Skip to content

Commit 1f01fee

Browse files
njhillkerthcet
authored andcommitted
[V1][Core] Generic mechanism for handling engine utility (vllm-project#13060)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 03d1596 commit 1f01fee

File tree

5 files changed

+197
-56
lines changed

5 files changed

+197
-56
lines changed

tests/lora/test_add_lora.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
4141
]
4242
for tokenizer_file in tokenizer_files:
4343
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
44-
del_path.unlink()
44+
del_path.unlink(missing_ok=True)
4545

4646

4747
@pytest.fixture(autouse=True)

tests/v1/engine/test_engine_core_client.py

+49-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import asyncio
44
import time
55
import uuid
6-
from typing import Dict, List
6+
from contextlib import ExitStack
7+
from typing import Dict, List, Optional
78

89
import pytest
910
from transformers import AutoTokenizer
@@ -14,7 +15,9 @@
1415
from vllm.platforms import current_platform
1516
from vllm.usage.usage_lib import UsageContext
1617
from vllm.v1.engine import EngineCoreRequest
17-
from vllm.v1.engine.core_client import EngineCoreClient
18+
from vllm.v1.engine.core import EngineCore
19+
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
20+
SyncMPClient)
1821
from vllm.v1.executor.abstract import Executor
1922

2023
if not current_platform.is_cuda():
@@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
6366
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
6467

6568
while True:
66-
engine_core_outputs = await client.get_output_async().outputs
69+
engine_core_outputs = (await client.get_output_async()).outputs
6770

6871
if len(engine_core_outputs) == 0:
6972
break
@@ -78,14 +81,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
7881
break
7982

8083

84+
# Dummy utility function to monkey-patch into engine core.
85+
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
86+
print(f"echo util function called: {msg}, {err_msg}")
87+
if err_msg is not None:
88+
raise ValueError(err_msg)
89+
return msg
90+
91+
8192
@fork_new_process_for_each_test
8293
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
8394
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
8495

8596
with monkeypatch.context() as m:
8697
m.setenv("VLLM_USE_V1", "1")
8798

88-
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
99+
# Monkey-patch core engine utility function to test.
100+
m.setattr(EngineCore, "echo", echo, raising=False)
101+
102+
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
89103
vllm_config = engine_args.create_engine_config(
90104
UsageContext.UNKNOWN_CONTEXT)
91105
executor_class = Executor.get_class(vllm_config)
@@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
147161

148162
client.abort_requests([request.request_id])
149163

164+
if multiprocessing_mode:
165+
"""Utility method invocation"""
150166

151-
@fork_new_process_for_each_test
152-
@pytest.mark.asyncio
167+
core_client: SyncMPClient = client
168+
169+
result = core_client._call_utility("echo", "testarg")
170+
assert result == "testarg"
171+
172+
with pytest.raises(Exception) as e_info:
173+
core_client._call_utility("echo", None, "help!")
174+
175+
assert str(e_info.value) == "Call to echo method failed: help!"
176+
177+
178+
@pytest.mark.asyncio(loop_scope="function")
153179
async def test_engine_core_client_asyncio(monkeypatch):
154180

155-
with monkeypatch.context() as m:
181+
with monkeypatch.context() as m, ExitStack() as after:
156182
m.setenv("VLLM_USE_V1", "1")
157183

158-
engine_args = EngineArgs(model=MODEL_NAME)
184+
# Monkey-patch core engine utility function to test.
185+
m.setattr(EngineCore, "echo", echo, raising=False)
186+
187+
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
159188
vllm_config = engine_args.create_engine_config(
160189
usage_context=UsageContext.UNKNOWN_CONTEXT)
161190
executor_class = Executor.get_class(vllm_config)
@@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
166195
executor_class=executor_class,
167196
log_stats=True,
168197
)
198+
after.callback(client.shutdown)
169199

