Very cool project. If you haven't already, for JAX and PyTorch support take a look at the Python Array API Standard, https://data-apis.org/array-api/latest/, and see https://data-apis.org/array-api-compat/ for how to use it. If you have or can write everything in terms of the subset of NumPy supported in the Array API Standard, you can get support for alternative array libraries almost for free.
That looks like a valuable resource, thank you! I already mostly stuck to a subset supported by NumPy and JAX (because that's the array libraries I'm familiar with). I hope the other are not to far off...