I think https://jax-ml.github.io/scaling-book/ is one of the best references to go through. It details how single device and distributed computations map to TPU hardware features. The emphasis is on mapping the transformer computations, both forwards and backwards, so requires some familiarity with how transformer networks are structured.