The interesting thing here is that it's not a straightforward port. JAX is already very fast, for the architecture it implements. The point is that the network is heavily contracted by removing nodes that only do pass-through, and then hugely parallelizing the computations using bitwise operations on 64 bits at once. Hence this incredible speedup.