from time import time import warnings import torch import transformers def sdpa_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float, dropout: float | None = 0., **_ ): """ Grouped query attention with sinks using `torch`-native scaled dot product attention. Let - N be the batch size, - H be total the number of attention heads, - G be the number of groups, and - E be the per-head embedding dimensionality. Parameters ---------- module: torch.nn.Module Attention module with `sinks` `torch.Tensor` attribute of shape H query: torch.Tensor Query tensor of shape N x H x L x E key: torch.Tensor Key tensor of shape N x G x L x E value: torch.Tensor Value tensor of shape N x G x L x E attention_mask: torch.Tensor | None Attention mask of shape N x 1 x L x L scaling: float Scaling factor applied to query-key dot products, typically 1 / sqrt(E) dropout: float | None = 0. Dropout probability Returns ------- torch.Tensor Attention output tensor of shape N x L x H x E None Unused attention weights """ N, H, L, E = query.shape _, G, *_ = key.shape if attention_mask is None: attention_mask = torch.zeros(N, 1, L, L, device=query.device, dtype=query.dtype) attention = torch.nn.functional.scaled_dot_product_attention( query, *( torch.cat( [ tensor.repeat_interleave(H // G, dim=1), torch.zeros(N, H, 1, E, device=tensor.device, dtype=tensor.dtype) ], dim=2 ) for tensor in (key, value) ), torch.cat( [ attention_mask.expand(N, H, L, L).clone(), module.sinks.reshape(1, H, 1, 1).expand(N, H, L, 1) ], dim=3 ), dropout_p=dropout, is_causal=attention_mask is None, scale=scaling ) return attention.transpose(1, 2).contiguous(), None warnings.filterwarnings('ignore', category=UserWarning, module='torch.cuda') torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) transformers.AttentionInterface.register('sdpa', sdpa_attention_forward) transformers.AttentionMaskInterface.register('sdpa', transformers.masking_utils.eager_mask) transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel._supports_sdpa = True model = transformers.AutoModelForCausalLM.from_pretrained( 'openai/gpt-oss-20b', attn_implementation='sdpa', dtype=torch.bfloat16, quantization_config=transformers.Mxfp4Config(dequantize=True) )