This repo is based on torchprime and integrated with Weights & Biases and Hugging Face. It is designed to be a flexible pipeline for training custom research-scale models on Google Cloud TPUs using PyTorch/XLA.
This pipeline prioritizes:
- Flexibility
- Customizability
- Simplicity
- Ease-of-use
- Research-scale models (1b-10b parameters)
Features
Without touching the base pipeline, easy-torch-tpu allows you to:
- Define custom train step functions (optionally including parameter update logic).
- Implement new nn.Module-based models.
- Create optimizers with custom step logic (with auxiliary metric logging).
- Use custom dataloaders (based on collate functions).
- Define custom recursive module scanning and activation checkpointing.
- Define custom activation and parameter sharding configs (with FSDP).
- Save and load checkpoints with Hugging Face
- Log training metrics to Weights & Biases
and more...
Installation
-
Create a single-slice TPU VM with version
tpu-ubuntu2204-base -
Clone repo onto all VM devices (see cli-commands documentation):
git clone https://github.com/aklein4/easy-torch-tpu
- Run the installation script on all VM devices (see tpu_setup.sh for more info):
cd ~/easy-torch-tpu && . tpu_setup.sh <HF_ID> <HF_TOKEN> <WANDB_TOKEN>
Getting Started
The docs folder contains useful information about configuration, training, and customization.
Contributing
If you find a problem, have a suggestion, or want to contribute, open a GitHub issue.
Acknowledgements
Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC).