TorchTL
A very minimal training loop abstraction for PyTorch.
Why TorchTL?
- Minimal: Only PyTorch as dependency
- Flexible: Use existing PyTorch models, no need to subclass
- Extensible: Callback system for custom behavior
- Automatic: Handles device management, mixed precision, gradient accumulation
- No magic: Simple, readable code that does what you expect
Features
Automatic device management (CPU/CUDA), mixed precision training, gradient accumulation, gradient clipping, checkpoints with resume capability, callback system for extensibility, early stopping, LR scheduling, progress tracking, exponential moving average (EMA), etc.
Installation
Quick Overview
import torch import torch.nn as nn from torch.utils.data import DataLoader from torchtl import Trainer model = nn.Linear(10, 1) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, device='cuda', mixed_precision=True ) history = trainer.fit(train_loader, val_loader, epochs=10)
Basic Usage
Simple training loop
import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from torchtl import Trainer X_train = torch.randn(1000, 10) y_train = torch.randn(1000, 1) train_dataset = TensorDataset(X_train, y_train) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) model = nn.Sequential( nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1) ) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() trainer = Trainer(model, optimizer, loss_fn) trainer.fit(train_loader, epochs=10)
Training with validation
X_val = torch.randn(200, 10) y_val = torch.randn(200, 1) val_dataset = TensorDataset(X_val, y_val) val_loader = DataLoader(val_dataset, batch_size=32) history = trainer.fit(train_loader, val_loader, epochs=10) print(f"Train losses: {history['train_loss']}") print(f"Val losses: {history['val_loss']}")
Mixed precision training
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, mixed_precision=True )
Gradient accumulation
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, grad_acc_steps=4 )
Gradient clipping
trainer = Trainer( model=model, optimizer=optimizer, loss_fn=loss_fn, max_grad_norm=1.0 )
Callbacks
Progress tracking
from torchtl import ProgressCallback trainer = Trainer(model, optimizer, loss_fn) trainer.add_callback(ProgressCallback(print_every=100)) trainer.fit(train_loader, epochs=10)
checkpointing
from torchtl import CheckpointCallback checkpoint_cb = CheckpointCallback( checkpoint_dir='./checkpoints', save_every_n_epochs=1, keep_last_n=3 ) trainer.add_callback(checkpoint_cb) trainer.fit(train_loader, val_loader, epochs=10)
Save best model only
checkpoint_cb = CheckpointCallback( checkpoint_dir='./checkpoints', save_best_only=True, monitor='val_loss', mode='min' ) trainer.add_callback(checkpoint_cb) trainer.fit(train_loader, val_loader, epochs=10)
Early stopping
from torchtl import EarlyStoppingCallback, StopTraining early_stop_cb = EarlyStoppingCallback( patience=5, monitor='val_loss', mode='min', min_delta=0.001 ) trainer.add_callback(early_stop_cb) try: trainer.fit(train_loader, val_loader, epochs=100) except StopTraining as e: print(f"Training stopped: {e}")
Learning rate scheduling
from torchtl import LearningRateSchedulerCallback from torch.optim.lr_scheduler import StepLR scheduler = StepLR(optimizer, step_size=5, gamma=0.1) scheduler_cb = LearningRateSchedulerCallback(scheduler) trainer.add_callback(scheduler_cb) trainer.fit(train_loader, epochs=20)
ReduceLROnPlateau
from torch.optim.lr_scheduler import ReduceLROnPlateau scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3) scheduler_cb = LearningRateSchedulerCallback(scheduler) trainer.add_callback(scheduler_cb) trainer.fit(train_loader, val_loader, epochs=20)
Multiple callbacks
from torchtl import ( ProgressCallback, CheckpointCallback, EarlyStoppingCallback, LearningRateSchedulerCallback ) trainer.add_callback(ProgressCallback(print_every=50)) trainer.add_callback(CheckpointCallback('./checkpoints', save_best_only=True)) trainer.add_callback(EarlyStoppingCallback(patience=5)) trainer.add_callback(LearningRateSchedulerCallback(scheduler)) trainer.fit(train_loader, val_loader, epochs=100)
Checkpoints
Manual save/load
trainer.save_checkpoint('./checkpoint.pt') trainer.load_checkpoint('./checkpoint.pt') trainer.fit(train_loader, epochs=10)
Save with extra state
trainer.save_checkpoint('./checkpoint.pt', best_accuracy=0.95, notes="best model")
Load with strict=false
trainer.load_checkpoint('./checkpoint.pt', strict=False)
Utilities
Count parameters
from torchtl import count_params total_params = count_params(model) trainable_params = count_params(model, trainable_only=True) print(f"Total: {total_params}, Trainable: {trainable_params}")
Freeze/unfreeze layers
from torchtl import freeze_layers, unfreeze_layers freeze_layers(model) unfreeze_layers(model, layer_names=['fc', 'classifier']) freeze_layers(model, layer_names=['conv1', 'conv2'])
Set random seed
from torchtl import set_seed set_seed(42)
Learning rate
from torchtl import get_lr, set_lr current_lr = get_lr(optimizer) print(f"Current LR: {current_lr}") set_lr(optimizer, 0.0001)
Exponential moving average
from torchtl import ExponentialMovingAverage ema = ExponentialMovingAverage(model, decay=0.999) for epoch in range(epochs): trainer.train_epoch(train_loader) ema.update() ema.apply_shadow() val_metrics = trainer.validate(val_loader) ema.restore()
Custom Callbacks
from torchtl import Callback class CustomCallback(Callback): def on_epoch_start(self, trainer): print(f"Starting epoch {trainer.epoch + 1}") def on_epoch_end(self, trainer, metrics): print(f"Epoch {trainer.epoch} finished with loss: {metrics['loss']:.4f}") def on_batch_end(self, trainer, batch_idx, batch, metrics): if trainer.global_step % 100 == 0: print(f"Step {trainer.global_step}, Loss: {metrics['loss']:.4f}") trainer.add_callback(CustomCallback()) trainer.fit(train_loader, epochs=10)
Batch Format Support
TorchTL supports multiple batch formats.
Tuple/list format
batch = (inputs, targets)
Dictionary format
batch = {'inputs': inputs, 'targets': targets} batch = {'input': inputs, 'target': targets}
Misc usage
Custom training loop
for epoch in range(10): train_metrics = trainer.train_epoch(train_loader) val_metrics = trainer.validate(val_loader) print(f"Epoch {epoch}: Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['val_loss']:.4f}") if val_metrics['val_loss'] < best_loss: best_loss = val_metrics['val_loss'] trainer.save_checkpoint('./best_model.pt')
Access internal state
print(f"Current epoch: {trainer.epoch}") print(f"Global step: {trainer.global_step}") print(f"Device: {trainer.device}")
License
Apache v2.0 License.