diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c324d62d44d79..83fdef16ef5cb 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -124,7 +124,10 @@ def __init__( self.megacore_mode = None tpu_env = torch_xla.tpu.get_tpu_env() - tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE") + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None tpu_type = tpu_type.lower() if "lite" not in tpu_type: