Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[V1][Core] Generic mechanism for handling engine utility methods #13060

Merged
merged 13 commits into from
Feb 19, 2025
Merged
2 changes: 1 addition & 1 deletion tests/lora/test_add_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
]
for tokenizer_file in tokenizer_files:
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
del_path.unlink()
del_path.unlink(missing_ok=True)


@pytest.fixture(autouse=True)
Expand Down
57 changes: 49 additions & 8 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
import time
import uuid
from typing import Dict, List
from contextlib import ExitStack
from typing import Dict, List, Optional

import pytest
from transformers import AutoTokenizer
Expand All @@ -14,7 +15,9 @@
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor

if not current_platform.is_cuda():
Expand Down Expand Up @@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):

while True:
engine_core_outputs = await client.get_output_async().outputs
engine_core_outputs = (await client.get_output_async()).outputs

if len(engine_core_outputs) == 0:
break
Expand All @@ -78,14 +81,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break


# Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}")
if err_msg is not None:
raise ValueError(err_msg)
return msg


@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)

engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
Expand Down Expand Up @@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

client.abort_requests([request.request_id])

if multiprocessing_mode:
"""Utility method invocation"""

@fork_new_process_for_each_test
@pytest.mark.asyncio
core_client: SyncMPClient = client

result = core_client._call_utility("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"


@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch):

with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)

engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
Expand All @@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
executor_class=executor_class,
log_stats=True,
)
after.callback(client.shutdown)

MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
Expand Down Expand Up @@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""

core_client: AsyncMPClient = client

result = await core_client._call_utility_async("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"
24 changes: 18 additions & 6 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
import time
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import msgspec

Expand Down Expand Up @@ -106,6 +106,18 @@ def finished(self) -> bool:
return self.finish_reason is not None


class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

call_id: int

# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None


class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand All @@ -116,10 +128,12 @@ class EngineCoreOutputs(
# e.g. columnwise layout

# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: Optional[SchedulerStats]
outputs: List[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0

utility_output: Optional[UtilityOutput] = None

def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
Expand All @@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
ADD_LORA = b'\x04'
UTILITY = b'\x02'
49 changes: 33 additions & 16 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import threading
import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Type

import msgspec
import psutil
import zmq
import zmq.asyncio
Expand All @@ -21,7 +23,7 @@
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -330,19 +332,39 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
elif request_type == EngineCoreRequestType.ADD_LORA:
self.model_executor.add_lora(request)
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))

@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
Comment on lines +358 to +360
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aesthetically, a helper function might clean up this code, i.e.

return tuple(msgspec.convert(v, type=p.annotation)
                      if needs_conversion(v,p) else v
                      for v, p in zip(args, arg_types))

however this is the engine core, perhaps multiple helper-function calls would be too costly.

for v, p in zip(args, arg_types))

def process_input_socket(self, input_path: str):
"""Input socket IO thread."""

# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
add_lora_decoder = MsgpackDecoder(LoRARequest)
generic_decoder = MsgpackDecoder()

with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
Expand All @@ -352,14 +374,9 @@ def process_input_socket(self, input_path: str):
request_type = EngineCoreRequestType(bytes(type_frame.buffer))

# Deserialize the request data.
decoder = None
if request_type == EngineCoreRequestType.ADD:
decoder = add_request_decoder
elif request_type == EngineCoreRequestType.ADD_LORA:
decoder = add_lora_decoder
else:
decoder = generic_decoder

decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)

# Push to input queue for core busy loop.
Expand Down
Loading