Stride and Prejudice: How a 32-bit overflow corrupted a CUDA kernel (and stayed hidden for weeks)

11 min read Original article ↗

TL;DR

While training our Jamba 3B model with GRPO, we hit a mysterious logprob mismatch between rollout and training. The root cause turned out to be a silent integer overflow deep inside a vLLM CUDA kernel – one that only triggered when the number of cache slots exceeded ~47,935. The fix was two characters away. The journey to find it was not.

One of the trickiest parts of debugging reinforcement learning systems is that the bug almost never introduces itself politely.

In reinforcement learning from human feedback, correctness is load-bearing in ways that aren’t always visible. A common sanity check is to compare the log probabilities which the rollout engine assigned to its own completions, against what the freshly-loaded training model computes on the same sequences – before any weight update has happened. In a healthy system, these should be nearly identical: same weights, same inputs, same numbers. When they diverge, it’s a canary signal that something upstream has already gone wrong.

Recently, while working on our GRPO training setup for our Jamba 3B model, we ran into exactly this kind of issue. Jamba is AI21’s hybrid attention-Mamba model – a 3B parameter architecture that combines standard attention with Mamba state-space layers. That hybrid nature, as we’ll see, is exactly what made this bug so hard to find. The training stack included Ray for orchestration, an LM training path with FSDP for weight updates, and a rollout path used for generating completions via vLLM and scoring them using our internal evaluation service. What we saw was a mismatch in log probabilities between rollout and the FSDP side. At first glance, it was not obvious whether the problem came from training, inference, synchronization, or something more subtle in the interaction between them.

Spoiler alert: It was a typical unsigned integer overflow of an index deep inside a CUDA kernel. The interesting part was the journey, not the destination, so let’s dig in.

The first question: where does the bug actually live?

In a system like this, GRPO is not really one thing, it is not your average model.fit(). It is a pipeline made of distinct stages, each with its own assumptions:

  • rollout generation
  • FSDP-based training
  • distributed coordination through Ray
  • handoff between inference-time and training-time components

In a typical GRPO training flow the process looks as below: 

Typical GRPO training flow
Figure 1: A typical GRPO training loop using vLLM and FSDP. vLLM handles rollout generation and computes log π_old; the FSDP model recomputes log π_train on the same completions before any weight update. The mean absolute difference between the two (sanity check, right) should be ≈0 at the start of each iteration.

When logprob spikes between the vLLM engine and those of the just-loaded model in FSDP started appearing – before any weight updates – the natural temptation was to inspect everything at once. But that usually leads nowhere. In distributed RL systems, if you do not reduce the search space aggressively, you end up debugging the entire stack instead of debugging the bug.

So the real task became: can we determine whether this issue belongs to the rollout side, the FSDP side, or the weight sync in-between?

Conventional debugging wisdom says: find a single hyperparameter that lets you disturb the system in a controlled way, and observe whether the symptom moves with it.

Finding the lever

Watching how the logprobs difference behaved, a key observation was that the spikes were not random. They appeared periodically – recurring damage and recovery at a fixed interval. With our default experiment configs, the spike didn’t occur until roughly step 12, which added more questions than it answered.

Logprobs difference
Figure 2: Mean absolute logprob difference between vLLM rollout and FSDP recompute (rollouts_num=8, default config). Spikes appear periodically – roughly every 12 steps – rather than randomly, suggesting a structured, recurring source of error rather than training noise.

From looking at the model rollouts around those steps, there were a number of gibberish completions, but not all. We took the weights from such steps and tried reproducing the issue in a standalone vLLM instance. The generation looked fine. Dead end.

We then started a methodical search for a lever. The eureka moment came when we ran experiments with a doubled number of rollouts per prompt – 8, 16, 32, up to 128 – and found that the periodicity of the spikes tracked exactly with that number. That changed the debugging story completely.

If the logprob diff spikes move in a pattern tied to rollout count, then the bug is very unlikely to be some generic instability in FSDP training. The symptom is being structured by rollout behavior. That doesn’t fully prove where the bug lives, but it gives you a much sharper hypothesis: the issue is probably introduced before the training path ever consumes the data.

That was the turning point.

Logprobs difference_2
Figure 3: Spike periodicity as a function of rollouts_num (8, 16, 32, 64, 128). As the number of rollouts per prompt increases, spikes shift earlier in training. With rollouts_num=128, the spike triggers on the very first step – a key diagnostic signal pointing away from training dynamics and toward the rollout path itself.

From “training instability” to “reproducible rollout issue”

Before that observation, the issue looked like an RL-training bug. After it, it looked more like a rollout-path bug that happened to surface during RL training.

That distinction matters a lot. A full RL training run is a terrible place to debug low-level correctness issues. There are too many layers involved, too much asynchronous machinery, too many places where the original signal gets blurred. The ideal debugging path is always to peel away layers until the issue survives in the smallest possible environment.

So that became the goal: reduce the problem from “something is wrong during GRPO training” to “we can reproduce this on demand, on the earliest possible rollout.”

The next milestone was getting the issue to reproduce at step zero – and looking closely at the chart, the 128 rollout-per-prompt run spiked on the very first rollout.

Logprobs difference_3
Figure 4: Zoomed view of step-one logprob differences across rollout counts. With rollouts_num=128 (green), the difference reaches ~0.11 on the very first step – confirming the bug is present before any training state accumulates, and making step-zero reproduction possible.

