I thought this was pretty well known (at least in the JAX/XLA world). I've hit this many times and got batch variance explained to me before: https://github.com/google-deepmind/penzai/issues/82 and https://github.com/jax-ml/jax/issues/20047#issuecomment-1975...

should be the top comment.