Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Add load generation config from model #11164

Merged
merged 20 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,12 @@ def test_serving_chat_could_load_correct_generation_config():

assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

# Test When temperature==0.0
req.temperature = 0.0

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ def to_beam_search_params(
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature or default_sampling_params.get(
"temperature", 0.0)

if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0.0)

return BeamSearchParams(
beam_width=n,
Expand All @@ -389,7 +390,7 @@ def to_sampling_params(
repetition_penalty = (default_sampling_params.get(
"repetition_penalty", 1.0))
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0.7)
temperature = default_sampling_params.get("temperature", 1.0)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get("top_p", 1.0)
if (top_k := self.top_k) is None:
Expand Down Expand Up @@ -705,7 +706,7 @@ def to_beam_search_params(
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature or default_sampling_params.get(
"temperature", 0.0)
"temperature", 1.0)

return BeamSearchParams(
beam_width=n,
Expand Down
Loading