The Pyrefly type checker is starting to work on this kind of shape hinting - so far it only works on Torch but I believe the plan is for it to work with other array packages (eg. JAX, NumPy)
The Pyrefly type checker is starting to work on this kind of shape hinting - so far it only works on Torch but I believe the plan is for it to work with other array packages (eg. JAX, NumPy)
See also jaxtyping which, contrary to what its name might imply, covers JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.
https://docs.kidger.site/jaxtyping/
Shape functions and shape analysis are basically mundane infra in almost every ML compiler/language/DSL.
https://mlir.llvm.org/docs/Dialects/ShapeDialect/
I didn't know that, thanks for sharing. It makes sense, but then it also makes me wonder why none of the deep learning libraries (Torch, Jax/NNX, Eigen etc...) make this information available. Instead, ML people all have their own schemes for tracking shape information, like commenting '# (b, n, t)' on every line, or suffixing shapes to variable names - and in my experience it's a common source of bugs.
> Torch, Jax
Both of these "make it available". Just because people don't know how to use/find them doesn't mean they're not "available".
> Eigen
This is not an ML anything, it's a linear algebra library.
> like commenting '# (b, n, t)' on every line, or suffixing shapes to variable names
There's a difference between tracking shapes in the compiler and specifying shapes in the model.