From 6f2a0a60cf86cc7dea94d37a18e9c2e8a1429bc4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 22 Jun 2023 18:05:22 +0200 Subject: [PATCH] Fix (weight_eq): fix for llm equalization (#638) * Fix (weight_eq): fix for llm equalization * rename new_module to new_value --- src/brevitas/graph/equalize.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 9622d6eba..91c7c9c65 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -444,11 +444,17 @@ def _no_equalize(): if len(src_axes) > 0: for module, axis in src_axes.items(): if hasattr(module, 'bias') and module.bias is not None: - module.bias.data = module.bias.data * inverse_scaling_factors.view_as(module.bias) + _update_weights( + module, + module.bias.clone() * inverse_scaling_factors.view_as(module.bias), + attr='bias') src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) - module.weight.data = module.weight.data * torch.reshape( - inverse_scaling_factors, src_broadcast_size) + _update_weights( + module, ( + module.weight.clone() * + torch.reshape(inverse_scaling_factors, src_broadcast_size)), + attr='weight') for module, axis in sink_axes.items(): src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) @@ -457,12 +463,23 @@ def _no_equalize(): # additive factor for equalization. additive_factor = module.running_mean.data * module.weight.data / torch.sqrt( module.running_var.data + module.eps) - module.bias.data = module.bias.data + additive_factor * (scaling_factors - 1) - module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size) + _update_weights( + module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') + _update_weights( + module, + module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size), + attr='weight') return scaling_factors +def _update_weights(original_module, new_value, attr='weight'): + if isinstance(original_module, WeightBiasTuple): + setattr(getattr(original_module, attr), 'data', new_value) + else: + setattr(original_module, attr, nn.Parameter(new_value)) + + def _equalize( model: GraphModule, regions: Set[Tuple[str]], @@ -509,7 +526,9 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: if isinstance(module, _supported_layers): # We support only self-attention if isinstance(module, nn.MultiheadAttention): - return all([node.all_input_nodes[0].name == n.name for n in node.all_input_nodes]) + kwargs = dict(node.kwargs) + kwargs.update(zip(module.forward.__code__.co_varnames[1:], node.args)) + return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name return True return False