Skip to content

Commit 1e9605b

Browse files
committed
Tests may use outdated rope scalings, so we patch them as well
1 parent 530d8a0 commit 1e9605b

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

vllm/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from vllm.tracing import is_otel_available, otel_import_error_traceback
1616
from vllm.transformers_utils.config import (ConfigFormat, get_config,
1717
get_hf_image_processor_config,
18-
get_hf_text_config)
18+
get_hf_text_config,
19+
patch_rope_scaling)
1920
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
2021
is_hip, is_neuron, is_openvino, is_xpu,
2122
print_warning_once)
@@ -1721,6 +1722,9 @@ def _get_and_verify_max_len(
17211722
default_max_len)
17221723
derived_max_model_len = default_max_len
17231724

1725+
# Backwards compatibility
1726+
patch_rope_scaling(hf_config)
1727+
17241728
rope_scaling = getattr(hf_config, "rope_scaling", None)
17251729
if rope_scaling is not None:
17261730
rope_type = rope_scaling["rope_type"]

vllm/model_executor/layers/rotary_embedding.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch.nn as nn
2929

3030
from vllm.model_executor.custom_op import CustomOp
31+
from vllm.transformers_utils.config import patch_rope_scaling_dict
3132

3233

3334
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -901,6 +902,9 @@ def get_rope(
901902
if dtype is None:
902903
dtype = torch.get_default_dtype()
903904
if rope_scaling is not None:
905+
# Backwards compatibility
906+
patch_rope_scaling_dict(rope_scaling)
907+
904908
# Transforms every value that is a list into a tuple for caching calls
905909
rope_scaling_tuple = {
906910
k: tuple(v) if isinstance(v, list) else v
@@ -920,8 +924,7 @@ def get_rope(
920924
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
921925
is_neox_style, dtype)
922926
else:
923-
scaling_type = rope_scaling[
924-
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
927+
scaling_type = rope_scaling["rope_type"]
925928

926929
if scaling_type == "llama3":
927930
scaling_factor = rope_scaling["factor"]

vllm/transformers_utils/config.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,34 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
9090
return False
9191

9292

93+
def patch_rope_scaling(config: PretrainedConfig) -> None:
94+
"""Provide backwards compatibility for RoPE."""
95+
rope_scaling = getattr(config, "rope_scaling", None)
96+
if rope_scaling is None:
97+
return
98+
99+
patch_rope_scaling_dict(rope_scaling)
100+
101+
102+
def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
103+
# Although HF prefers "rope_type", we have code that accesses "type",
104+
# so we populate both keys
105+
if "type" in rope_scaling:
106+
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
107+
elif "rope_type" in rope_scaling:
108+
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
109+
else:
110+
raise ValueError("rope_scaling must have a 'type' or 'rope_type' key")
111+
112+
if rope_type == "su":
113+
rope_scaling["type"] = rope_scaling["rope_type"] = "longrope"
114+
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
115+
elif rope_type == "mrope":
116+
assert "mrope_section" in rope_scaling
117+
rope_scaling["type"] = rope_scaling["rope_type"] = "default"
118+
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
119+
120+
93121
def get_config(
94122
model: Union[str, Path],
95123
trust_remote_code: bool,
@@ -177,26 +205,7 @@ def get_config(
177205
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
178206
config.update({"architectures": [model_type]})
179207

180-
# Backwards compatibility for RoPE
181-
rope_scaling = getattr(config, "rope_scaling", None)
182-
if rope_scaling is not None:
183-
# Although HF prefers "rope_type", we have code that accesses "type",
184-
# so we populate both keys
185-
if "type" in rope_scaling:
186-
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
187-
elif "rope_type" in rope_scaling:
188-
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
189-
else:
190-
raise ValueError(
191-
"rope_scaling must have a 'type' or 'rope_type' key.")
192-
193-
if rope_type == "su":
194-
rope_scaling["rope_type"] = rope_type = "longrope"
195-
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
196-
elif rope_type == "mrope":
197-
assert "mrope_section" in rope_scaling
198-
rope_scaling["rope_type"] = rope_type = "default"
199-
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
208+
patch_rope_scaling(config)
200209

201210
for key, value in [
202211
("rope_scaling", rope_scaling),

0 commit comments

Comments
 (0)