Ah, I always forget that there's intermediates that aren't just matrix multiplies in ML.

A single python interpreter stack frame into a 10^4 * 10^4 GEMM C BLAS kernel is not a bottleneck, but calling 10^8 python interpreter stack frames for a pointwise addition broadcast op would be a bottleneck.

Does pytorch overload common broadcast operations though? I was under the impression that it did as well. I guess this is what `torch.compile` attempts to solve?