@@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
98
98
"""
99
99
100
100
def __init__ (self , quant_config : Fp8Config ):
101
+ self .fused_module_in_checkpoint = False
101
102
self .quant_config = quant_config
102
103
self .cutlass_fp8_supported = cutlass_fp8_supported ()
103
104
@@ -111,6 +112,7 @@ def _create_scale_param(
111
112
scale = Parameter (torch .empty (len (output_partition_sizes ),
112
113
dtype = torch .float32 ),
113
114
requires_grad = False )
115
+ scale [:] = torch .finfo (torch .float8_e4m3fn ).min
114
116
layer .register_parameter (scale_name , scale )
115
117
set_weight_attrs (
116
118
scale , {
@@ -169,11 +171,15 @@ def create_weights(
169
171
** extra_weight_attrs )
170
172
171
173
def scales_shard_indexer (
172
- self , param : torch .Tensor , loaded_weight : torch .Tensor ,
173
- shard_id : Union [str , int ]) -> Tuple [torch .Tensor , torch .Tensor ]:
174
+ self , param : torch .Tensor , loaded_weight : torch .Tensor ,
175
+ shard_id : Optional [Union [str ,
176
+ int ]]) -> Tuple [torch .Tensor , torch .Tensor ]:
174
177
qkv_idxs = {"q" : 0 , "k" : 1 , "v" : 2 }
175
178
176
- if isinstance (shard_id , int ):
179
+ if shard_id is None :
180
+ shard_id = 0
181
+ self .fused_module_in_checkpoint = True
182
+ elif isinstance (shard_id , int ):
177
183
pass
178
184
elif isinstance (shard_id , str ):
179
185
if shard_id not in qkv_idxs :
@@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
205
211
# WEIGHT_SCALE / WEIGHT
206
212
# Loop over logical weights, requantizing with single scale.
207
213
max_w_scale = layer .weight_scale .max ()
208
- start = 0
209
- for idx , logical_width in enumerate (layer .logical_widths ):
210
- end = start + logical_width
211
- weight_dq = per_tensor_dequantize (layer .weight [start :end , :],
212
- layer .weight_scale [idx ])
213
-
214
- layer .weight [start :end , :] = per_tensor_quantize (
215
- weight_dq , layer .weight_scale .max ())
216
- start = end
214
+
215
+ if not self .fused_module_in_checkpoint :
216
+ start = 0
217
+ for idx , logical_width in enumerate (layer .logical_widths ):
218
+ end = start + logical_width
219
+ weight_dq = per_tensor_dequantize (
220
+ layer .weight [start :end , :], layer .weight_scale [idx ])
221
+
222
+ layer .weight [start :end , :] = per_tensor_quantize (
223
+ weight_dq , layer .weight_scale .max ())
224
+ start = end
217
225
layer .weight_scale = Parameter (max_w_scale , requires_grad = False )
218
226
219
227
# WEIGHT
@@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
227
235
if self .quant_config .activation_scheme == "dynamic" :
228
236
layer .input_scale = None
229
237
elif self .quant_config .activation_scheme == "static" :
230
- if not all_close_1d (layer .input_scale ):
231
- raise ValueError (
232
- "All the input_scales for the logical weights of a "
233
- f"layer must be equal. But got { layer .input_scale } " )
234
238
layer .input_scale = Parameter (layer .input_scale .max (),
235
239
requires_grad = False )
236
240
else :
@@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
317
321
del layer .kv_scale
318
322
319
323
320
- def all_close_1d (x : torch .Tensor ) -> bool :
321
- assert len (x .shape ) == 1
322
- return all (torch .allclose (x [0 ], x [i ]) for i in range (x .shape [0 ]))
323
-
324
-
325
324
def per_tensor_quantize (tensor : torch .Tensor ,
326
325
inv_scale : Union [float , torch .Tensor ]) -> torch .Tensor :
327
326
finfo = torch .finfo (torch .float8_e4m3fn )
0 commit comments