Nested Learning Reproduction
High-fidelity reproduction of Google's Nested Learning (HOPE) architecture, matching the quality bar set by lucidrains' TITAN reference while remaining fully open-source and uv managed.
Quickstart
uv python install 3.12 uv sync --all-extras uv run bash scripts/data/run_sample.sh uv run bash scripts/run_smoke.sh pilot # CPU-friendly HOPE block smoke test uv run bash scripts/run_e2e_smoke.sh # sync + sample data + smoke train + zeroshot eval uv run python scripts/eval/zeroshot.py \ --config configs/hope/pilot.yaml \ --checkpoint artifacts/examples/pilot_dummy.pt \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \ --tasks piqa --max-samples 32 --device cpu
Requirements
- Python 3.12+
uvpackage manager (https://github.com/astral-sh/uv)- PyTorch 2.9.0 LTS + CUDA-capable GPUs for accelerated runs (CPU works for smoke tests)
Setup
uv python install 3.12 uv sync --all-extras
Developer checks:
uv run ruff check .uv run mypy srcuv run pytest
Data Pipeline
- Tokenizer training
uv run python scripts/data/train_tokenizer.py \ --manifest configs/data/refinedweb_mixture.yaml \ --vocab-size 32000 \ --output-dir artifacts/tokenizer/refinedweb_mix \ --log-file data/mixtures/refinedweb_mix_tokenizer.json
- Corpus filtering + sharding
uv run python scripts/data/process_mixture.py \ configs/data/refinedweb_mixture_filtered.yaml \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \ --log-file data/mixtures/refinedweb_mix_filtered_shards.json
- Sample pipeline (downloads/licensed datasets, filters, shards, records stats)
uv run bash scripts/data/run_sample.sh
- Full pipeline (set env vars like
RW_LIMIT,WIKI_LIMIT, etc. to scale ingestion)uv run bash scripts/data/run_full.sh # default ~50k docs per corpus; increase limits as needed
Training
- Single GPU / CPU:
uv run python train.py --config-name pilot_smoke
- DDP (torchrun):
torchrun --nproc_per_node=2 train_dist.py --config-name mid
- CPU-only DDP smoke (verifies
gloobackend and deterministic seeding):uv run bash scripts/run_cpu_ddp_smoke.sh
- FSDP (see
docs/FSDP_SCALING_GUIDE.mdfor VRAM/batch sizing):# 760M run torchrun --nproc_per_node=2 train_fsdp.py --config-name hope/mid_fsdp # 1.3B run torchrun --nproc_per_node=2 train_fsdp.py --config-name hope/target_fsdp
- DeepSpeed (requires
deepspeedinstalled separately):deepspeed --num_gpus=2 train_deepspeed.py --config-name target \ deepspeed.config=configs/deepspeed/zero3.json
Pilot (3 B tokens) workflow
- Ensure TMUX session:
- Launch the long run on
cuda:1(≈52 h wall clock):set -a && source git.env && set +a export UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python train.py --config-name pilot \ logging.enabled=true logging.backend=wandb \ logging.project=nested-learning logging.run_name=pilot-main-$(date +%Y%m%d%H%M%S) \ train.device=cuda:1
- Checkpoints appear in
artifacts/checkpoints/pilot/step_*.ptevery 1 000 steps; the accompanying W&B run captures full telemetry. - Copy the final checkpoint, config, logs, and eval JSON/CSV into
artifacts/pilot_release/for distribution.
Logging
Set logging.enabled=true in Hydra configs (or override via CLI) to send metrics to W&B (default). For local JSON logs, use logging.backend=json logging.path=logs/run.json. Sample outputs reside in logs/ and artifacts/examples/.
Evaluation
- Zero-shot:
Use
uv run python scripts/eval/zeroshot.py \ --config configs/hope/mid.yaml \ --checkpoint checkpoints/mid/step_000100.pt \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \ --tasks all --max-samples 200 --device cuda:0
uv run python scripts/eval/zeroshot.py --list-tasksto display the full benchmark roster (PIQA, HellaSwag, WinoGrande, ARC-E/C, BoolQ, SIQA, CommonsenseQA, OpenBookQA). Seedocs/zeroshot_eval.mdfor details. - Needle-in-a-Haystack:
uv run python scripts/eval/niah.py \ --config configs/hope/mid.yaml \ --checkpoint checkpoints/mid/step_000100.pt \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \ --context-lengths 2048 4096 8192 --samples-per-length 20
- Continual-learning forgetting:
Plot forgetting curves via
uv run python scripts/eval/continual.py \ --config configs/hope/mid.yaml \ --checkpoints checkpoints/mid/step_000050.pt checkpoints/mid/step_000100.pt \ --segments-yaml configs/data/continual_segments_sample.yaml \ --batch-size 4 --max-batches 10 --memorize --memorize-steps 2
uv run python scripts/eval/plot_forgetting.py --continual-json eval/continual_mid.json. - Long-context diagnostics:
uv run python scripts/eval/passkey.py --config configs/hope/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --samples 64 --memorize uv run python scripts/eval/pg19_perplexity.py --config configs/hope/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt \ --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --max-samples 64
Evaluation summaries are written to eval/ alongside per-task JSON metrics.
Test-time memorization toggles
Every evaluator supports TITAN-style memorization so you can reproduce test-time adaptation:
uv run python scripts/eval/zeroshot.py \
... \
--memorize \
--memorize-steps 2 \
--memorize-use-correct-answer \
--memorize-no-reset # optional: retain updates across samples
--memorize-paths titan,cms_fast \
--memorize-surprise-threshold 0.01--memorizeturns on the learner with one LMS step per example by default.--memorize-stepscontrols the number of adaptation passes per prompt.--memorize-use-correct-answerinjects ground-truth text during memorization for ablations.--memorize-no-resetcarries memories across samples; omit it to reset every question.--memorize-pathsrestricts which levels receive teach-signal updates (titan,cms_fast, orall).--memorize-surprise-thresholdgates updates on average teach-signal norm, matching the paper’s surprise trigger.
Memorization metrics (baseline vs adaptive) are emitted alongside task accuracy for easy comparisons.
Releases
Before tagging or announcing a new checkpoint, work through docs/release_checklist.md so the bundle includes manifest validation reports, tokenizer coverage JSON, zero-shot/NIAH/continual/passkey/PG-19 eval outputs, forgetting plots, and filled checkpoint reports.
Performance & optimizer options
- Mixed precision: enable bf16 autocast via
train.mixed_precision.enabled=true train.mixed_precision.dtype=bf16(already enabled in pilot/mid/target configs). torch.compile: accelerate attention/core loops by togglingtrain.compile.enable=true train.compile.mode=max-autotune; failure falls back to eager unlesstrain.compile.strict=true.- Muon hybrid (default): all HOPE configs now set
optim.type=muon, routing ≥2D tensors through PyTorch 2.9's Muon optimizer while embeddings/norms stay on AdamW. Training logs emitoptim.muon_param_elems/optim.adamw_param_elemsso you can confirm the split. - Fused AdamW fallback: override with
optim.type=adamw optim.fused=autoif Muon is unavailable or if you want to compare against the AdamW ablation inreports/ablations.md. - Surprise gating: set
model.surprise_threshold=<float>to gate all inner updates on the average teach-signal norm (mirrors the paper’s “surprise” trigger). Evaluation CLIs expose--memorize-surprise-thresholdfor ad-hoc gating.
All Hydra knobs can be overridden from the CLI or composed via config groups (configs/hope/*.yaml). Use these flags in tandem with scripts/run_e2e_smoke.sh (automation) or scripts/run_cpu_ddp_smoke.sh (CPU-only determinism check) to validate releases quickly.
Documentation & References
docs/guide.md– full onboarding (setup → data → training → eval).docs/release_plan.md– release readiness checklist.docs/data_pipeline.md– large-scale sharding/tokenizer workflow.docs/scaling_guidance.md– roadmap for expanding data + compute footprints.docs/stage1_plan.md,docs/stage2_plan.md– architecture + experiment roadmaps.docs/stage2_progress.md– latest dual-GPU training/eval status and commands.docs/experiments_report.md– draft paper covering completed experiments.docs/stability_journal.md– chronological notes on NaN fixes & teach-scale tuning.docs/future_directions.md– prioritized roadmap after the initial release.reports/stage2_smoke.md– exact commands/artifacts for the release-ready smoke workflow.docs/FSDP_SCALING_GUIDE.md– dual-RTX 6000 Ada instructions for the mid/target FSDP configs.google_papers/– PDFs/markdown of Nested Learning & TITAN papers.CHANGELOG.md– user-facing changes per release.
Contributing
- Run formatting/tests (
uv run ruff check .,uv run pytest). - Document new configs or scripts in
docs/guide.mdand updateCHANGELOG.md. - Open a PR referencing the relevant NL/TITAN spec sections or planner transcript snippets.