From 61f4a93d1490f285b0dd3a536dd85a9f3f18ddd9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 3 Sep 2024 18:35:33 -0700 Subject: [PATCH] [TPU][Bugfix] Use XLA rank for persistent cache path (#8137) --- docs/source/getting_started/tpu-installation.rst | 2 +- vllm/worker/tpu_worker.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index d0c2498d8849e..217028839e347 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -59,7 +59,7 @@ First, install the dependencies: $ export DATE="20240828" $ export TORCH_VERSION="2.5.0" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl - $ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl $ # Install JAX and Pallas. $ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 44fa3aed5816d..9e0c522cee453 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -102,8 +102,9 @@ def init_device(self) -> None: # NOTE(woosuk): Set per-rank cache path since different ranks # can have slightly different XLA graphs. world_size = self.parallel_config.world_size + rank = xr.global_ordinal() per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{self.rank}") + f"tp{world_size}_rank{rank}") xr.initialize_cache(per_rank_path, readonly=False) def load_model(self):