Perceiver AR is an autoregressive, modality-agnostic architecture which uses cross-attention to map long-range inputs to a small number of latents while also maintaining end-to-end causal masking. Perceiver AR can directly attend to over a hundred thousand tokens, enabling practical long-context density estimation without the need for hand-crafted sparsity patterns or memory mechanisms.
For more details, see our ICML 2022 paper: https://arxiv.org/abs/2202.07765
An implementation is also available for T5X/Flaxformer.
First, install dependencies following these instructions:
- Create a virtual env:
python3 -m venv ~/.venv/perceiver-ar
- Switch to the virtual env:
source ~/.venv/perceiver-ar/bin/activate
- Follow instructions for installing JAX on your platform: https://github.com/google/jax#installation
- Install other dependencies:
pip install -r requirements.txt
As an example of the model, a 32-position version of the Copy Task from our paper can be trained using only a local CPU.
PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
--config=perceiver_ar/experiment.py:random_mirrored_32
By default, checkpoints and events will be saved to /tmp/perceiver_ar
.
Training metrics will be periodically written to Tensorboard event files which can be viewed using:
tensorboard --logdir /tmp/perceiver_ar/
During training, use Ctrl+C to save a checkpoint and Ctrl+\ to save a checkpoint and exit.
To evaluate the latest saved checkpoint:
CHECKPOINTS="/tmp/perceiver_ar"
LATEST_CHECKPOINT="${CHECKPOINTS}/models/latest/$(ls -tr ${CHECKPOINTS}/models/latest/ | tail -n 1)"
echo "Evaluating ${LATEST_CHECKPOINT}"
PYTHONPATH=.::$PYTHONPATH python perceiver_ar/experiment.py \
--config=perceiver_ar/experiment.py:random_mirrored_32 \
--jaxline_mode=eval \
--config.one_off_evaluate=True \
--config.restore_path="${LATEST_CHECKPOINT}"
Results will be written to the console and can also be viewed from Tensorboard.
To run inference in a local Jupyter notebook:
jupyter notebook
Load inference.ipynb
and follow the instructions in the notebook.
The notebook also supports loading a pretrained checkpoint for the 131k-position copy task used in our paper. This model is fairly large, so inferring more than a few positions will likely require a large accelerator. The notebook has been tested to run on a GCP TPU VM using a TPU v3-8.
To run all unit tests:
pytest
This is not an officially supported Google product.