I also worked a long time ago in recreating the original Deep Differentiable Logic Network paper [1], so I have a couple of additions to make.

> I wanted to see if I could learn the wires in addition to the gates. I still think it’s possible, but it’s something I had to abandon to get the model to converge.

Actually, I read some other paper where they also learned the wiring, but they did so by alternating the training of the gates and the wires (in some iterations they learned the wiring while keeping the gates frozen, and in other they learned the gates while keeping the wiring frozen). The problem with this approach is that it is inherently non-escalable: you need a lot of gates to approximate the behavior of a simple MLP, and if you need a full NxM learned matrix to encode the wiring, the memory needed to learn, for example, MNIST, gets huge, quickly. I think that for this there are 2 fixes:

- You actually don't need to learn a full NxM matrix to increase the expressivity of the network. You can, for each output gate, select a random subset of possible input gates of size K, and then you only need a learned matrix of size KxM. I did the numbers, and even a moderately small K, like 16 or 32, wildly increases the number of circuits you can learn with a smaller number of layers and gates.

- You could use a LoRA kind of matrix. Instead of a matrix NxM, use a pair of matrices NxK and KxM, where K<<N,M.

Learning the wiring also has other benefits. As the output gate can learn to swap the inputs if needed, you can remove some learnable gates that are "mirrors" or "permutations" of each other (a and not b, not a and b; a or not b, not a or b), which can help scale the networks to use gates of more inputs (I tried with 3-input gates and 4-input gates).

Also, as the author pointed out, it was very difficult to get the models to converge. It was very frustrating that I never managed to get a working model that performed really well on MNIST. In the end, I gave up on that and I worked on how to make the network consistently learn simple 3-input or 4-input functions with perfect accuracy, and I managed to make it learn them consistently with a couple dozen iterations, which was nice.

[1] https://arxiv.org/abs/2210.08277

Very cool, thank you for sharing!