The guts of a LLM isn't something I'm well versed in, but
> to get the first N tokens sorted, only when the big model and small model diverge do you infer on the big model
suggests there is something I'm unaware of. If you compare the small and big model, don't you have to wait for the big model anyway and then what's the point? I assume I'm missing some detail here, but what?
Speculative decoding takes advantage of the fact that it's faster to validate that a big model would have produced a particular sequence of tokens than to generate that sequence of tokens from scratch, because validation can take more advantage of parallel processing. So the process is generate with small model -> validate with big model -> then generate with big model only if validation fails
More info:
* https://research.google/blog/looking-back-at-speculative-dec...
* https://pytorch.org/blog/hitchhikers-guide-speculative-decod...
See also speculative cascades which is a nice read and furthered my understanding of how it all works
https://research.google/blog/speculative-cascades-a-hybrid-a...
Verification is faster than generation, one forward pass for verification of multiple tokens vs a pass for every new token in generation
I don't understand how it would work either, but it may be something similar to this: https://developers.openai.com/api/docs/guides/predicted-outp...
They are referring to a thing called "speculative decoding" I think.
When you predict with the small model, the big model can verify as more of a batch and be more similar in speed to processing input tokens, if the predictions are good and it doesn't have to be redone.