FlashAttention-T: Towards Tensorized Attention
dl.acm.org> Our key insight is to offload critical softmax primitives to idle tensor units, maximizing hardware utilization and throughput.
> … speedups of 1.05–1.17×across diverse attention configurations on Ampere and Hopper GPUs …
Oh wow there's still work being done on ampere?
I was wondering - I've been thinking about switching to AI systems programming (I know, easy task), but from what I understand, industry cloud GPUs are the main winners, right? Nobody's going to pay me (assuming I even had the skills) to optimize for consumer GPUs?
From what I understand, it's not just number + capacity + performance, it's literal core primitives. I don't think any of the "Blackwell" chips like the grace one or rtx 5090 have for example SM pairs in their ISA? And likewise similar fundamental differences between consumer and cloud hopper (where the majority of the perf is the cloud one's ISA?)
So I guess I'm wondering if I should buy a GPU myself or should I just rent on the cloud if I wanted to start getting some experience in this field. How do you even get experience in this normally anyways, do you get into really good schools and into their AI labs which have a lot of funding?
> Nobody's going to pay me (assuming I even had the skills) to optimize for consumer GPUs?
People will but probably less, not many people are doing AI at the edge that can pay the mega millions
> And likewise similar fundamental differences between consumer and cloud hopper (where the majority of the perf is the cloud one's ISA?)
I think Hopper was the version where they did a clean split and it’s only for datacenter
> So I guess I'm wondering if I should buy a GPU myself or should I just rent on the cloud if I wanted to start getting some experience in this field. How do you even get experience in this normally anyways, do you get into really good schools and into their AI labs which have a lot of funding?
You can do performance work on any system you have really it’s just that the details change depending on what you’re targeting. You can definitely learn the basics on like a 3060 by following blog posts
Why does publishing papers require the latest and greatest GPUs? My understanding is that the paper talks about very general principles.
> So I guess I'm wondering if I should buy a GPU myself or should I just rent on the cloud if I wanted to start getting some experience in this field. How do you even get experience in this normally anyways, do you get into really good schools and into their AI labs which have a lot of funding?
Unless you have money to throw around, you'd better start working on something, write some code and get them running on a leased GPU, before deciding on a long term plan
You should check out nanochat. I would personally appreciate it if someone implemented hardware optimized flash attention for my 3090
I do CUDA for a living (not inference) and for the life of me (and a couple of LLMs for that matter) I cannot figure out what you mean by "SM pairs".
Do you mean the coupled dies on stuff like the B200? An NVidia chip die has many SMs if so.
Do you mean TMEM MMA cooperative execution? I'm guessing that must be it given what the paper is about.
https://hazyresearch.stanford.edu/blog/2025-03-15-tk-blackwe...
cooperative execution yeah
as you can tell I do not do CUDA for a living :D
I still have 2x NVLinked A6000 and they aren't that bad compared to a single RTX 6000 Pro.
yep, https://github.com/poad42/cuda-fp8-ampere recently another attempt at squeezing whatever's left from ampere
Look at am the email addresses. If you’ll recall there’s an embargo on China.
OT but instead of quadratic attention can we not have n^10 or something crazier? I feel like we are limiting the intelligence just to save cost. But I can imagine that there might be some questions that may be worth paying higher cost for.
I feel like n^10 attention can capture patterns that lower complexity attention may not. So it seems arbitrary that we have n^2 attention.
You can find papers discussing "cubic" attention, i.e. each token gets to interact with each pair of other tokens, but always in very theoretical settings with single-layer transformers on contrived synthetic tasks.
Keep in mind that LLMs have many many layers, so they have plenty of opportunity to model higher-order interactions without needing to brute force every possible combination of 10 previous tokens, of which the vast majority will be useless. Empirically, even full "quadratic" attention is not always necessary, as evidenced by the existence of linear/sparse attention variants that perform almost as well.
What you're missing is that there's no need to do extra work in the kernel smoothing step (what attention essentially is) because all the fancy transformation work is already happening in learning the kernel.
The feedforward networks prior to the attention layer are effectively learning sophisticated kernels. If you're unfamiliar (or for those who are) a Kernel is just a generalization of the dot product which is the most fundamental way of defining "similarity" between two points.
By learning a kernel the transformer is learning the best way to define what "similar" means for the task at hand and then we simply apply some basic smoothing over the data. This will handle all sort of interesting ways to compare points and that comparison will allow all points to provide a little bit of information.
Anything you could hope to achieve by performing more comparisons would be better solved by a better similarity function.
Aren't layers basically doing n^k attention? The attention block is n^2 because it allows 1 number per input/output pair. But nothing prevents you from stacking these on top of each other and get k-th order of "attentioness" with each layer encoding a different order.
Yes, and it works in theory.
Less so in practice. You saturate the memory of a b200 with a few dozen tokens on attentions higher than order 4. Training is even worse.
To paraphrase Knuth: high order polynomials are much more unimaginably large than mere infinity.
This is a common way of thinking. In practice this type of thing is more like optimizing flop allocation. Surely with an infinite compute and parameter budget you could have a better model with more intensive operations.
Another thing to consider is that transformers are very general computers. You can encode many many more complex architectures in simpler, multi layer transformers.
The vast majority of benefits that can be obtained from scaling a single layer inside a neural network can often be better accomplished by having more layers instead.
Here is an illustrative example: You can write higher order polynomials as a recursive chain of first order polynomials. (Horner's method).
Things like TreeConnect [0] scale better if each TreeConnect layer has a depth of two and you add more TreeConnect layers to compensate the lack of expressivity instead of choosing a higher depth.
Attention pairs every token against every other token. n^10 would mean pairing each token with nine other tokens. The primary benefit of doing this is that you can have a "function" that accepts the interactions of 10 tokens as input to produce a single output, but you already have that if you have a ten layer network. The interactions of two tokens can form a combined token that contains information of both tokens. The network can repeat this ten times to accumulate the desired information into a single super token and then make a decision based on all ten input tokens.
n^2 isn't a setting someone chose, it's a mathematical consequence of what attention is.
Here's what attention does: every token looks at every other token to decide what's relevant. If you have n tokens, and each one looks at n others, you get n * n = n^2 operations.
Put another way: n^2 is when every token gets to look at every other token. What would n^3 be? n^10?
(sibling comment has same interpretation as you, then handwaves transformers can emulate more complex systems)
There are lots more complicated operations than comparing every token to every other token & the complexity increases when you start comparing not just token pairs but token bigrams, trigrams, & so on. There is no obvious proof that all those comparisons would be equivalent to the standard attention mechanism of comparing every token to every other one.
While you are correct at a higher level, comparing bigrams/trigrams would be less compute not more because there’s fewer of them in a given text
I'm correct on the technical level as well: https://chatgpt.com/s/t_698293481e308191838b4131c1b605f1
That math is for comparing all n-grams for all n <= N simultaneously, which isn't what was being discussed.
For any fixed n-gram size, the complexity is still O(N^2), same as standard attention.
I was talking about all n-gram comparisons.
Thanks for clarifying. I was hoping to clarify the disconnect between you two, looked like on on "bigrams, trigrams, & so on." It reads idiomatically as enumerating fixed-n cases. Parsing "& so on" as "their simultaneous union" asks quite a bit of those two words. Either way, as ChatGPT showed you and you shared, all-ngram comparison brings us to O(N^3), still several exponents short of N^10 that started this thread.
This is getting tiresome. I can make the operations as complicated as necessary by comparing all possible permutations of the input string w/ every other permutation & that will not be reducible to standard attention comparisons. The n-gram was a simple example anyone should be able to understand. You can ask your favorite chatbot to compute the complexity for the permutation version.
No worries! I enjoyed it fwiw, appreciate your time :) (The permutation version would be factorial, fwiw, not polynomial. Different beast entirely.)
That skips an important part: the "deep" in "deep learning".
Attention already composes across layers.
After layer 1, you're not comparing raw tokens anymore. You're comparing tokens-informed-by-their-context. By layer 20, you're effectively comparing rich representations that encode phrases, relationships, and abstract patterns. The "higher-order" stuff emerges from depth. This is the whole point of deep networks, and attention.
TL;DR for rest of comment: people have tried shallow-and-wide instead of deep, it doesn't work in practice. (rest of comment fleshes out search/ChatGPT prompt terms to look into to understand more of the technical stuff here)
A shallow network can approximate any function (universal approximation theorem), but it may need exponentially more neurons. Deep networks represent the same functions with way fewer parameters. There's formal work on "depth separation",functions that deep nets compute efficiently, but shallow nets need exponential width to match.
Empirically, People have tried shallow-and-wide vs. deep-and-narrow many times, across many domains. Deep wins consistently for the same parameter budget. This is part of why "deep learning" took off, the depth is load-bearing.
For transformers specifically, stacking attention layers is crucial. A single attention layer, even with more heads or bigger dimensions, doesn't match what you get from depth. The representations genuinely get richer in ways that width alone can't replicate.
QM would tell us the order of your Hamiltonian (attention operator) doesn’t limit the complexity of the wave function (hidden state). It might be more efficient to explicitly correlate certain many-body interactions, but pair-wise interactions, depth and a basis (hidden state dimension) approaching completeness "are all you need”.
The terminology is overloaded.. Tensors in QM are objects obeying transformation laws, in ML Tensors are just data arranged in multidimensional arrays. There are no constraints on how the data transforms.
Intended as analogy - but it is essentially a description of the DMRG algorithm (quantum chem). Only pair-wise operators there but the theory approaches exact when there are enough terms in your tensor product (iterations ~ depth) and a large enough embedding dimension.
> There are no constraints on how the data transforms.
Except those implicit in your learned representation. And that representation could be the MB WF.
I built guided window attn (literally predict the position of the window) a while ago and that works great. Why are we still stuck on any form of attn that looks at the entire context in any meaningful way? Do humans work this way? Do I need a whole book to predict the next word? Who out there is working on really new unique ways to deal with infinite history, other than me of course :)
> Who out there is working on ... infinite history?
Many people are still working on improving RNNs, mostly in academia. Examples off the top of my head:
* RWKV: https://arxiv.org/abs/2006.16236 / https://arxiv.org/abs/2404.05892 https://arxiv.org/abs/2305.13048
* Linear attention: https://arxiv.org/abs/2503.14456
* State space models: https://arxiv.org/abs/2312.00752 / https://arxiv.org/abs/2405.21060
* Linear RNNs: https://arxiv.org/abs/2410.01201
Industry OTOH has gone all-in on Transformers.
> Industry OTOH has gone all-in on Transformers.
It's so annoying. Transformers keep improving and recurrent networks are harder to train so until we hit some real wall, companies don't seem eager to diverge. It's like lithium batteries improving easy faster than it was profitable to work on sodium ones, even though we unfortunately want the sodium ones to be better.
RNNs have two huge issues: - long context. Recurrence degrades the signal for the same reason that 'deep' nn architectures don't go much past 3-4 layers before you need residual connections and the like - (this is the big one) training performance is terrible since you can't parallelize them across a sequence like you can with causal masked attn in transformers
On the huge benefit side though you get: - guaranteed state size so perfect batch packing, perfect memory use, easy load/unload from a batch, O(1) of token gen so generally massive performance gains in inference. - unlimited context (well, no need for a concept of a position embedding or similar system)
Taking the best of both worlds is definitely where it is at for the future. An architecture that can train parallelized, has a fixed state size so you can load/unload and patch batches perfectly, unlimited context (with perfect recall), etc etc. That is the real architecture to go for.
RNN training cannot be parallelized along the sequence dimension like attention can, but it can still be trained in batches on multiple sequences simultaneously. Given the sizes of modern training sets and the limits on context size for transformer-based models, it's not clear to what extent this is an important limitation nowadays. It may have been more relevant in the early days of attention-based models where being able to do experimental training runs quickly on relatively small sizes of training data may have been important.
Not quite, most of the recent work on modern RNNs has been addressing this exact limitation. For instance linear attention yields formulations that can be equivalently interpreted either as a parallel operation or a recursive one. The consequence is that these parallelizable versions of RNNs are often "less expressive per param" than their old-school non-parallelizable RNN counterparts, though you could argue that they make up for that in practice by being more powerful per unit of training compute via much better training efficiency.
To get a similar token/sec in training though you would need to swap batch size and seq length so you could have the massive batch size but then won't you start hitting memory issues with any reasonable sequence length? You would have to create do something similar to a minibatch along the sequence and cut the gradients after a short number of tokens on each sequence. So how will they learn truly long sequences for recall? Or is there a different trick I am missing here?
Linear RNNs overcome both issues. All the RNNs I mentioned are linear RNNs.
I'll give them all a look. Thanks!
> Who out there is working on really new unique ways to deal with infinite history, other than me of course :)
I'm working on a novel (I think) linear attention mechanism in my personal lab that's O(L) for effectively infinite context. I haven't yet decided how much of it is going to be open source, but I agree with you that it's important to figure this out.
Was your work open? Is there some place I can read more about it? I'm trying to figure out what to do with my thing on the off-chance that it actually does turn out to work the way I want it to.
I'm trying to figure the same thing out for my stuff. I figured out a simple way to train location prediction so I'm using it for guided window prediction which is great for attn (predict a distance in the past to look at) and for memory (predict an x, y location for a 2d window into a memory store to look at that will be helpful). I suspect there are a lot of people out there that have found that one weird trick but haven't released it because they don't know how to capitalize on the idea. Why give OpenAI and others the keys to the future for free?
how does this compare to MoSA (arXiv:2505.00315)? do you require that there's a single contiguous window? and do you literally predict on position, or with a computed feature?
I predict a specific location then put a window around that. Of course you can predict a different location per head or multiple window locations per head as well. The cost is negligible (single linear embx1 size) so attn becomes a fixed cost per token just like traditional windowed attn. Of course this doesn't solve memory consumption because you still have a kv cache unless you only do attn over the initial embeddings at which point you don't need the cache, just the token history. This is the tact I'm taking now since I have other ways of providing long context at deeper layers that remain O(1) for token prediction and are paralellizable like standard attn. I think this kind of architecture is the future, infinite context, fixed size state, O(1) prediction, externalized memory are all possible and break current context, memory and compute problems. It is clear that in the future token caching will be dead once these types of models (mine or someone else's with the same properties) are properly tuned and well trained.
tri dao isn't on the paper is it even allowed to call it "FlashAttention"???
Less annoying link directly to the paper: https://dl.acm.org/doi/pdf/10.1145/3774934.3786425?download=...
link if you don't want to automatically download files
Tldr: 5% - 17% speedup due to removing a bottleneck by juggling where on a GPU/compute core a computation is done during Flash attention.