Ask HN: How do you scale transformer context lengths over multiple machines?
Context length is such an important aspect in today’s AI race. All major players actively advertise this too. Given how matrix math works, how do people run inference for a transformer when the context length is so long that you can’t fit it on one gpu / one machine ? Not sure how they do it specifically for LLMs, but you can do what is called model or tensor parallelism where you can split a layer over multiple GPUs or even nodes.
If you look under the hood it's the same distributed matrix multiplication stuff with MPI, as far as I know. I think Deepspeed has bespoke transformer kernels which handle this stuff specifically. I'm a few months off on the latest, but one way used to be to start summarizing history if the context did start becoming huge: summarize the earliest n-k messages, keep the last k. Quantizing it down to 8 bits seems to be one solution. TensorRT-LLM does this (and I think requires an H100)? exLlama also does this on much lesser hardware. wouldn't that mean trying to fit it on one machine? Indeed :P Honestly I'm not sure how context "sharding" works on multiple GPUs atm. Decent, really long context OSS models like Yi 200K and YARN finetunes are very new.