Skip to content

Commit 200a13c

Browse files
youkaichaotlrmchlsmth
authored andcommitted
[9/N] torch.compile LLM usage (vllm-project#10552)
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent 3e4efea commit 200a13c

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

tests/tpu/test_compilation.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import depyf
66

7-
from vllm.config import CompilationConfig, CompilationLevel
7+
from vllm.config import CompilationLevel
88

99
temp_dir = tempfile.mkdtemp()
1010
with depyf.prepare_debug(temp_dir):
@@ -34,8 +34,7 @@
3434
# all the control
3535
llm = LLM(model="google/gemma-2b",
3636
enforce_eager=True,
37-
compilation_config=CompilationConfig(
38-
level=CompilationLevel.DYNAMO_AS_IS))
37+
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
3938
outputs = llm.generate(prompts, sampling_params)
4039
for output, answer in zip(outputs, answers):
4140
prompt = output.prompt

vllm/entrypoints/llm.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import json
23
import warnings
34
from contextlib import contextmanager
45
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
@@ -9,6 +10,7 @@
910
from vllm import envs
1011
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
1112
BeamSearchSequence, get_beam_search_score)
13+
from vllm.config import CompilationConfig
1214
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
1315
TaskOption)
1416
from vllm.engine.llm_engine import LLMEngine
@@ -107,13 +109,16 @@ class LLM:
107109
hf_overrides: If a dictionary, contains arguments to be forwarded to the
108110
HuggingFace config. If a callable, it is called to update the
109111
HuggingFace config.
112+
compilation_config: Either an integer or a dictionary. If it is an integer,
113+
it is used as the level of compilation optimization. If it is a dictionary,
114+
it can specify the full compilation configuration.
110115
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
111116
:ref:`engine_args`)
112117
113118
Note:
114119
This class is intended to be used for offline inference. For online
115120
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
116-
"""
121+
""" # noqa
117122

118123
DEPRECATE_LEGACY: ClassVar[bool] = False
119124
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@@ -166,6 +171,7 @@ def __init__(
166171
# After positional args are removed, move this right below `model`
167172
task: TaskOption = "auto",
168173
override_pooler_config: Optional[PoolerConfig] = None,
174+
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
169175
**kwargs,
170176
) -> None:
171177
'''
@@ -178,6 +184,12 @@ def __init__(
178184
if "disable_log_stats" not in kwargs:
179185
kwargs["disable_log_stats"] = True
180186

187+
if compilation_config is not None:
188+
compilation_config_instance = CompilationConfig.from_cli(
189+
json.dumps(compilation_config))
190+
else:
191+
compilation_config_instance = None
192+
181193
engine_args = EngineArgs(
182194
model=model,
183195
task=task,
@@ -202,6 +214,7 @@ def __init__(
202214
hf_overrides=hf_overrides,
203215
mm_processor_kwargs=mm_processor_kwargs,
204216
override_pooler_config=override_pooler_config,
217+
compilation_config=compilation_config_instance,
205218
**kwargs,
206219
)
207220
# Logic to switch between engines is done at runtime instead of import

0 commit comments

Comments
 (0)