Skip to content

Commit

Permalink
Fix (weight_eq): fix for llm equalization (#638)
Browse files Browse the repository at this point in the history
* Fix (weight_eq): fix for llm equalization

* rename new_module to new_value
  • Loading branch information
Giuseppe5 authored Jun 22, 2023
1 parent 0a55295 commit 6f2a0a6
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,17 @@ def _no_equalize():
if len(src_axes) > 0:
for module, axis in src_axes.items():
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data = module.bias.data * inverse_scaling_factors.view_as(module.bias)
_update_weights(
module,
module.bias.clone() * inverse_scaling_factors.view_as(module.bias),
attr='bias')
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
module.weight.data = module.weight.data * torch.reshape(
inverse_scaling_factors, src_broadcast_size)
_update_weights(
module, (
module.weight.clone() *
torch.reshape(inverse_scaling_factors, src_broadcast_size)),
attr='weight')
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
Expand All @@ -457,12 +463,23 @@ def _no_equalize():
# additive factor for equalization.
additive_factor = module.running_mean.data * module.weight.data / torch.sqrt(
module.running_var.data + module.eps)
module.bias.data = module.bias.data + additive_factor * (scaling_factors - 1)
module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)
_update_weights(
module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size),
attr='weight')

return scaling_factors


def _update_weights(original_module, new_value, attr='weight'):
if isinstance(original_module, WeightBiasTuple):
setattr(getattr(original_module, attr), 'data', new_value)
else:
setattr(original_module, attr, nn.Parameter(new_value))


def _equalize(
model: GraphModule,
regions: Set[Tuple[str]],
Expand Down Expand Up @@ -509,7 +526,9 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
if isinstance(module, _supported_layers):
# We support only self-attention
if isinstance(module, nn.MultiheadAttention):
return all([node.all_input_nodes[0].name == n.name for n in node.all_input_nodes])
kwargs = dict(node.kwargs)
kwargs.update(zip(module.forward.__code__.co_varnames[1:], node.args))
return kwargs['query'].name == kwargs['key'].name == kwargs['value'].name
return True
return False

Expand Down

0 comments on commit 6f2a0a6

Please # to comment.