I'm wondering how this might be summarized in simple terms? It sounds like, after processing some text, the entire prompt is included in the in-memory internal state of the program that's doing inference.
But it seems like it would need to remember the prompt to answer questions about it. How does this interact with the attention mechanism?
I wonder why this is such a surprise, this is in fact what you would naively expect given the way the residual stream is structured no?
Each attention block adds to the residual stream. And we already know from logit-lens type work that the residual stream roughly remains in the same "basis" [1], which I vaguely remember is something that resnet architectures explicitly try to achieve.
So maybe it's my armchair naivety but in order for both of these to hold while the LLM being able to do some sort of "abstraction", it seems like it is natural for the initial token embedding to be projected into some high-dimensional subspace and then as it passes through different attention blocks you get added "deltas" on top of that, filling out other parts of that subspace.
And looking at the overall attention network from an information passing perspective, encoding and having access to the input tokens is certainly a nice thing to have.
Now maybe it could be argued that the original input could be lossily "converted" into some other more abstract representation, but if the latent space is large enough to not force this, then there's probably no strict reason to do so. And in fact we do know that traditional BPE token embeddings don't even form a subspace (there's a fixed vocab size and embeddings are just a lookup table, so it's only just a bunch of scattered points).
I wonder if this work is repeated with something like vision tokens, whether you will get the same results.
[1] https://nitter.poast.org/khoomeik/status/1920620258353610896