speeding up llm rl training by 7.5x on long-prompt, short-response tasks

8 min read Original article ↗

cf May 11, 2026 9 min read

tl;dr

  • training llms with reinforcement learning slows down as prompts get longer
  • for workloads with long prompts & short completions, computation is dominated by repeated processing of the same prompts
  • by “caching” prompts across completions, we were able to reduce end-to-end step time by up to 7.5x in our measured runs. smaller gains apply as the prompt-to-completion ratio shrinks.

context

for a deeper intro to grpo, the technique behind open source llm rl, check out this post.

at a high level, llm rl works by taking a prompt, sampling a group of G responses, computing a reward for each of those responses & then training the model to upweight samples with higher rewards and downweight samples with lower rewards.

a key problem with most open source rl engines is that they pack sequences this way for training: each sequence in the batch is the full prompt concatenated with one response.

note that the same prompt is repeated G times. when prompts are long and responses are short, this repetition results in wasted computation.

take a case where prompts are ~1000 tokens and responses are ~100 tokens. for G = 8, this naive packing would process (1000 + 100) * 8 = 8800 tokens. but only 1000 + (100 * 8) = 1800 tokens are from unique sequences - the remaining 7000 tokens are just wasted compute. nearly 5x redundant work is done.

train-time prompt caching

instead of repeating the prompt G times for every response, we can instead pack the micro-batch in the following way: put the prompt once at the front, followed by all G responses back-to-back.

if this sounds familiar, it should - inference engines implement this: e.g. vLLM’s prefix caching and anthropic’s prompt caching. basically, compute the prompt’s KV cache once and reuse it across many responses. the key difference is that inference caching only needs the forward pass. training needs gradients to flow back through the prompt so the model can actually learn from it. for full attention layers this means carefully reconstructing the right attention patterns so that backprop sees the same computation graph as the naive layout; for linear attention layers the recurrent state lets us do something closer to true reuse, as we’ll show below.

the prompt is only materialized once in the batch, which should save significant computation, since we’ll be processing a fraction of the number of tokens as before.

while this is conceptually simple, work needs to be done to architecturally support this caching. this is trivial for feedforward layers. however, for attention layers, there is a causal relationship where each token depends on the previous tokens. with our new packing, that causal chain no longer holds. this breaks the traditional way most open source attention layers are implemented. here’s how we went about fixing them.

full attention layers

► how does attention work?

when a model generates each token, it needs to decide which previous tokens are most relevant to what comes next. attention is the mechanism that does this.

each token produces three vectors: a query (q): “what am i looking for?”, a key (k): “what do i represent?”, and a value (v): “what info do i carry?”. the model computes a similarity score between each query and all available keys, then uses those scores to take a weighted sum of the values. high-scoring tokens contribute more to the output.

causal attention adds a constraint: a token can only attend to tokens that came before it (or itself), not future ones. this is how llm generation works - each new token only sees what’s already been generated.

the packed layout [prompt | resp1 | resp2 | ... | respG] breaks the standard causal attention assumption: each response should attend to the prompt and to itself, but not to any other response. we solve this with two flash attention passes.

pass 1: prompt self-attention

slice out the prompt tokens and run standard causal self-attention on them. this produces the prompt’s output hidden states, and importantly, the key/value tensors we’ll reuse in pass 2.

pass 2: response attention over [prompt + self]

to express this in a single flash attention call, we build the key/value sequences by prepending a copy of k_prompt / v_prompt before each response’s own keys/values:

k_full = [k_prompt, k_resp1, k_prompt, k_resp2, k_prompt, k_resp3, ...]

we then pass cu_seqlens to tell flash attention where each independent sequence starts and ends. with causal=True, each response token can see all prompt tokens plus its own prior tokens, and nothing else.

here’s the pseudo-code implementation with flash-attention.

from flash_attn import flash_attn_varlen_func
# layout: [prompt | resp1 | resp2 | ... | respG]
q_prompt, k_prompt, v_prompt = (
    q[:prompt_len], k[:prompt_len], v[:prompt_len]
)
q_responses, k_responses, v_responses = (
    q[prompt_len:], k[prompt_len:], v[prompt_len:]
)

# Pass 1: causal self-attention on the prompt slice alone
prompt_out = flash_attn_varlen_func(
    q_prompt, k_prompt, v_prompt, causal=True
)

