torchprime is a reference implementation for training PyTorch models on TPU. It
is designed to showcase best practices for large-scale, high-performance model
training using torch_xla
(project), with
minimal changes to model code. It aims to demystify training on XLA-based
accelerators, providing clear patterns and best practices to help the PyTorch
community unlock top performance and efficiency on Google Cloud TPUs.
torchprime is under active development, and we're eager for feedback and input from the PyTorch community.
Before installing torchprime, you will need to first install torch_xla following its respective project README.
Install torchprime
:
git clone https://github.com/AI-Hypercomputer/torchprime.git
cd torchprime
pip install -e '.[dev]'
Here is a simple example of training on a single TPU VM. Train Llama 3 8B using torch_xla:
export HF_TOKEN='...your huggingface token...'
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py
Refer to README.md
in torchprime/torch_xla_models
for more details.
torchprime uses hydra to read configurations (e.g. model name, batch
size) from the command line and .yaml
files.
In the torch_xla_models
directory, you'll find a configs/default.yaml
. That
specifies the default configuration for the trainer. You may override configs on
the command line with a key=value
syntax. For example, the following command
will train Mixtral 8x7B with a global batch size of 256, and set the FSDP SPMD
ICI mesh axis length to 64:
python3 torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
global_batch_size=256 \
ici_mesh.fsdp=64
You may refer to the hydra docs for other ways to specify configs.
torchprime uses xpk as the standard path for iterating on distributed training code.
First teach torchprime about the XPK cluster it is using, the artifact storage location, etc. You only need to do this on first clone or when switching to a different topology or cluster. Example:
tp use \
--cluster <XPK CLUSTER NAME> \
--project my-gcp-project \
--zone us-east5-b \
--num-slices 1 \
--tpu-type v6e-256 \
--artifact-dir gs://bucket/dir
torchprime natively supports multi-slice or multi-pod training.
--num-slices
specifies the number of slices used by the workload.
--tpu-type
specifies the accelerator type in each slice.
To do multi-pod training, simply specify a --tpu-type
that is as big as a
pod.
After configuring the cluster, prepend tp run
to a particular Python file you
would like to run remotely, including arguments, e.g.
# Train Llama 3.0 8B on 256 chips
tp run torchprime/torch_xla_models/train.py \
model=llama-3-8b \
global_batch_size=256 \
ici_mesh.fsdp=256
tp run
will broadcast the specified command to all VMs in the XPK cluster,
which is the convention for running SPMD distributed workloads. See tp run --help
for more advanced features.
tp run
will pick up these environment variables locally and proxy them to the
distributed workload, if found:
HF_TOKEN
: HuggingFace tokenXLA_IR_DEBUG
: torch_xla debugging flagXLA_HLO_DEBUG
: torch_xla debugging flagLIBTPU_INIT_ARGS
: XLA flags that affect compilation and execution behavior
Besides forwarding your command line arguments, tp run
will add:
profile_dir=[...]
: path to a profile directory accessible by the workload
Below are the status of various models. There are five stages for each model:
- TODO: We need to implement the model.
- Implemented: The model runs either a training or an inference step.
- Optimized: We found the best scaling configuration for the model on one or more hardware. One-off performance data is available.
- Convergence: We tested that the training loss converges to a reasonable value, or that the loss curve tracks an existing reference if exists.
- Production: Not only is the model optimized and converges, its performance is also continuously monitored. This is a good state for using the model in production.
All implemented models will at least have unit tests to verify basic numerical correctness, and the convergence verification stage serves as an additional correctness guarantee.
If a model is implemented, you'll also find a training recipe linked from the checkmark emoji in the table. If a model is optimized, you'll also find MFU numbers linked from the table. Note that a model may continue to receive ongoing optimization thereafter.
Model | Implemented | Optimized | Converges |
---|---|---|---|
Llama 3.0 8B | ✅ | ✅ | TODO |
Llama 3.1 8B | ✅ | TODO | TODO |
Llama 3.1 70B | TODO | TODO | TODO |
Llama 3.1 405B | ✅ | ✅ | TODO |
Llama 4 Scout | TODO | TODO | TODO |
Llama 4 Maverick | TODO | TODO | TODO |
Mixtral 8x7B | ✅ | TODO | TODO |
Mixtral 8x22B | TODO | TODO | TODO |
DeepSeek V3/R1 | TODO | TODO | TODO |
Stable Diffusion 2.0 | TODO | TODO | TODO |
Stable Diffusion 2.1 | TODO | TODO | TODO |
This repo will contain a set of reference models that we have optimized and runs well on TPU. The best performing scaling configuration (parallelism techniques, checkpointing, etc.) for a model on various hardwares will be provided for ease of reproducibility.
docs
contains guides for optimizing performance and debugging issues.
torchprime/launcher
contains scripts to train a model on a large TPU cluster.
torchprime/data
contains dataset and data loading utilities.
torchprime/torch_xla_models
contains model implementations using torch_xla
.
torchprime/experimental/torchax_models
contains model implementations using
torchax
.
Finally, each model may also provide a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to showcase what changes we have done to make it performant on TPU. The original version is not expected to be run.
Contributions are welcome! Please feel free to submit a pull request.
When developing, use pip install -e '.[dev]'
to install dev dependencies such
as linter and formatter.
pytest
tp -i test ... # replace with path to tests/directories
ruff format
ruff check [--fix]
You can install a Ruff VSCode plugin to check errors and format files from the editor.
You can also run locally without XPK with docker. When running inside the docker container, it will use the same dependencies and build process as used in the XPK approach, improving the hermeticity and reliability.
tp docker-run torchprime/torch_xla_models/train.py
This will run the torchprime docker image locally. You can also add --use-hf
to run HuggingFace model locally.
tp docker-run --use-hf torchprime/hf_models/train.py
torchprime supports running with user specified torch and torch_xla wheels
placed under local_dist/
directory. The wheel will be automatically installed
in the docker image when use tp run
command. To use the wheel, add flag
--use-local-wheel
to tp run
command:
tp run --use-local-wheel torchprime/hf_models/train.py
The wheels should be built inside a PyTorch/XLA development docker image or the PyTorch/XLA VSCode Dev Container to minimize compatibility issues.
This project is licensed under the New BSD License - see the LICENSE file for details.
For more information on PyTorch/XLA, visit the official documentation.