170200
MAX_TOKENS = 20
171201
params = SamplingParams(max_tokens=MAX_TOKENS)
@@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
204234
else:
205235
assert len(outputs[req_id]) == MAX_TOKENS, (
206236
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
237+
"""Utility method invocation"""
238+
239+
core_client: AsyncMPClient = client
240+
241+
result = await core_client._call_utility_async("echo", "testarg")
242+
assert result == "testarg"
243+
244+
with pytest.raises(Exception) as e_info:
245+
await core_client._call_utility_async("echo", None, "help!")
246+
247+
assert str(e_info.value) == "Call to echo method failed: help!"

vllm/v1/engine/__init__.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import enum
44
import time
5-
from typing import List, Optional, Union
5+
from typing import Any, List, Optional, Union
66

77
import msgspec
88

@@ -106,6 +106,18 @@ def finished(self) -> bool:
106106
return self.finish_reason is not None
107107

108108

109+
class UtilityOutput(
110+
msgspec.Struct,
111+
array_like=True, # type: ignore[call-arg]
112+
gc=False): # type: ignore[call-arg]
113+
114+
call_id: int
115+
116+
# Non-None implies the call failed, result should be None.
117+
failure_message: Optional[str] = None
118+
result: Any = None
119+
120+
109121
class EngineCoreOutputs(
110122
msgspec.Struct,
111123
array_like=True, # type: ignore[call-arg]
@@ -116,10 +128,12 @@ class EngineCoreOutputs(
116128
# e.g. columnwise layout
117129

118130
# [num_reqs]
119-
outputs: List[EngineCoreOutput]
120-
scheduler_stats: Optional[SchedulerStats]
131+
outputs: List[EngineCoreOutput] = []
132+
scheduler_stats: Optional[SchedulerStats] = None
121133
timestamp: float = 0.0
122134

135+
utility_output: Optional[UtilityOutput] = None
136+
123137
def __post_init__(self):
124138
if self.timestamp == 0.0:
125139
self.timestamp = time.monotonic()
@@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
132146
"""
133147
ADD = b'\x00'
134148
ABORT = b'\x01'
135-
PROFILE = b'\x02'
136-
RESET_PREFIX_CACHE = b'\x03'
137-
ADD_LORA = b'\x04'
149+
UTILITY = b'\x02'

vllm/v1/engine/core.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import threading
66
import time
77
from concurrent.futures import Future
8+
from inspect import isclass, signature
89
from multiprocessing.connection import Connection
910
from typing import Any, List, Optional, Tuple, Type
1011

12+
import msgspec
1113
import psutil
1214
import zmq
1315
import zmq.asyncio
@@ -21,7 +23,7 @@
2123
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
2224
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
2325
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
24-
EngineCoreRequestType)
26+
EngineCoreRequestType, UtilityOutput)
2527
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
2628
from vllm.v1.executor.abstract import Executor
2729
from vllm.v1.outputs import ModelRunnerOutput
@@ -330,19 +332,39 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
330332
self.add_request(request)
331333
elif request_type == EngineCoreRequestType.ABORT:
332334
self.abort_requests(request)
333-
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
334-
self.reset_prefix_cache()
335-
elif request_type == EngineCoreRequestType.PROFILE:
336-
self.model_executor.profile(request)
337-
elif request_type == EngineCoreRequestType.ADD_LORA:
338-
self.model_executor.add_lora(request)
335+
elif request_type == EngineCoreRequestType.UTILITY:
336+
call_id, method_name, args = request
337+
output = UtilityOutput(call_id)
338+
try:
339+
method = getattr(self, method_name)
340+
output.result = method(
341+
*self._convert_msgspec_args(method, args))
342+
except BaseException as e:
343+
logger.exception("Invocation of %s method failed", method_name)
344+
output.failure_message = (f"Call to {method_name} method"
345+
f" failed: {str(e)}")
346+
self.output_queue.put_nowait(
347+
EngineCoreOutputs(utility_output=output))
348+
349+
@staticmethod
350+
def _convert_msgspec_args(method, args):
351+
"""If a provided arg type doesn't match corresponding target method
352+
arg type, try converting to msgspec object."""
353+
if not args:
354+
return args
355+
arg_types = signature(method).parameters.values()
356+
assert len(args) <= len(arg_types)
357+
return tuple(
358+
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
359+
and issubclass(p.annotation, msgspec.Struct)
360+
and not isinstance(v, p.annotation) else v
361+
for v, p in zip(args, arg_types))
339362

340363
def process_input_socket(self, input_path: str):
341364
"""Input socket IO thread."""
342365

343366
# Msgpack serialization decoding.
344367
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
345-
add_lora_decoder = MsgpackDecoder(LoRARequest)
346368
generic_decoder = MsgpackDecoder()
347369

348370
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
@@ -352,14 +374,9 @@ def process_input_socket(self, input_path: str):
352374
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
353375

354376
# Deserialize the request data.
355-
decoder = None
356-
if request_type == EngineCoreRequestType.ADD:
357-
decoder = add_request_decoder
358-
elif request_type == EngineCoreRequestType.ADD_LORA:
359-
decoder = add_lora_decoder
360-
else:
361-
decoder = generic_decoder
362-
377+
decoder = add_request_decoder if (
378+
request_type
379+
== EngineCoreRequestType.ADD) else generic_decoder
363380
request = decoder.decode(data_frame.buffer)
364381

365382
# Push to input queue for core busy loop.

0 commit comments

Comments
 (0)