# Pass 2: each response attends to [prompt; response_i]
# interleave k_prompt before each response chunk:
# [k_prompt, k_resp1, k_prompt, k_resp2, ...]
# this way flash_attn sees one contiguous key sequence per group
k_full = torch.cat(
    [
        chunk for i in range(G)
        for chunk in (k_prompt, k_responses[i])
    ], dim=0
)
v_full = torch.cat(
    [
        chunk for i in range(G)
        for chunk in (v_prompt, v_responses[i])
    ], dim=0
)

# each response queries into a key sequence of [prompt + that response]
# causal=True here means: each response token can see
# all prompt tokens + prior response tokens
responses_out = flash_attn_varlen_func(
    q_responses, k_full, v_full, cu_seqlens_q, cu_seqlens_k, causal=True
)

# Reconstruct packed layout: [prompt | resp1 | resp2 | ...]
final_out = torch.cat([prompt_out, responses_out], dim=0)

linear attention layers

while the above works for standard attention kernels, most models today mix full attention layers with linear attention layers. we need a different solution for the latter.

standard attention computes relationships between every pair of tokens - powerful, but expensive (cost scales quadratically with sequence length). linear attention approximates this by instead maintaining a single fixed-size state that gets updated as each token is processed, like a running summary. this makes it much cheaper, but our two-pass flash attention trick no longer applies directly.

given that linear attention compresses the entire prefix into one fixed-size state, you can run the prompt through once to get that state and then use it as the initial state for each of the packed responses. this works cleanly because the state is a complete summary of the prefix - unlike softmax attention’s KV cache, where each query still needs to re-attend to every past key individually, the linear attention state captures everything a response token needs from the prompt. so the prompt state can be broadcast to all G responses as a shared starting point, with no extra work per response.

here’s the pseudo-code implementation

from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# Pass 1: prompt only - capture the final state.
prompt_out, prompt_state = chunk_gated_delta_rule(
    query=packed_query[:, :prompt_len, :, :],
    key=packed_key[:, :prompt_len, :, :],
    value=packed_value[:, :prompt_len, :, :],
    initial_state=None,
    cu_seqlens=None
)
# Pass 2: use the final state from the prompt as the initial state
# for all the responses in the group
response_out, _ = chunk_gated_delta_rule(
    packed_query[:, prompt_len:, :, :],
    packed_key[:, prompt_len:, :, :],
    packed_value[:, prompt_len:, :, :],
    initial_state=prompt_state.expand(group_size, -1, -1, -1).contiguous(),
    # list containing lengths of each response
    cu_seqlens=cu_seqlens_suffixes
)
# Concat the prompt with the responses for the next layer
final_out = torch.cat([prompt_out, response_out], dim=1)

results

the above numbers were computed with qwen3.5-4B on a 4xa100-80gb node, batch size = 126, group size = 9. we see significant throughput gains across the board, with the effect becoming more pronounced as the prompt-to-completion ratio grows. in the largest-ratio case shown above, with 16,384-token prompts and 64-token outputs, prompt caching reduces step time from 586s to 78s, a 7.5x speedup. the measured speedups are 7.3x for 128-token outputs, 5.4x for 1,024-token outputs, and 1.7x when prompts are 8,192 tokens and outputs are 4,096 tokens.

looking forward

custom attention masks

the above was a quick way for us to drive throughput gains, without going deep into the internals of flash attention/flash linear attention kernels. the longer-term solution is to use custom attention masks (which are not yet supported by flash attention or flash linear attention as of this post). specifically, we can create a custom attention mask for packed samples this way:

each response only attends to the initial prompt (green squares) and none of the other responses (gray squares). unfortunately, flash attention & flash linear attention do not support custom attention masks today. we plan to use flex attention and flex linear attention to implement this next.

more intelligent caching

currently, we only support prompt level caching. but there are cases where rollouts are similar to others, e.g. the same sequence of initial tool calls. consider this case:

a subset of rollouts tend to follow the same initial tool-call trajectory. being able to cache these computations more dynamically could result in further compute savings.

come work with us

we’re building fast, efficient rl infrastructure for frontier models. if this kind of systems + ml work interests you, reach out at coder@castform.com.