GitHub - ocramz/lightning-extra: PyTorch Lightning plugins and utilities for cloud-native machine learning

5 min read Original article ↗

lightning-extra

CI License

A collection of PyTorch Lightning plugins and utilities for cloud-native machine learning, with a focus on Azure Blob Storage integration and experiment tracking.

Features

  • AzureBlobCheckpointIO - PyTorch Lightning checkpoint plugin for Azure Blob Storage with content-addressable naming
  • AzureBlobDataset - Efficient dataset class for loading files from Azure Blob Storage with local caching
  • SQLiteLogger - Custom Lightning logger with SQLite backend for local experiment tracking and hyperparameter management

Quick Start

AzureBlobCheckpointIO

This plugin takes care of saving Lightning model checkpoints into Azure blob storage.

Checkpoint names are optionally suffixed with a content hash that makes them unique, and can carry metadata on metric of interest (e.g. validation loss).

We also provide a convenience function for retrieving the checkpoint and loading it back into Pytorch for inference.

from azure.storage.blob import BlobServiceClient
from lightning import Trainer
from lightning.extra.azure.blob import AzureBlobCheckpointIO

# Setup Azure container
blob_service_client = BlobServiceClient.from_connection_string(
    connection_string
)
container = blob_service_client.get_container_client("checkpoints")

# Create checkpoint plugin
checkpoint_io = AzureBlobCheckpointIO(
    container=container,
    prefix="experiments/exp-001/",
    use_content_hash=True
)

# Use with trainer
trainer = Trainer(
    plugins=[checkpoint_io],
    max_epochs=100
)

trainer.fit(model, train_loader, val_loader)

# Load checkpoint back for inference
from lightning.extra.azure.blob import load_checkpoint_from_azure

checkpoint = load_checkpoint_from_azure(
    container=container,
    checkpoint_name="epoch=10-hash=a3f9c2e1.ckpt",
    prefix="experiments/exp-001/",
    map_location="cpu"
)

# Restore model state
model.load_state_dict(checkpoint["state_dict"])
model.eval()

# Run inference
with torch.no_grad():
    predictions = model(test_batch)

AzureBlobDataset

Load Pytorch datasets directy from Azure with local disk caching (cache size is configurable). Internally, the caching mechanism relies on diskcache.

NB: The file cache is cleared once the AzureBlobDataset object is garbage collected (in the worst case, at the end of the program).

from azure.storage.blob import BlobServiceClient
from lightning.extra.azure.blob import AzureBlobDataset
from torch.utils.data import DataLoader

# Setup
blob_service_client = BlobServiceClient.from_connection_string(
    connection_string
)
container = blob_service_client.get_container_client("training-data")

# List files to load
blob_files = [
    "images/cat_001.jpg",
    "images/cat_002.jpg",
    "images/dog_001.jpg",
]

# Create dataset (pre-caches all files)
dataset = AzureBlobDataset(
    container=container,
    abs_fnames=blob_files
)

# Use with DataLoader
dataloader = DataLoader(dataset, batch_size=32)
trainer.fit(model, dataloader)

SQLiteLogger

Track and search experiments with a good ol' SQL:

from lightning import Trainer, LightningModule
from lightning.extra.sqlite import SQLiteLogger

# Create logger
logger = SQLiteLogger(
    db_path="experiments/training.db",
    experiment_name="image_classification"
)

# Use with trainer
trainer = Trainer(
    logger=logger,
    max_epochs=100
)

trainer.fit(model, train_loader, val_loader)

Installation

Using Conda (Recommended)

# Create a new conda environment
conda create -n lightning-ml python=3.11

# Activate the environment
conda activate lightning-ml

# Clone the repository
git clone https://github.com/ocramz/lightning-extra.git
cd lightning-extra

# Install in development mode
pip install -e .

# For development with testing tools
pip install -e ".[dev]"

Using pip

# Clone and install
git clone https://github.com/ocramz/lightning-extra.git
cd lightning-extra
pip install -e .

Configuration

Environment Setup

For Azure Blob Storage access, set your connection string:

export AZURE_STORAGE_CONNECTION_STRING="your-connection-string-here"

Or create a .env file in the project root:

AZURE_STORAGE_CONNECTION_STRING=your-connection-string-here

The test fixtures will automatically load this via python-dotenv.

Testing

