Memory usage increases quadratically with sequence length. Therefore, using shorter sequences during fine-tuning can prevent memory explosions. On my 64GB RAM machine, I'm limited to input sequences of about 2,000 tokens, considering my average output for the fine-tuning task is around 1,000 tokens (~3k tokens total).
Ah that makes sense, quadratic scaling is brutal. So with 96gb i'd probably get somewhere around 4-5k total sequence length before hitting the wall, which is still pretty limiting for anything multimodal. Do you do any gradient checkpointing or is that not worth the speed tradeoff at these sizes?
Haven’t tried yet. That’s on the do list. But good suggestion.
Shouldn't FlashAttention address the quadratic increase in memory footprint wrt. fine-tuning/training? I'm also pretty sure that it does not apply to pure inference due to how KV-caching works.