That was a huge simplification. Once the bug appears at the very first opportunity, you no longer need to reason about optimizer history, training drift, reward effects, gradient accumulation, or long-horizon RL dynamics. You are no longer debugging learning. You are debugging execution – and that is a much better problem to have. No accumulated training state to account for. No later-stage RL dynamics to untangle. Just a deterministic failure, repeatable on demand, small enough to instrument properly.

In our case, getting to step zero meant the gap between first-rollout logprobs and FSDP logprobs – preceding any weight updates – let us separate the issue from the full RL workflow and isolate it inside the underlying inference path. Our earlier attempts to reproduce with vLLM alone had failed because we were sampling a normal number of completions (8 or 16). The bug needed scale to surface. What began as a VeRL/GRPO debugging problem became, essentially, an inference-engine debugging problem.

This is the kind of shift you want in systems debugging: from a complex emergent symptom to a minimal, deterministic repro. With a standalone reproduction script in hand, we methodically narrowed down further and found yet another lever: vLLM’s GPU memory utilization parameter. Decreasing it to 50% avoided the issue entirely; pushing it beyond that reproduced it reliably. Now we know it was related to vLLM’s internal caching allocation budget.

In our Jamba case, since it’s a hybrid attention-Mamba model, the culprit was either the attention KV-cache or the Mamba state cache.

Running one more ablation with an attention-only model didn’t surface the issue – that eliminated the KV-cache case and restricted the blast radius to the CUDA kernel of Mamba-1 inside vLLM.

The bug

It lived in the selective scan CUDA kernel’s pointer arithmetic, and once you see it, it’s not the kind of bug you admire, but the kind you grudgingly respect for how quietly it hid. The SSMParamsBase struct defined index_t as uint32_t, and all stride fields – including ssm_states_batch_stride – inherited this type. In the kernel, the base pointer for each batch’s SSM state is computed by multiplying cache_index (an int) by ssm_states_batch_stride (a uint32_t). Since both operands are 32-bit types, the multiplication is performed in 32-bit arithmetic at runtime, and silently wraps around when the result exceeds UINT32_MAX. With a batch stride of 89,600 elements, the product overflows UINT32_MAX once cache_index exceeds roughly 47,935. In our setup with 69,776 total cache slots, roughly 31% of slots were affected – the kernel was writing SSM states to incorrect memory locations while the intended slots remained zeroed out.

Think about that for a moment: 31% of your cache is silently corrupted. No crash. No warning. Just wrong numbers propagating through a training run, manifesting as a logprob mismatch that looked, from the outside, like training instability.

The fix was as follows in this vLLM upstream PR:

struct SSParamsBase {
-     using index_t = uint32_t;
+     using index_t = size_t; 

Widening index_t to size_t – compiled to uint64 in this kernel – ensures all stride multiplications use 64-bit arithmetic. Weeks of investigation. Two characters changed.

Pointer offset
Figure 5: Pointer offset overflow in the Mamba-1 selective scan CUDA kernel. When cache_index × ssm_states_batch_stride exceeds UINT32_MAX (~4.29B), 32-bit arithmetic wraps silently. With a batch stride of 89,600, indices above 47,935 overflow – affecting ~31% of the 69,776 total cache slots, which receive corrupted SSM state writes while their intended slots remain zeroed.

The real debugging lesson

The technical issue itself mattered, of course. But the bigger lesson was methodological.

When debugging RL systems, especially distributed ones, the hardest part is often not fixing the bug. It is finding the right boundary around it.

A few things actually moved the needle here, and they’re worth naming precisely.

  1. The most useful one: look for variables that change the structure of failure, not just its magnitude. Increasing rollout count didn’t just make the error bigger – it changed the spacing of the spikes. That structural shift was far more informative than a louder error signal would have been. It told us the failure was being organized by the rollout process itself, which pointed directly at where to look.
  2. Resist the urge to debug “the whole training.” Ask instead what independent subsystems exist inside the pipeline and how you can stress them in isolation.
  3. Once you have a directional hypothesis, push hard toward minimal reproduction. If the bug can be reproduced at step zero, don’t keep chasing it inside a 500-step RL run.

Closing thought

In distributed RL systems, Ray, FSDP, rollout workers, and the training loop can make a fairly localized bug look like a system-wide failure. That is what makes debugging these setups so deceptive: the visible symptom often reflects the architecture around the bug more than the bug itself. What helped in this case was resisting the urge to treat the full stack as opaque magic, and instead using the system’s own structure against it: varying one dimension, observing the shape of the failure, and asking which component could produce that pattern. That turned a vague log-prob mismatch into a targeted investigation.

To me, this is one of the most underrated parts of working on RL infrastructure: the debugging process is often as interesting as the training algorithm itself. A good debugging session is not just about finding what is broken, but about shrinking a messy, distributed, multi-component problem into something small enough to reason about clearly. In this case, the path went from GRPO training, to rollout-vs-FSDP suspicion, to rollout-correlated periodicity, to step-zero reproduction, and finally to isolating the issue inside the inference engine. That progression is what made the bug solvable and honestly, it is also what makes this kind of engineering work fun.

Acknowledgements
This work was contributed by Asaf Gardin and Amir Koblyansky