See also jaxtyping which, contrary to what its name might imply, covers JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.

https://docs.kidger.site/jaxtyping/