forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_engine_core_client.py
207 lines (160 loc) · 6.57 KB
/
test_engine_core_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import asyncio
import time
import uuid
from typing import Dict, List
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_data=None,
mm_placeholders=None,
mm_processor_kwargs=None,
sampling_params=params,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
)
def loop_until_done(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = client.get_output()
if len(engine_core_outputs) == 0:
break
all_finished = True
for out in engine_core_outputs:
outputs[out.request_id].append(out)
if not out.finished:
all_finished = False
if all_finished:
break
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = await client.get_output_async()
if len(engine_core_outputs) == 0:
break
all_finished = True
for out in engine_core_outputs:
outputs[out.request_id].append(out)
if not out.finished:
all_finished = False
if all_finished:
break
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
client.add_request(request)
time.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
loop_until_done(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Note: this code pathway will only work for multiprocessing
# since we have to call get_output() explicitly
# Add requests to the engine.
for idx, request in enumerate(requests):
client.add_request(request)
time.sleep(0.01)
if idx % 2 == 0:
client.abort_requests([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
loop_until_done(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Abort after request is finished."""
# Note: this code pathway will only work for multiprocessing
# since we have to call get_output() explicitly
request = requests[0]
client.add_request(request)
time.sleep(10.)
client.abort_requests([request.request_id])
# Shutdown the client.
client.shutdown()
@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=True,
asyncio_mode=True,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
# Shutdown the client.
client.shutdown()