Skip to content

Commit c99b137

Browse files
liuyanyiDarkLight1337
authored and
Ubuntu
committed
[Feature] Add load generation config from model (vllm-project#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent 725105b commit c99b137

10 files changed

+307
-74
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from vllm import LLM
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
11+
# Create an LLM with built-in default generation config.
12+
# The generation config is set to None by default to keep
13+
# the behavior consistent with the previous version.
14+
# If you want to use the default generation config from the model,
15+
# you should set the generation_config to "auto".
16+
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")
17+
18+
# Load the default sampling parameters from the model.
19+
sampling_params = llm.get_default_sampling_params()
20+
# Modify the sampling parameters if needed.
21+
sampling_params.temperature = 0.5
22+
23+
# Generate texts from the prompts. The output is a list of RequestOutput objects
24+
# that contain the prompt, generated text, and other information.
25+
outputs = llm.generate(prompts, sampling_params)
26+
# Print the outputs.
27+
for output in outputs:
28+
prompt = output.prompt
29+
generated_text = output.outputs[0].text
30+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/entrypoints/openai/test_serving_chat.py

+61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from contextlib import suppress
33
from dataclasses import dataclass
4+
from typing import Optional
45
from unittest.mock import MagicMock
56

67
from vllm.config import MultiModalConfig
@@ -31,6 +32,10 @@ class MockModelConfig:
3132
multimodal_config = MultiModalConfig()
3233
hf_config = MockHFConfig()
3334
logits_processor_pattern = None
35+
diff_sampling_param: Optional[dict] = None
36+
37+
def get_diff_sampling_param(self):
38+
return self.diff_sampling_param or {}
3439

3540

3641
@dataclass
@@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
9499
asyncio.run(serving_chat.create_chat_completion(req))
95100

96101
assert mock_engine.generate.call_args.args[1].max_tokens == 10
102+
103+
104+
def test_serving_chat_could_load_correct_generation_config():
105+
106+
mock_model_config = MockModelConfig()
107+
mock_model_config.diff_sampling_param = {
108+
"temperature": 0.5,
109+
"repetition_penalty": 1.05
110+
}
111+
112+
mock_engine = MagicMock(spec=MQLLMEngineClient)
113+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
114+
mock_engine.errored = False
115+
116+
# Initialize the serving chat
117+
serving_chat = OpenAIServingChat(mock_engine,
118+
mock_model_config,
119+
BASE_MODEL_PATHS,
120+
response_role="assistant",
121+
chat_template=CHAT_TEMPLATE,
122+
chat_template_content_format="auto",
123+
lora_modules=None,
124+
prompt_adapters=None,
125+
request_logger=None)
126+
req = ChatCompletionRequest(
127+
model=MODEL_NAME,
128+
messages=[{
129+
"role": "user",
130+
"content": "what is 1+1?"
131+
}],
132+
guided_decoding_backend="outlines",
133+
)
134+
135+
with suppress(Exception):
136+
asyncio.run(serving_chat.create_chat_completion(req))
137+
138+
assert mock_engine.generate.call_args.args[1].temperature == 0.5
139+
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
140+
141+
# Test the param when user set it
142+
req.temperature = 0.1
143+
144+
with suppress(Exception):
145+
asyncio.run(serving_chat.create_chat_completion(req))
146+
147+
assert mock_engine.generate.call_args.args[1].temperature == 0.1
148+
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
149+
150+
# Test When temperature==0.0
151+
req.temperature = 0.0
152+
153+
with suppress(Exception):
154+
asyncio.run(serving_chat.create_chat_completion(req))
155+
156+
assert mock_engine.generate.call_args.args[1].temperature == 0.0
157+
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

vllm/config.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.transformers_utils.config import (
2828
ConfigFormat, get_config, get_hf_image_processor_config,
2929
get_hf_text_config, get_pooling_config,
30-
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
30+
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
31+
try_get_generation_config, uses_mrope)
3132
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
3233
get_cpu_memory, print_warning_once, random_uuid,
3334
resolve_obj_by_qualname)
@@ -160,6 +161,7 @@ class ModelConfig:
160161
logits processor qualified names that can be passed with the
161162
`logits_processors` extra completion argument. Defaults to None,
162163
which allows no processors.
164+
generation_config: Configuration parameter file for generation.
163165
"""
164166

165167
def compute_hash(self) -> str:
@@ -218,7 +220,8 @@ def __init__(self,
218220
disable_mm_preprocessor_cache: bool = False,
219221
override_neuron_config: Optional[Dict[str, Any]] = None,
220222
override_pooler_config: Optional["PoolerConfig"] = None,
221-
logits_processor_pattern: Optional[str] = None) -> None:
223+
logits_processor_pattern: Optional[str] = None,
224+
generation_config: Optional[str] = None) -> None:
222225
self.model = model
223226
self.tokenizer = tokenizer
224227
self.tokenizer_mode = tokenizer_mode
@@ -348,6 +351,8 @@ def __init__(self,
348351
self.pooler_config = self._init_pooler_config(override_pooler_config)
349352
self.logits_processor_pattern = logits_processor_pattern
350353

354+
self.generation_config = generation_config
355+
351356
self._verify_quantization()
352357
self._verify_cuda_graph()
353358
self._verify_bnb_config()
@@ -813,6 +818,56 @@ def get_multimodal_config(self) -> "MultiModalConfig":
813818

814819
return self.multimodal_config
815820

821+
def try_get_generation_config(self) -> Dict[str, Any]:
822+
if self.generation_config is None or self.generation_config == "auto":
823+
config = try_get_generation_config(
824+
self.model,
825+
trust_remote_code=self.trust_remote_code,
826+
revision=self.revision,
827+
)
828+
else:
829+
config = try_get_generation_config(
830+
self.generation_config,
831+
trust_remote_code=self.trust_remote_code,
832+
)
833+
834+
if config is None:
835+
return {}
836+
837+
return config.to_diff_dict()
838+
839+
def get_diff_sampling_param(self) -> Dict[str, Any]:
840+
"""
841+
This method returns a dictionary containing the parameters
842+
that differ from the default sampling parameters, but only
843+
if `generation_config` is set. If `generation_config` is not
844+
set, an empty dictionary is returned.
845+
846+
Returns:
847+
Dict[str, Any]: A dictionary with the differing sampling
848+
parameters if `generation_config` is set, otherwise an
849+
empty dictionary.
850+
"""
851+
if self.generation_config is None:
852+
# When generation_config is not set
853+
return {}
854+
config = self.try_get_generation_config()
855+
available_params = [
856+
"repetition_penalty",
857+
"temperature",
858+
"top_k",
859+
"top_p",
860+
"min_p",
861+
]
862+
if any(p in config for p in available_params):
863+
diff_sampling_param = {
864+
p: config.get(p)
865+
for p in available_params if config.get(p) is not None
866+
}
867+
else:
868+
diff_sampling_param = {}
869+
return diff_sampling_param
870+
816871
@property
817872
def is_encoder_decoder(self) -> bool:
818873
"""Extract the HF encoder/decoder model flag."""

vllm/engine/arg_utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class EngineArgs:
197197

198198
kv_transfer_config: Optional[KVTransferConfig] = None
199199

200+
generation_config: Optional[str] = None
201+
200202
def __post_init__(self):
201203
if not self.tokenizer:
202204
self.tokenizer = self.model
@@ -942,6 +944,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
942944
default="auto",
943945
help='The worker class to use for distributed execution.')
944946

947+
parser.add_argument(
948+
"--generation-config",
949+
type=nullable_str,
950+
default=None,
951+
help="The folder path to the generation config. "
952+
"Defaults to None, will use the default generation config in vLLM. "
953+
"If set to 'auto', the generation config will be automatically "
954+
"loaded from model. If set to a folder path, the generation config "
955+
"will be loaded from the specified folder path.")
956+
945957
return parser
946958

947959
@classmethod
@@ -985,7 +997,8 @@ def create_model_config(self) -> ModelConfig:
985997
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
986998
override_neuron_config=self.override_neuron_config,
987999
override_pooler_config=self.override_pooler_config,
988-
logits_processor_pattern=self.logits_processor_pattern)
1000+
logits_processor_pattern=self.logits_processor_pattern,
1001+
generation_config=self.generation_config)
9891002

9901003
def create_load_config(self) -> LoadConfig:
9911004
return LoadConfig(

vllm/engine/llm_engine.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from contextlib import contextmanager
66
from dataclasses import dataclass
77
from functools import partial
8-
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
9-
Iterable, List, Mapping, NamedTuple, Optional)
8+
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
9+
List, Mapping, NamedTuple, Optional)
1010
from typing import Sequence as GenericSequence
1111
from typing import Set, Type, Union, cast, overload
1212

@@ -52,7 +52,6 @@
5252
SequenceGroupOutput, SequenceStatus)
5353
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
5454
init_tracer)
55-
from vllm.transformers_utils.config import try_get_generation_config
5655
from vllm.transformers_utils.detokenizer import Detokenizer
5756
from vllm.transformers_utils.tokenizer import AnyTokenizer
5857
from vllm.transformers_utils.tokenizer_group import (
@@ -65,20 +64,6 @@
6564
logger = init_logger(__name__)
6665
_LOCAL_LOGGING_INTERVAL_SEC = 5
6766

68-
69-
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
70-
config = try_get_generation_config(
71-
model_config.model,
72-
trust_remote_code=model_config.trust_remote_code,
73-
revision=model_config.revision,
74-
)
75-
76-
if config is None:
77-
return {}
78-
79-
return config.to_diff_dict()
80-
81-
8267
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
8368
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
8469

@@ -274,8 +259,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
274259
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
275260

276261
self.seq_counter = Counter()
277-
self.generation_config_fields = _load_generation_config_dict(
278-
self.model_config)
262+
self.generation_config_fields = (
263+
self.model_config.try_get_generation_config())
279264

280265
self.input_preprocessor = InputPreprocessor(self.model_config,
281266
self.tokenizer,

vllm/entrypoints/llm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258258
else:
259259
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
260260

261+
def get_default_sampling_params(self) -> SamplingParams:
262+
diff_sampling_param = (
263+
self.llm_engine.model_config.get_diff_sampling_param())
264+
if diff_sampling_param:
265+
return SamplingParams.from_optional(**diff_sampling_param)
266+
return SamplingParams()
267+
261268
@overload
262269
def generate(
263270
self,
@@ -441,7 +448,7 @@ def generate(
441448

442449
if sampling_params is None:
443450
# Use default sampling params.
444-
sampling_params = SamplingParams()
451+
sampling_params = self.get_default_sampling_params()
445452

446453
self._validate_and_add_requests(
447454
prompts=parsed_prompts,

0 commit comments

Comments
 (0)