GitHub - abdimoallim/torchtl: A very minimal training loop abstraction for PyTorch

3 min read Original article ↗

TorchTL

A very minimal training loop abstraction for PyTorch.

Why TorchTL?

  • Minimal: Only PyTorch as dependency
  • Flexible: Use existing PyTorch models, no need to subclass
  • Extensible: Callback system for custom behavior
  • Automatic: Handles device management, mixed precision, gradient accumulation
  • No magic: Simple, readable code that does what you expect

Features

Automatic device management (CPU/CUDA), mixed precision training, gradient accumulation, gradient clipping, checkpoints with resume capability, callback system for extensibility, early stopping, LR scheduling, progress tracking, exponential moving average (EMA), etc.

Installation

Quick Overview

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchtl import Trainer

model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device='cuda',
    mixed_precision=True
)

history = trainer.fit(train_loader, val_loader, epochs=10)

Basic Usage

Simple training loop

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchtl import Trainer

X_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

trainer = Trainer(model, optimizer, loss_fn)
trainer.fit(train_loader, epochs=10)

Training with validation

X_val = torch.randn(200, 10)
y_val = torch.randn(200, 1)
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32)

history = trainer.fit(train_loader, val_loader, epochs=10)
print(f"Train losses: {history['train_loss']}")
print(f"Val losses: {history['val_loss']}")

Mixed precision training

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    mixed_precision=True
)

Gradient accumulation

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    grad_acc_steps=4
)

Gradient clipping

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    max_grad_norm=1.0
)

Callbacks

Progress tracking

from torchtl import ProgressCallback

trainer = Trainer(model, optimizer, loss_fn)
trainer.add_callback(ProgressCallback(print_every=100))
trainer.fit(train_loader, epochs=10)

checkpointing

from torchtl import CheckpointCallback

checkpoint_cb = CheckpointCallback(
    checkpoint_dir='./checkpoints',
    save_every_n_epochs=1,
    keep_last_n=3
)

trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)

Save best model only

checkpoint_cb = CheckpointCallback(
    checkpoint_dir='./checkpoints',
    save_best_only=True,
    monitor='val_loss',
    mode='min'
)

trainer.add_callback(checkpoint_cb)
trainer.fit(train_loader, val_loader, epochs=10)

Early stopping

from torchtl import EarlyStoppingCallback, StopTraining

early_stop_cb = EarlyStoppingCallback(
    patience=5,
    monitor='val_loss',
    mode='min',
    min_delta=0.001
)

trainer.add_callback(early_stop_cb)

try:
    trainer.fit(train_loader, val_loader, epochs=100)
except StopTraining as e:
    print(f"Training stopped: {e}")

Learning rate scheduling

from torchtl import LearningRateSchedulerCallback
from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
scheduler_cb = LearningRateSchedulerCallback(scheduler)

trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, epochs=20)

ReduceLROnPlateau

from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3)
scheduler_cb = LearningRateSchedulerCallback(scheduler)

trainer.add_callback(scheduler_cb)
trainer.fit(train_loader, val_loader, epochs=20)

Multiple callbacks

from torchtl import (
    ProgressCallback,
    CheckpointCallback,
    EarlyStoppingCallback,
    LearningRateSchedulerCallback
)

trainer.add_callback(ProgressCallback(print_every=50))
trainer.add_callback(CheckpointCallback('./checkpoints', save_best_only=True))
trainer.add_callback(EarlyStoppingCallback(patience=5))
trainer.add_callback(LearningRateSchedulerCallback(scheduler))

trainer.fit(train_loader, val_loader, epochs=100)

Checkpoints

Manual save/load

trainer.save_checkpoint('./checkpoint.pt')

trainer.load_checkpoint('./checkpoint.pt')
trainer.fit(train_loader, epochs=10)

Save with extra state

trainer.save_checkpoint('./checkpoint.pt', best_accuracy=0.95, notes="best model")

Load with strict=false

trainer.load_checkpoint('./checkpoint.pt', strict=False)

Utilities

Count parameters

from torchtl import count_params

total_params = count_params(model)
trainable_params = count_params(model, trainable_only=True)
print(f"Total: {total_params}, Trainable: {trainable_params}")

Freeze/unfreeze layers

from torchtl import freeze_layers, unfreeze_layers

freeze_layers(model)

unfreeze_layers(model, layer_names=['fc', 'classifier'])

freeze_layers(model, layer_names=['conv1', 'conv2'])

Set random seed

from torchtl import set_seed

set_seed(42)

Learning rate

from torchtl import get_lr, set_lr

current_lr = get_lr(optimizer)
print(f"Current LR: {current_lr}")

set_lr(optimizer, 0.0001)

Exponential moving average

from torchtl import ExponentialMovingAverage

ema = ExponentialMovingAverage(model, decay=0.999)

for epoch in range(epochs):
    trainer.train_epoch(train_loader)
    ema.update()

ema.apply_shadow()
val_metrics = trainer.validate(val_loader)
ema.restore()

Custom Callbacks

from torchtl import Callback

class CustomCallback(Callback):
    def on_epoch_start(self, trainer):
        print(f"Starting epoch {trainer.epoch + 1}")

    def on_epoch_end(self, trainer, metrics):
        print(f"Epoch {trainer.epoch} finished with loss: {metrics['loss']:.4f}")

    def on_batch_end(self, trainer, batch_idx, batch, metrics):
        if trainer.global_step % 100 == 0:
            print(f"Step {trainer.global_step}, Loss: {metrics['loss']:.4f}")

trainer.add_callback(CustomCallback())
trainer.fit(train_loader, epochs=10)

Batch Format Support

TorchTL supports multiple batch formats.

Tuple/list format

batch = (inputs, targets)

Dictionary format

batch = {'inputs': inputs, 'targets': targets}
batch = {'input': inputs, 'target': targets}

Misc usage

Custom training loop

for epoch in range(10):
    train_metrics = trainer.train_epoch(train_loader)
    val_metrics = trainer.validate(val_loader)

    print(f"Epoch {epoch}: Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['val_loss']:.4f}")

    if val_metrics['val_loss'] < best_loss:
        best_loss = val_metrics['val_loss']
        trainer.save_checkpoint('./best_model.pt')

Access internal state

print(f"Current epoch: {trainer.epoch}")
print(f"Global step: {trainer.global_step}")
print(f"Device: {trainer.device}")

License

Apache v2.0 License.