WIP inference engine divergence testing

11 min read Original article ↗
#!/usr/bin/env python3 import argparse, os, gc, json, random, csv os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1") import numpy as np import torch import torch.nn.functional as F from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM from matplotlib import pyplot as plt # Optional, used only when --engine vllm try: from vllm import LLM, SamplingParams except Exception: LLM = None SamplingParams = None SUFFIX = "Let's think step by step and answer in \\boxed{}." def ensure_dir(d): os.makedirs(d, exist_ok=True) def set_seed(s): random.seed(s); np.random.seed(s); torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed_all(s) def _get_base_module_and_head(model: AutoModelForCausalLM): head = getattr(model, "lm_head", None) or getattr(model, "embed_out", None) base = None base_prefix = getattr(model, "base_model_prefix", None) if isinstance(base_prefix, str) and hasattr(model, base_prefix): base = getattr(model, base_prefix) for name in ["model", "transformer", "language_model", "backbone", "base_model"]: if base is None and hasattr(model, name): base = getattr(model, name) if base is None or head is None: raise RuntimeError("Could not locate base transformer or lm_head on the HF model.") return base, head @torch.no_grad() def _chunked_token_logprobs_from_hidden(hiddens: torch.Tensor, head_weight: torch.Tensor, targets: torch.Tensor, time_chunk: int = 256) -> torch.Tensor: """ hiddens: [B, L, H] (on CPU) head_weight: [V, H] on its own (likely CUDA) device targets: [B, L-1] (on CPU) returns lp: [B, L-1] (float32, on CPU) """ B, L, H = hiddens.shape T = L - 1 V, Hw = head_weight.shape assert H == Hw, f"Hidden size mismatch: {H} vs {Hw}" out_device = hiddens.device # CPU weight_device = head_weight.device # e.g., cuda:0/1 # Ensure we run on the correct CUDA device to avoid cross-device kernel weirdness if weight_device.type == 'cuda': torch.cuda.set_device(weight_device.index) # keep W on the head device, dtype matched with hiddens (bf16 on CPU is fine) W = head_weight.to(dtype=hiddens.dtype, device=weight_device) lp = torch.empty((B, T), dtype=torch.float32, device=out_device) t = 0 cur_chunk = max(1, int(time_chunk)) while t < T: cur = min(cur_chunk, T - t) try: # Slice, make contiguous, then copy CPU->CUDA (blocking copy for stability) h = hiddens[:, t:t+cur, :].contiguous().view(-1, H).to(weight_device, non_blocking=False) y = targets[:, t:t+cur].contiguous().view(-1).to(weight_device, non_blocking=False) logits = F.linear(h, W) # [B*cur, V] on weight_device # Numerically stable log softmax for chosen tokens m = logits.max(dim=-1).values lse = torch.logsumexp((logits - m.unsqueeze(1)).to(torch.float32), dim=-1) + m.to(torch.float32) chosen = logits.gather(1, y.unsqueeze(1)).squeeze(1).to(torch.float32) lp_chunk = (chosen - lse).view(B, cur).to(out_device, non_blocking=False) lp[:, t:t+cur] = lp_chunk # cleanup del h, y, logits, m, lse, chosen, lp_chunk if torch.cuda.is_available(): torch.cuda.empty_cache() t += cur cur_chunk = time_chunk except RuntimeError as e: # Back off time chunk on OOM if "out of memory" in str(e).lower() and cur > 1: if torch.cuda.is_available(): torch.cuda.empty_cache() cur_chunk = max(1, cur // 2) continue raise return lp def infer_log_probs_batch(model: AutoModelForCausalLM, sequences, device_hint: str, time_chunk: int = 256): """ Run base forward (sharded on GPUs), offload [B,L,H] to CPU, then do chunked head on the head device. """ use_cpu_io = hasattr(model, "hf_device_map") or hasattr(model, "device_map") tgt_device = 'cpu' if use_cpu_io else device_hint lens = [len(s) for s in sequences] Lm = max(lens) if lens else 0 pad_id = model.config.pad_token_id if model.config.pad_token_id is not None else (model.config.eos_token_id or 0) inp = torch.full((len(sequences), Lm), pad_id, dtype=torch.long, device=tgt_device) for i, s in enumerate(sequences): if len(s) > 0: inp[i, :len(s)] = torch.tensor(s, dtype=torch.long, device=tgt_device) attn = (inp != pad_id).long() try: model.config.use_cache = False except Exception: pass base, head = _get_base_module_and_head(model) # If the model was loaded with a device_map (even single GPU), ensure inputs # are moved to the same device as the base module's first parameter (embedding device) if use_cpu_io: try: first_param_device = next(base.parameters()).device if inp.device != first_param_device: inp = inp.to(first_param_device) attn = attn.to(first_param_device) except Exception: pass with torch.inference_mode(): outputs = base( input_ids=inp, attention_mask=attn, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=False, ) hidden_states = outputs[0] # [B, Lm, H] on last shard GPU del outputs # Offload activations to CPU, then make contiguous + (optional) pin if hidden_states.is_cuda: hidden_states = hidden_states.to('cpu', non_blocking=False) torch.cuda.synchronize(); torch.cuda.empty_cache() hidden_states = hidden_states.contiguous() try: hidden_states = hidden_states.pin_memory() except Exception: pass tgt = inp[:, 1:] # [B, Lm-1] on CPU lp = _chunked_token_logprobs_from_hidden( hiddens=hidden_states, head_weight=head.weight, targets=tgt, time_chunk=time_chunk, ) del hidden_states if torch.cuda.is_available(): torch.cuda.empty_cache() return lp, lens # ------------------------- plotting ------------------------- def plot_correlation(eng_logp, hf_logp, out_png, log_space=False): p_min, p_max = (-40, 0) if log_space else (0, 1) X = (eng_logp if log_space else eng_logp.exp()).float().cpu().numpy() Y = (hf_logp if log_space else hf_logp.exp()).float().cpu().numpy() fig, axes = plt.subplots(2, 1, figsize=(10, 10), gridspec_kw={'height_ratios': [4, 2]}) axes[0].set_aspect('equal') axes[0].set_xlim(p_min, p_max); axes[0].set_ylim(p_min, p_max) axes[1].set_xlim(p_min, p_max) hist, xe, ye = np.histogram2d(X, Y, bins=100, range=[[p_min, p_max], [p_min, p_max]], density=False) hist = np.log(hist + 1e-10) Xm, Ym = np.meshgrid(xe[:-1], ye[:-1]) im = axes[0].pcolormesh(Xm, Ym, hist.T, shading='auto') axes[0].plot([p_min, p_max], [p_min, p_max], linestyle='--', linewidth=1) axes[0].set_xlabel('Engine ' + ('log-prob' if log_space else 'probability')) axes[0].set_ylabel('HF ' + ('log-prob' if log_space else 'probability')) fig.colorbar(im, ax=axes[0], label='Log Frequency') hx, xe1 = np.histogram(X, bins=100, range=[p_min, p_max], density=True) hy, ye1 = np.histogram(Y, bins=100, range=[p_min, p_max], density=True) axes[1].plot(xe1[:-1], np.log(hx + 1e-12), label='Engine') axes[1].plot(ye1[:-1], np.log(hy + 1e-12), label='HF') axes[1].legend(); axes[1].set_ylabel('Log Density'); axes[1].set_xlabel('log-prob' if log_space else 'probability') plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close() def plot_sample_prob_diff(hf_logp, eng_logp, out_png): diff = hf_logp.exp() - eng_logp.exp() xs = np.arange(len(diff)) plt.figure(figsize=(10, 4)) plt.plot(xs, diff.cpu().numpy()) plt.xlabel('Response token index'); plt.ylabel('Δ prob (HF − Engine)') plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close() # ------------------------- engines ------------------------- def run_vllm(prompt_ids_list, model, batch_size, max_new_tokens, seed=0, use_inductor=False): assert LLM is not None, 'vLLM is not installed.' llm = LLM( model=model, dtype=torch.bfloat16, trust_remote_code=True, compilation_config={ "use_inductor": use_inductor, } ) # NOTE: script not valid for temp != 1.0, vllm needs a patch bc logprobs are returned before sampling sp = SamplingParams( temperature=1.0, top_p=1.0, top_k=-1, max_tokens=int(max_new_tokens), logprobs=0, detokenize=True, seed=seed ) prompt_ids, gen_ids, gen_logp, texts = [], [], [], [] for i in range(0, len(prompt_ids_list), batch_size): batch_prompts = prompt_ids_list[i:i+batch_size] # vLLM expects a list of PromptInputs; use tokenized prompts schema batch_inputs = [{"prompt_token_ids": ids} for ids in batch_prompts] outs = llm.generate(batch_inputs, sampling_params=sp) for pid_sent, o in zip(batch_prompts, outs): sample = o.outputs[0] p_ids = list(pid_sent) g_ids = list(sample.token_ids) if sample.logprobs is None: raise RuntimeError("vLLM returned no logprobs; set SamplingParams.logprobs >= 1.") chosen_lp = [] for t, tok_id in enumerate(g_ids): lp_dict = sample.logprobs[t] # dict[token_id -> Logprob] lp_obj = lp_dict.get(tok_id) if lp_obj is None: raise RuntimeError("Chosen token not in returned top-k logprobs (???)") chosen_lp.append(float(lp_obj.logprob)) g_lp = torch.tensor(chosen_lp, dtype=torch.float32) prompt_ids.append(p_ids); gen_ids.append(g_ids); gen_logp.append(g_lp); texts.append(sample.text) del llm; gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return prompt_ids, gen_ids, gen_logp, texts def run_sglang(prompt_ids_list, model, batch_size, max_new_tokens): from sglang.srt.entrypoints.engine import Engine engine = Engine(model_path=model, dtype='bfloat16', tp_size=1, trust_remote_code=True, load_format='auto', log_level='INFO', max_running_requests=1024) prompt_ids, gen_ids, gen_logp, texts = [], [], [], [] for i in range(0, len(prompt_ids_list), batch_size): batch_ids = prompt_ids_list[i:i+batch_size] sp = { "n": 1, "max_new_tokens": int(max_new_tokens), "temperature": 1.0, "top_p": 1.0, "top_k": -1, "ignore_eos": False, "min_new_tokens": 0, "skip_special_tokens": True, "spaces_between_special_tokens": True, } outs = engine.generate(prompt=None, sampling_params=sp, return_logprob=True, input_ids=batch_ids, image_data=None) for pid, out in zip(batch_ids, outs): tpl = out["meta_info"]["output_token_logprobs"] g_lp = torch.tensor([float(t[0]) for t in tpl], dtype=torch.float32) g_ii = [int(t[1]) for t in tpl] prompt_ids.append(pid); gen_ids.append(g_ii); gen_logp.append(g_lp); texts.append(out["text"]) try: engine.shutdown() except Exception: pass del engine; gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return prompt_ids, gen_ids, gen_logp, texts def build_hf_model(model_name): # Give accelerate a memory budget so it offloads instead of spiking one GPU max_memory = {} if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): total = torch.cuda.get_device_properties(i).total_memory giB = int((total * 0.95) // (1024**3)) max_memory[i] = f"{giB}GiB" max_memory["cpu"] = "256GiB" # Flash-Attention 2 only hf = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2", trust_remote_code=True, device_map='auto' if torch.cuda.is_available() else None, max_memory=max_memory if torch.cuda.is_available() else None, ) hf.eval() try: hf.config.use_cache = False except Exception: pass if not torch.cuda.is_available(): hf.to("cpu") return hf, ('cuda' if torch.cuda.is_available() else 'cpu') # ------------------------- main ------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument('--engine', choices=['vllm','sglang'], required=True) ap.add_argument('--model', required=True) ap.add_argument('--batch-size', type=int, default=256) ap.add_argument('--hf-batch-size', type=int, default=64, help='Batch size for HF scoring (keep small to avoid OOM)') ap.add_argument('--time-chunk-size', type=int, default=128, help='Time chunk (tokens) for head projection') ap.add_argument('--max-new-tokens', type=int, default=32768) ap.add_argument('--seed', type=int, default=0) ap.add_argument('--n', type=int, default=16) ap.add_argument('--out', default='out') ap.add_argument('--dataset', default='AI-MO/aimo-validation-aime') # or MathArena/aime_2025 ap.add_argument('--vllm-use-inductor', action='store_true') args = ap.parse_args() set_seed(args.seed); ensure_dir(args.out) if args.vllm_use_inductor: assert args.engine == 'vllm', 'vLLM with inductor requires vLLM engine' ds = load_dataset(args.dataset, split='train') user_texts = [f"{row['problem']}\n\n{SUFFIX}" for row in ds] user_texts = user_texts * args.n tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) # Build chat prompts using the model's chat template messages_list = [[{"role": "user", "content": text}] for text in user_texts] chat_prompt_ids = [ tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=True, return_tensors=None) for msgs in messages_list ] # Ensure lists of ints for sglang chat_prompt_ids = [ ids.tolist() if hasattr(ids, 'tolist') else (list(ids[0]) if isinstance(ids, (tuple, list)) and hasattr(ids[0], '__iter__') else list(ids)) for ids in chat_prompt_ids ] if args.engine == 'vllm': p_ids, g_ids, g_lp, texts = run_vllm(chat_prompt_ids, args.model, args.batch_size, args.max_new_tokens, seed=args.seed, use_inductor=args.vllm_use_inductor) else: p_ids, g_ids, g_lp, texts = run_sglang(chat_prompt_ids, args.model, args.batch_size, args.max_new_tokens) hf, device = build_hf_model(args.model) # Sanity checks V = hf.get_input_embeddings().weight.size(0) for idx, (pi, gi, elp) in enumerate(zip(p_ids, g_ids, g_lp)): assert len(gi) == len(elp), f"Engine IDs/logprobs length mismatch at sample {idx} (len(ids)={len(gi)} vs len(lp)={len(elp)})" if (pi and max(pi) >= V) or (gi and max(gi) >= V): raise ValueError(f"Token id out of range at sample {idx}: vocab={V}, max_id={max(pi+gi)}") sequences = [pi + gi for pi, gi in zip(p_ids, g_ids)] max_pos = getattr(hf.config, "max_position_embeddings", None) if max_pos is not None: too_long = [i for i, s in enumerate(sequences) if len(s) > max_pos] if len(too_long) > 0: print(f"Warning: {len(too_long)} sequences exceed model max_position_embeddings={max_pos} and may be truncated or slow.") # ---------- HF scoring (length-bucket to cut padding) ---------- idxs = list(range(len(sequences))) idxs.sort(key=lambda i: len(sequences[i])) # shortest -> longest hf_rows = [None] * len(sequences) for start in range(0, len(sequences), args.hf_batch_size): batch_idx = idxs[start:start+args.hf_batch_size] seq_batch = [sequences[i] for i in batch_idx] lp_batch, lens_batch = infer_log_probs_batch( hf, seq_batch, device, time_chunk=args.time_chunk_size ) for j, (i_orig, L) in enumerate(zip(batch_idx, lens_batch)): hf_rows[i_orig] = lp_batch[j, :L-1].detach().cpu() if torch.cuda.is_available(): torch.cuda.empty_cache() # Align engine vs HF on generated tokens eng_slices, hf_slices, slice_indices = [], [], [] for idx, (pi, gi, eng_lp) in enumerate(zip(p_ids, g_ids, g_lp)): Lp, Lg = len(pi), len(gi) hf_row = hf_rows[idx] start = max(Lp - 1, 0) hf_slice = hf_row[start:start+Lg] m = min(len(eng_lp), len(hf_slice)) if m > 0: eng_slices.append(eng_lp[:m]) hf_slices.append(hf_slice[:m]) slice_indices.append(idx) eng_all = torch.cat(eng_slices) if eng_slices else torch.empty(0) hf_all = torch.cat(hf_slices) if hf_slices else torch.empty(0) # Save raw per-item # Re-ensure output directory and use absolute path to avoid cwd ambiguity out_dir = os.path.abspath(args.out) os.makedirs(out_dir, exist_ok=True) out_jsonl = os.path.join(out_dir, 'engine_outputs.jsonl') with open(out_jsonl, 'w', encoding='utf-8') as f: for pi, gi, elp, txt in zip(p_ids, g_ids, g_lp, texts): f.write(json.dumps({'prompt_ids': pi, 'gen_ids': gi, 'gen_logprobs': [float(x) for x in elp.tolist()], 'text': txt})+'\n') # Summary metrics + plots if len(eng_all) > 0: e = eng_all.float().cpu().numpy() h = hf_all.float().cpu().numpy() mae = float(np.mean(np.abs(h - e))) rmse = float(np.sqrt(np.mean((h - e) ** 2))) corr = float(np.corrcoef(h, e)[0, 1]) if len(h) > 1 else float('nan') # Additional metrics in probability space lnp = np.clip(h.astype(np.float64), -80.0, 0.0) lnq = np.clip(e.astype(np.float64), -80.0, 0.0) p_raw = np.exp(lnp) q_raw = np.exp(lnq) diff = p_raw - q_raw rollout_probs_diff_mean = float(np.mean(diff)) rollout_probs_diff_std = float(np.std(diff)) # Bernoulli KL between chosen-token probabilities (clip to avoid log(0)) eps = 1e-12 p = np.clip(p_raw, eps, 1.0 - eps) q = np.clip(q_raw, eps, 1.0 - eps) kl_vals = p * (np.log(p) - np.log(q)) + (1.0 - p) * (np.log1p(-p) - np.log1p(-q)) kl_divergence = float(np.mean(kl_vals)) # Completion length stats (in tokens) completion_lengths = np.array([len(gen_ids) for gen_ids in g_ids], dtype=np.int64) avg_completion_length = float(np.mean(completion_lengths)) if completion_lengths.size > 0 else 0.0 min_completion_length = int(np.min(completion_lengths)) if completion_lengths.size > 0 else 0 max_completion_length = int(np.max(completion_lengths)) if completion_lengths.size > 0 else 0 with open(os.path.join(out_dir, 'summary_metrics.json'), 'w', encoding='utf-8') as f: json.dump({ "mae_logprob": mae, "rmse_logprob": rmse, "pearson_r": corr, "kl_divergence": kl_divergence, "rollout_probs_diff_mean": rollout_probs_diff_mean, "rollout_probs_diff_std": rollout_probs_diff_std, "avg_completion_length": avg_completion_length, "min_completion_length": min_completion_length, "max_completion_length": max_completion_length, "n_tokens": int(len(h)) }, f, indent=2) plot_correlation(eng_all, hf_all, os.path.join(out_dir, 'diff_raw.png'), log_space=False) plot_correlation(eng_all, hf_all, os.path.join(out_dir, 'diff_log.png'), log_space=True) j = len(eng_slices) // 2 if eng_slices else 0 if eng_slices: plot_sample_prob_diff(hf_slices[j], eng_slices[j], os.path.join(out_dir, 'sample_prob_diff.png')) orig_j = slice_indices[j] with open(os.path.join(out_dir, 'sample_completion.txt'), 'w', encoding='utf-8') as f: f.write(texts[orig_j]) toks = tokenizer.convert_ids_to_tokens(g_ids[orig_j][:len(eng_slices[j])]) with open(os.path.join(out_dir, 'sample_token_diffs.csv'), 'w', newline='', encoding='utf-8') as cf: w = csv.writer(cf); w.writerow(['idx','token','prob_hf','prob_engine','delta']) for i, (hlp, elp, t) in enumerate(zip(hf_slices[j], eng_slices[j], toks)): ph, pe = float(hlp.exp()), float(elp.exp()) w.writerow([i, t, ph, pe, ph-pe]) j_longest = max(range(len(slice_indices)), key=lambda k: len(g_ids[slice_indices[k]])) orig_longest = slice_indices[j_longest] plot_sample_prob_diff(hf_slices[j_longest], eng_slices[j_longest], os.path.join(out_dir, 'longest_prob_diff.png')) with open(os.path.join(out_dir, 'longest_completion.txt'), 'w', encoding='utf-8') as f: f.write(texts[orig_longest]) toks = tokenizer.convert_ids_to_tokens(g_ids[orig_longest][:len(eng_slices[j_longest])]) with open(os.path.join(out_dir, 'longest_token_diffs.csv'), 'w', newline='', encoding='utf-8') as cf: w = csv.writer(cf); w.writerow(['idx','token','prob_hf','prob_engine','delta']) for i, (hlp, elp, t) in enumerate(zip(hf_slices[j_longest], eng_slices[j_longest], toks)): ph, pe = float(hlp.exp()), float(elp.exp()) w.writerow([i, t, ph, pe, ph-pe]) # Clean up torch.distributed process group if initialized (avoids NCCL warning) try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): dist.destroy_process_group() except Exception: pass print('Saved to', out_dir) if __name__ == '__main__': main()