Skip to content

Commit 118b80b

Browse files
[Bugfix] Fix torch dynamo fixes caused by replace_parameters (vllm-project#8748)
1 parent 992413a commit 118b80b

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

vllm/model_executor/layers/quantization/utils/layer_utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str,
2121
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
2222

2323
old = getattr(mod, name)
24-
if old.dtype == new.dtype and \
24+
if type(old) is type(new) and old.dtype == new.dtype and \
2525
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
2626
# If we can just update in-place to avoid re-registering
2727
# can be faster if the underlying storage is the same
2828
update_tensor_inplace(old, new)
2929
else:
30-
# Fallback re-register parameter
30+
# Fallback re-register parameter, convert to Parameter if necessary
31+
# this not only ensures we don't register a tensor as a parameter, but
32+
# also ensures that all parameter subclasses get re-registered as
33+
# parameters for `torch.compile` compatibility
3134
if not isinstance(new, torch.nn.Parameter):
32-
new = torch.nn.Parameter(new)
33-
mod.register_parameter(name, torch.nn.Parameter(new))
35+
new = torch.nn.Parameter(new, requires_grad=False)
36+
mod.register_parameter(name,
37+
torch.nn.Parameter(new, requires_grad=False))

0 commit comments

Comments
 (0)