Run All Tests (Unit + SQLite)

Run Specific Test Suite

# Unit tests only
make test-unit

# SQLite logger tests only
make test-sqlite

# Azure integration tests (requires AZURE_STORAGE_CONNECTION_STRING)
make test-azure

# All tests including Azure
make test-all

# With verbose output
make test-verbose

# With coverage report
make coverage

Project Structure

lightning-extra/
├── lightning/
│   └── extra/
│       ├── azure/
│       │   ├── blob/
│       │   │   ├── checkpoint_plugin.py    # AzureBlobCheckpointIO
│       │   │   ├── dataset.py               # AzureBlobDataset
│       │   │   └── __init__.py
│       │   └── __init__.py
│       └── sqlite/
│           ├── logger.py                    # SQLiteLogger
│           └── __init__.py
├── tests/
│   ├── conftest.py                          # Pytest fixtures (e.g. Azure setup)
│   ├── test_checkpoint.py                   # Checkpoint plugin tests
│   ├── test_dataset.py                      # Dataset tests
│   ├── test_sqlite_logger.py                # SQLite logger tests
│   └── test_integration_checkpoint.py       # Full Lightning training integration
├── pyproject.toml                           # Build configuration
├── pytest.ini                               # Pytest configuration with marks
├── Makefile                                 # Development tasks
└── README.md

Test Marks

Tests are organized with pytest marks for selective execution:

  • @pytest.mark.unit - Unit tests (utilities, no external deps)
  • @pytest.mark.sqlite - SQLiteLogger tests
  • @pytest.mark.azure - Azure Blob Storage tests (require credentials)

Dependencies

Core Dependencies

  • torch >= 2.0
  • pytorch-lightning >= 2.6
  • azure-storage-blob >= 12.27
  • diskcache >= 5.6

Integration Tests

The project includes comprehensive integration tests that exercise real PyTorch Lightning training workflows:

  1. test_train_save_and_load_checkpoint - Full training with checkpoint save/load
  2. test_train_upload_and_inference - Training, upload to Azure, and inference
  3. test_checkpoint_with_metrics_in_name - Checkpoint naming with metrics
  4. test_multiple_checkpoints_per_training - Multiple checkpoints per run
  5. test_training_resumption_from_checkpoint - Resume training from checkpoint

Advanced Usage

Content-Addressable Checkpoints

Enable content-addressable checkpoint naming so identical model weights produce the same filename:

checkpoint_io = AzureBlobCheckpointIO(
    container=container,
    prefix="experiments/exp-2024-12/",
    use_content_hash=True,
    hash_length=8
)

# Checkpoints named like: epoch=10-hash=a3f9c2e1.ckpt

Query Experiment Database

import sqlite3

conn = sqlite3.connect("experiments.db")
cursor = conn.cursor()

# Find best validation accuracy
cursor.execute('''
    SELECT experiments.version, metrics.metric_value
    FROM metrics
    JOIN experiments ON metrics.experiment_id = experiments.id
    WHERE experiments.name = ? AND metrics.metric_name = ?
    ORDER BY metrics.metric_value DESC LIMIT 1
''', ('image_classification', 'val_acc'))

best_version, best_acc = cursor.fetchone()
print(f"Best accuracy: {best_acc} (version {best_version})")

Type-Aware Hyperparameter Storage

logger = SQLiteLogger(db_path="experiments.db", experiment_name="exp1")

# Log hyperparameters with automatic type preservation
logger.log_hyperparams({
    "learning_rate": 0.001,           # float
    "batch_size": 32,                 # int
    "optimizer": "adam",              # str
    "use_regularization": True,       # bool
    "layer_sizes": [128, 64, 32]      # list
})

# Retrieve with types preserved
hparams = logger.get_hyperparams()
assert isinstance(hparams["learning_rate"], float)
assert isinstance(hparams["layer_sizes"], list)

Contributing

Contributions are welcome!

If you have a bug to report or a new feature you'd like to implement, please get in touch with the maintainers using the issue tracker.

Once your contribution is ready, please ensure all tests pass with : make test

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Citation

If you use lightning-extra in your research, please cite:

@software{lightning_extra,
  title={lightning-extra: PyTorch Lightning plugins and utilities for cloud-native ML},
  author={Marco Zocca},
  year={2025},
  url={https://github.com/ocramz/lightning-extra}
}