lightning-extra
A collection of PyTorch Lightning plugins and utilities for cloud-native machine learning, with a focus on Azure Blob Storage integration and experiment tracking.
Features
- AzureBlobCheckpointIO - PyTorch Lightning checkpoint plugin for Azure Blob Storage with content-addressable naming
- AzureBlobDataset - Efficient dataset class for loading files from Azure Blob Storage with local caching
- SQLiteLogger - Custom Lightning logger with SQLite backend for local experiment tracking and hyperparameter management
Quick Start
AzureBlobCheckpointIO
This plugin takes care of saving Lightning model checkpoints into Azure blob storage.
Checkpoint names are optionally suffixed with a content hash that makes them unique, and can carry metadata on metric of interest (e.g. validation loss).
We also provide a convenience function for retrieving the checkpoint and loading it back into Pytorch for inference.
from azure.storage.blob import BlobServiceClient from lightning import Trainer from lightning.extra.azure.blob import AzureBlobCheckpointIO # Setup Azure container blob_service_client = BlobServiceClient.from_connection_string( connection_string ) container = blob_service_client.get_container_client("checkpoints") # Create checkpoint plugin checkpoint_io = AzureBlobCheckpointIO( container=container, prefix="experiments/exp-001/", use_content_hash=True ) # Use with trainer trainer = Trainer( plugins=[checkpoint_io], max_epochs=100 ) trainer.fit(model, train_loader, val_loader) # Load checkpoint back for inference from lightning.extra.azure.blob import load_checkpoint_from_azure checkpoint = load_checkpoint_from_azure( container=container, checkpoint_name="epoch=10-hash=a3f9c2e1.ckpt", prefix="experiments/exp-001/", map_location="cpu" ) # Restore model state model.load_state_dict(checkpoint["state_dict"]) model.eval() # Run inference with torch.no_grad(): predictions = model(test_batch)
AzureBlobDataset
Load Pytorch datasets directy from Azure with local disk caching (cache size is configurable).
Internally, the caching mechanism relies on diskcache.
NB: The file cache is cleared once the AzureBlobDataset object is garbage collected (in the worst case, at the end of the program).
from azure.storage.blob import BlobServiceClient from lightning.extra.azure.blob import AzureBlobDataset from torch.utils.data import DataLoader # Setup blob_service_client = BlobServiceClient.from_connection_string( connection_string ) container = blob_service_client.get_container_client("training-data") # List files to load blob_files = [ "images/cat_001.jpg", "images/cat_002.jpg", "images/dog_001.jpg", ] # Create dataset (pre-caches all files) dataset = AzureBlobDataset( container=container, abs_fnames=blob_files ) # Use with DataLoader dataloader = DataLoader(dataset, batch_size=32) trainer.fit(model, dataloader)
SQLiteLogger
Track and search experiments with a good ol' SQL:
from lightning import Trainer, LightningModule from lightning.extra.sqlite import SQLiteLogger # Create logger logger = SQLiteLogger( db_path="experiments/training.db", experiment_name="image_classification" ) # Use with trainer trainer = Trainer( logger=logger, max_epochs=100 ) trainer.fit(model, train_loader, val_loader)
Installation
Using Conda (Recommended)
# Create a new conda environment conda create -n lightning-ml python=3.11 # Activate the environment conda activate lightning-ml # Clone the repository git clone https://github.com/ocramz/lightning-extra.git cd lightning-extra # Install in development mode pip install -e . # For development with testing tools pip install -e ".[dev]"
Using pip
# Clone and install git clone https://github.com/ocramz/lightning-extra.git cd lightning-extra pip install -e .
Configuration
Environment Setup
For Azure Blob Storage access, set your connection string:
export AZURE_STORAGE_CONNECTION_STRING="your-connection-string-here"
Or create a .env file in the project root:
AZURE_STORAGE_CONNECTION_STRING=your-connection-string-here
The test fixtures will automatically load this via python-dotenv.
Testing
Run All Tests (Unit + SQLite)
Run Specific Test Suite
# Unit tests only make test-unit # SQLite logger tests only make test-sqlite # Azure integration tests (requires AZURE_STORAGE_CONNECTION_STRING) make test-azure # All tests including Azure make test-all # With verbose output make test-verbose # With coverage report make coverage
Project Structure
lightning-extra/
├── lightning/
│ └── extra/
│ ├── azure/
│ │ ├── blob/
│ │ │ ├── checkpoint_plugin.py # AzureBlobCheckpointIO
│ │ │ ├── dataset.py # AzureBlobDataset
│ │ │ └── __init__.py
│ │ └── __init__.py
│ └── sqlite/
│ ├── logger.py # SQLiteLogger
│ └── __init__.py
├── tests/
│ ├── conftest.py # Pytest fixtures (e.g. Azure setup)
│ ├── test_checkpoint.py # Checkpoint plugin tests
│ ├── test_dataset.py # Dataset tests
│ ├── test_sqlite_logger.py # SQLite logger tests
│ └── test_integration_checkpoint.py # Full Lightning training integration
├── pyproject.toml # Build configuration
├── pytest.ini # Pytest configuration with marks
├── Makefile # Development tasks
└── README.md
Test Marks
Tests are organized with pytest marks for selective execution:
@pytest.mark.unit- Unit tests (utilities, no external deps)@pytest.mark.sqlite- SQLiteLogger tests@pytest.mark.azure- Azure Blob Storage tests (require credentials)
Dependencies
Core Dependencies
torch >= 2.0pytorch-lightning >= 2.6azure-storage-blob >= 12.27diskcache >= 5.6
Integration Tests
The project includes comprehensive integration tests that exercise real PyTorch Lightning training workflows:
- test_train_save_and_load_checkpoint - Full training with checkpoint save/load
- test_train_upload_and_inference - Training, upload to Azure, and inference
- test_checkpoint_with_metrics_in_name - Checkpoint naming with metrics
- test_multiple_checkpoints_per_training - Multiple checkpoints per run
- test_training_resumption_from_checkpoint - Resume training from checkpoint
Advanced Usage
Content-Addressable Checkpoints
Enable content-addressable checkpoint naming so identical model weights produce the same filename:
checkpoint_io = AzureBlobCheckpointIO( container=container, prefix="experiments/exp-2024-12/", use_content_hash=True, hash_length=8 ) # Checkpoints named like: epoch=10-hash=a3f9c2e1.ckpt
Query Experiment Database
import sqlite3 conn = sqlite3.connect("experiments.db") cursor = conn.cursor() # Find best validation accuracy cursor.execute(''' SELECT experiments.version, metrics.metric_value FROM metrics JOIN experiments ON metrics.experiment_id = experiments.id WHERE experiments.name = ? AND metrics.metric_name = ? ORDER BY metrics.metric_value DESC LIMIT 1 ''', ('image_classification', 'val_acc')) best_version, best_acc = cursor.fetchone() print(f"Best accuracy: {best_acc} (version {best_version})")
Type-Aware Hyperparameter Storage
logger = SQLiteLogger(db_path="experiments.db", experiment_name="exp1") # Log hyperparameters with automatic type preservation logger.log_hyperparams({ "learning_rate": 0.001, # float "batch_size": 32, # int "optimizer": "adam", # str "use_regularization": True, # bool "layer_sizes": [128, 64, 32] # list }) # Retrieve with types preserved hparams = logger.get_hyperparams() assert isinstance(hparams["learning_rate"], float) assert isinstance(hparams["layer_sizes"], list)
Contributing
Contributions are welcome!
If you have a bug to report or a new feature you'd like to implement, please get in touch with the maintainers using the issue tracker.
Once your contribution is ready, please ensure all tests pass with : make test
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Citation
If you use lightning-extra in your research, please cite:
@software{lightning_extra, title={lightning-extra: PyTorch Lightning plugins and utilities for cloud-native ML}, author={Marco Zocca}, year={2025}, url={https://github.com/ocramz/lightning-extra} }