Skip to content

Commit

Permalink
[V1] TPU support - refactored
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
  • Loading branch information
alexm-redhat committed Feb 13, 2025
1 parent 14ecab5 commit 4cb766e
Show file tree
Hide file tree
Showing 10 changed files with 1,780 additions and 25 deletions.
6 changes: 3 additions & 3 deletions examples/offline_inference/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
15 changes: 11 additions & 4 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked
Expand Down Expand Up @@ -67,14 +67,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
more_args = []

# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]

run_test(more_args)


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
2 changes: 1 addition & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
"vllm.v1.worker.gpu_worker.GPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class _Backend(enum.Enum):
TRITON_MLA = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down
54 changes: 38 additions & 16 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend
Expand Down Expand Up @@ -33,22 +34,28 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.PALLAS:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
return "tpu"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
return not envs.VLLM_USE_V1

@classmethod
def inference_mode(cls):
Expand All @@ -63,22 +70,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION

# TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

if compilation_config.backend == "":
compilation_config.backend = "openxla"

assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"

assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
Expand All @@ -88,8 +91,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"

# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False

assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
Loading

0 comments on commit 4cb766e

Please # to comment.