@@ -90,6 +90,34 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
90
90
return False
91
91
92
92
93
+ def patch_rope_scaling (config : PretrainedConfig ) -> None :
94
+ """Provide backwards compatibility for RoPE."""
95
+ rope_scaling = getattr (config , "rope_scaling" , None )
96
+ if rope_scaling is None :
97
+ return
98
+
99
+ patch_rope_scaling_dict (rope_scaling )
100
+
101
+
102
+ def patch_rope_scaling_dict (rope_scaling : Dict [str , Any ]) -> None :
103
+ # Although HF prefers "rope_type", we have code that accesses "type",
104
+ # so we populate both keys
105
+ if "type" in rope_scaling :
106
+ rope_type = rope_scaling ["rope_type" ] = rope_scaling ["type" ]
107
+ elif "rope_type" in rope_scaling :
108
+ rope_type = rope_scaling ["type" ] = rope_scaling ["rope_type" ]
109
+ else :
110
+ raise ValueError ("rope_scaling must have a 'type' or 'rope_type' key" )
111
+
112
+ if rope_type == "su" :
113
+ rope_scaling ["type" ] = rope_scaling ["rope_type" ] = "longrope"
114
+ logger .warning ("Replacing legacy rope_type 'su' with 'longrope'" )
115
+ elif rope_type == "mrope" :
116
+ assert "mrope_section" in rope_scaling
117
+ rope_scaling ["type" ] = rope_scaling ["rope_type" ] = "default"
118
+ logger .warning ("Replacing legacy rope_type 'mrope' with 'default'" )
119
+
120
+
93
121
def get_config (
94
122
model : Union [str , Path ],
95
123
trust_remote_code : bool ,
@@ -177,26 +205,7 @@ def get_config(
177
205
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES [config .model_type ]
178
206
config .update ({"architectures" : [model_type ]})
179
207
180
- # Backwards compatibility for RoPE
181
- rope_scaling = getattr (config , "rope_scaling" , None )
182
- if rope_scaling is not None :
183
- # Although HF prefers "rope_type", we have code that accesses "type",
184
- # so we populate both keys
185
- if "type" in rope_scaling :
186
- rope_type = rope_scaling ["rope_type" ] = rope_scaling ["type" ]
187
- elif "rope_type" in rope_scaling :
188
- rope_type = rope_scaling ["type" ] = rope_scaling ["rope_type" ]
189
- else :
190
- raise ValueError (
191
- "rope_scaling must have a 'type' or 'rope_type' key." )
192
-
193
- if rope_type == "su" :
194
- rope_scaling ["rope_type" ] = rope_type = "longrope"
195
- logger .warning ("Replacing legacy rope_type 'su' with 'longrope'" )
196
- elif rope_type == "mrope" :
197
- assert "mrope_section" in rope_scaling
198
- rope_scaling ["rope_type" ] = rope_type = "default"
199
- logger .warning ("Replacing legacy rope_type 'mrope' with 'default'" )
208
+ patch_rope_scaling (config )
200
209
201
210
for key , value in [
202
211
("rope_scaling" , rope_scaling ),
0 commit comments