Skip to content

Commit 2cd402e

Browse files
robertgshaw2-redhatRobert Shaw
and
Robert Shaw
authored
[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
1 parent b185230 commit 2cd402e

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

vllm/model_executor/layers/linear.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,13 @@ def weight_loader(self,
383383
None)
384384

385385
if loaded_shard_id is None:
386-
# Loaded weight is already packed.
386+
# Loaded weight is already fused on disk (qkv/mlp).
387387
if output_dim is None:
388+
# If fp8 + scale, need to send to each shard.
389+
if fp8_scales_shard_indexer is not None:
390+
param_data, loaded_weight = fp8_scales_shard_indexer(
391+
param_data, loaded_weight, loaded_shard_id)
392+
388393
assert param_data.shape == loaded_weight.shape
389394
param_data.copy_(loaded_weight)
390395
return
@@ -567,8 +572,13 @@ def weight_loader(self,
567572
None)
568573

569574
if loaded_shard_id is None:
570-
# Loaded weight is already packed.
575+
# Loaded weight is already fused on disk (qkv/mlp).
571576
if output_dim is None:
577+
# If fp8 + scale, need to send to each shard.
578+
if fp8_scales_shard_indexer is not None:
579+
param_data, loaded_weight = fp8_scales_shard_indexer(
580+
param_data, loaded_weight, loaded_shard_id)
581+
572582
assert param_data.shape == loaded_weight.shape
573583
param_data.copy_(loaded_weight)
574584
return

vllm/model_executor/layers/quantization/fp8.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
9898
"""
9999

100100
def __init__(self, quant_config: Fp8Config):
101+
self.fused_module_in_checkpoint = False
101102
self.quant_config = quant_config
102103
self.cutlass_fp8_supported = cutlass_fp8_supported()
103104

@@ -111,6 +112,7 @@ def _create_scale_param(
111112
scale = Parameter(torch.empty(len(output_partition_sizes),
112113
dtype=torch.float32),
113114
requires_grad=False)
115+
scale[:] = torch.finfo(torch.float8_e4m3fn).min
114116
layer.register_parameter(scale_name, scale)
115117
set_weight_attrs(
116118
scale, {
@@ -169,11 +171,15 @@ def create_weights(
169171
**extra_weight_attrs)
170172

171173
def scales_shard_indexer(
172-
self, param: torch.Tensor, loaded_weight: torch.Tensor,
173-
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
174+
self, param: torch.Tensor, loaded_weight: torch.Tensor,
175+
shard_id: Optional[Union[str,
176+
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
174177
qkv_idxs = {"q": 0, "k": 1, "v": 2}
175178

176-
if isinstance(shard_id, int):
179+
if shard_id is None:
180+
shard_id = 0
181+
self.fused_module_in_checkpoint = True
182+
elif isinstance(shard_id, int):
177183
pass
178184
elif isinstance(shard_id, str):
179185
if shard_id not in qkv_idxs:
@@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
205211
# WEIGHT_SCALE / WEIGHT
206212
# Loop over logical weights, requantizing with single scale.
207213
max_w_scale = layer.weight_scale.max()
208-
start = 0
209-
for idx, logical_width in enumerate(layer.logical_widths):
210-
end = start + logical_width
211-
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
212-
layer.weight_scale[idx])
213-
214-
layer.weight[start:end, :] = per_tensor_quantize(
215-
weight_dq, layer.weight_scale.max())
216-
start = end
214+
215+
if not self.fused_module_in_checkpoint:
216+
start = 0
217+
for idx, logical_width in enumerate(layer.logical_widths):
218+
end = start + logical_width
219+
weight_dq = per_tensor_dequantize(
220+
layer.weight[start:end, :], layer.weight_scale[idx])
221+
222+
layer.weight[start:end, :] = per_tensor_quantize(
223+
weight_dq, layer.weight_scale.max())
224+
start = end
217225
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
218226

219227
# WEIGHT
@@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
227235
if self.quant_config.activation_scheme == "dynamic":
228236
layer.input_scale = None
229237
elif self.quant_config.activation_scheme == "static":
230-
if not all_close_1d(layer.input_scale):
231-
raise ValueError(
232-
"All the input_scales for the logical weights of a "
233-
f"layer must be equal. But got {layer.input_scale}")
234238
layer.input_scale = Parameter(layer.input_scale.max(),
235239
requires_grad=False)
236240
else:
@@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
317321
del layer.kv_scale
318322

319323

320-
def all_close_1d(x: torch.Tensor) -> bool:
321-
assert len(x.shape) == 1
322-
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
323-
324-
325324
def per_tensor_quantize(tensor: torch.Tensor,
326325
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
327326
finfo = torch.finfo(torch.float8_e4m3fn)

0 commit comments

Comments
 (0)