diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index a6b4bd6b4..c9d8f07b4 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -444,7 +444,18 @@ def forward( else: q, k, v = self.forward_qkv(query, key, value) new_cache = (k, v) if not self.training else cache - + # for multi query or multi groups attention + if self.h_kv != self.h and self.h_kv != 1: + k = torch.repeat_interleave( + k, + self.h // self.h_kv, + dim=-3, + ) + v = torch.repeat_interleave( + v, + self.h // self.h_kv, + dim=-3, + ) B = query.size(0) Beams = 1 if B != k.size(0):