GitHub - refortif-ai/trt-weight-extractor

4 min read Original article ↗

trt-model-dump

Extract weight tensors from any REFIT-enabled TensorRT .engine file — no ONNX file, no original model, no sidecar metadata required.

How it works

The tool uses TensorRT's IRefitter API and a refit-diff mapping technique to locate and extract every refittable weight from the engine binary:

  1. Enumerate — Load the engine and call IRefitter.get_all_weights() to discover all refittable weight names, sizes, and dtypes.
  2. Zero baseline — Refit every weight to zero and serialize the engine. This produces a known reference binary.
  3. Marker injection — For each weight, inject sequential marker values [42, 43, ..., 42+N) (all other weights stay zero), refit, and serialize.
  4. Diff & map — Compare the marked binary against the zeroed baseline using vectorized numpy operations. The marker values reveal the exact byte offsets and element ordering of each weight in the engine binary.
  5. Extract — Read the actual weight values from the original (normalized) engine bytes at the mapped offsets.

Key details

  • Engine normalization: TRT may change the serialized length after a refit cycle. The tool re-serializes the original engine first to produce a stable reference format.
  • Float16 chunking: float16 can only represent integers exactly up to 2048, so large float16 weights are processed in chunks of 2048 elements each.
  • Fused constants: Some ONNX graph constants (attention scaling factors, GELU coefficients) get fused into TRT kernels and can't be extracted. These are detected and skipped with a warning.
  • All weights must be set: Without REFIT_INDIVIDUAL, TRT requires all weights to be provided before refit_cuda_engine(). The tool sets all non-target weights to zero during each marker injection pass.

Requirements

  • NVIDIA GPU with CUDA support
  • TensorRT 10.x (Python bindings)
  • NumPy
  • safetensors (for .safetensors output)

For running the validation script, additionally:

  • PyTorch
  • ONNX
  • HuggingFace Transformers

Usage

Extract weights from an engine

# Dump to a single safetensors file
python trt_model_dump.py dump <engine.engine> -o weights.safetensors

# Dump to a compressed numpy archive
python trt_model_dump.py dump <engine.engine> -o weights.npz

# Dump to a directory of individual .npy files
python trt_model_dump.py dump <engine.engine> -o weights_dir/

# Verbose mode (prints weight names, sizes, dtypes)
python trt_model_dump.py dump <engine.engine> -o weights.safetensors -v

Arguments:

Argument Description
engine Path to the .engine file (must be built with REFIT flag)
-o, --output Output path. Format is inferred from extension: .safetensors, .npz, or a directory path for individual .npy files
-v, --verbose Print detailed info about each weight during extraction

Create a test engine (for development/validation)

# Build a REFIT-enabled engine from a tiny MLP model
python convert_hf_to_trt.py -o model.engine

# Keep the intermediate ONNX file
python convert_hf_to_trt.py -o model.engine --keep-onnx

Arguments:

Argument Description
-o, --output Output engine path (default: model.engine)
--seq-len Sequence length for optimization profiles (default: 64)
--keep-onnx Keep the intermediate ONNX file instead of deleting it

This also saves the PyTorch state dict as <name>_state_dict.pt alongside the engine for manual comparison.

Run end-to-end validation

This runs two tests:

  1. TinyMLP — A custom 3-layer MLP (7,312 parameters, 6 weights). Creates the model, converts to TRT, extracts weights, and verifies bit-exact match against the original PyTorch state dict.
  2. prajjwal1/bert-tiny — A real HuggingFace BERT model (4.4M parameters, 39 weights). Downloads from the Hub, converts to TRT, extracts, and verifies.

Both tests should report PASS with 0.00e+00 max error for every weight.

Constraints

  • REFIT flag required — The engine must have been built with trt.BuilderFlag.REFIT. Without this flag, the IRefitter API is unavailable and no weights can be extracted.
  • Only refittable weights — Weights that TRT fuses into optimized kernels (e.g., certain biases, layer norm parameters after fusion) are not exposed by the refitter and cannot be recovered.
  • Flat output — Extracted tensors are 1D (flat). The refitter API provides element count but not the original tensor shape. Shape recovery requires external knowledge of the model architecture.

Output formats

Format Extension/Path Description
SafeTensors .safetensors Single file, HuggingFace-compatible, memory-mappable
NumPy archive .npz Single compressed file with np.savez
NumPy directory any directory path One .npy file per weight (slashes in names replaced with __)

Project structure

trt-weight-extractor/
  SPEC.md                  # Design specification
  trt_model_dump.py        # Main extraction tool (the only file that matters)
  convert_hf_to_trt.py     # Test utility: HF model -> ONNX -> REFIT TRT engine
  validate.py              # End-to-end validation script

Example output

$ python trt_model_dump.py dump model.engine -o weights.safetensors -v
Loading engine from model.engine...
Found 6 refittable weights
  net.4.bias: size=16, dtype=<class 'numpy.float32'>
  net.2.bias: size=64, dtype=<class 'numpy.float32'>
  net.0.bias: size=64, dtype=<class 'numpy.float32'>
  net.2.weight: size=4096, dtype=<class 'numpy.float32'>
  net.4.weight: size=1024, dtype=<class 'numpy.float32'>
  net.0.weight: size=2048, dtype=<class 'numpy.float32'>
Normalizing engine format...
Creating zeroed baseline...
  Baseline created in 0.03s
Extracting weights...
  [1/6] net.4.bias: 16 elements, 0.02s
  [2/6] net.2.bias: 64 elements, 0.02s
  [3/6] net.0.bias: 64 elements, 0.02s
  [4/6] net.2.weight: 4096 elements, 0.03s
  [5/6] net.4.weight: 1024 elements, 0.02s
  [6/6] net.0.weight: 2048 elements, 0.02s
Saved to weights.safetensors (safetensors)
Done! Extracted 6 weights to weights.safetensors