The math on the input tokens is definitely wrong. It claims each instance (8 GPUs) can handle 1.44 million tokens/sec of input. Let's check that out.
1.44e6 tokens/sec * 37e9 bytes/token / 3.3e12 bytes/sec/GPU = ~16,000 GPUs
And that's assuming a more likely 1 byte per parameter.
So the article is only off by a factor of at least 1,000. I didn't check any of the rest of the math, but that probably has some impact on their conclusions...
37 billion bytes per token?
Edit: Oh assuming this is an estimate based on the model weights moving fromm HBM to SRAM, that's not how transformers are applied to input tokens. You only have to do move the weights for every token during generation, not during "prefill". (And actually during generation you can use speculative decoding to do better than this roofline anyways).
> (And actually during generation you can use speculative decoding to do better than this roofline anyways).
And more importantly batches, so taking the example from the blog post, it would be 32 tokens per each forward pass in the decoding phase.
There's also an estimation of how much a KV cache grows with each subsequent token. That would be roughly ~MBs/token. I think that would be the bottleneck
Your calculations make no sense. Why are you loading the model for each token independently? You can process all the input tokens at the same time as long as they can fit in memory.
You are doing the calculation as they were output tokens on a single batch, it would not make sense even in the decode phase.
This. ChatGPT also agrees with you: "74 GB weight read is per pass, not per token." I was checking the math in this blog post with GPT to understand it better and it seems legit for the given assumptions.
Then the right calculation is to use FLOPs not bandwidth like they did.
> 37e9 bytes/token
This doesn't quite sound right...isn't a token just a few characters?
Well he asked some AI to do the math for him probably