Skip to content

Commit

Permalink
Feat (nn): set dim names in QuantMHA Linear (#629)
Browse files Browse the repository at this point in the history
* Feat (nn): set dim names in QuantMHA Linear

* remove dims afterwards

* Fix dim

* Remove no op view

* Remove naming before linear ops

* Correct handling of QuantTensor. No named tensors during tracing/export

---------

Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
  • Loading branch information
volcacius and Giuseppe5 authored Jul 4, 2023
1 parent 16aa4ae commit e99ae57
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,15 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
return inp
else:
inp = QuantTensor(inp, training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
return inp
# Remove any naming metadata to avoid dowmstream errors
if not torch._C._get_tracing_state():
inp.value.rename_(None)
return inp

def pack_output(self, quant_output: QuantTensor):
if not self.training and self.cache_inference_quant_out:
Expand Down
29 changes: 26 additions & 3 deletions src/brevitas/nn/quant_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,15 @@ def multi_head_attention(
#
# compute in-projection
#

if self.in_proj is not None:
if check_tensors_same_ptr([key, query, value]):
# Mark dimensions through named tensors.
if not torch._C._get_tracing_state():
if isinstance(query, QuantTensor):
query.value.rename_('L', 'N', 'E')
else:
query.rename_('L', 'N', 'E')
# self-attention
q, k, v = self.in_proj(query).chunk(3, dim=-1)
else:
Expand All @@ -415,7 +422,15 @@ def multi_head_attention(
assert self.q_proj is not None, "use_separate_proj_weight is True but q_proj is None"
assert self.k_proj is not None, "use_separate_proj_weight is True but k_proj is None"
assert self.v_proj is not None, "use_separate_proj_weight is True but v_proj is None"
# Mark dimensions through named tensors.
if not torch._C._get_tracing_state():
for t in [query, key, value]:
t.rename_('L', 'N', 'E')
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
# Remove names to avoid errors downstream
if not torch._C._get_tracing_state():
for t in [q, k, v]:
t.rename_(None)

# prep attention mask
if attn_mask is not None:
Expand Down Expand Up @@ -546,10 +561,18 @@ def multi_head_attention(
v = self.v_quant(v)

attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)

# preserve the 3D input compared to the float version to be able to do row wise scaling
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
# Set dim names for PTQ algorithms that requires it
if not torch._C._get_tracing_state():
attn_output.rename_('L', 'N', 'E')
attn_output = self.out_proj(attn_output)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# Remove names to avoid errors un unsupported downstream ops
if not torch._C._get_tracing_state():
if isinstance(attn_output, QuantTensor):
attn_output.value.rename_(None)
else:
attn_output.rename_(None)

if need_weights:
# optionally average attention weights over heads
Expand Down

0 comments on commit e99ae57

Please # to comment.