You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.
Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.
defforward(
self, q: torch.Tensor, k: torch.Tensor
) ->Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached=self._update_cos_sin_tables(
k, seq_dimension=-2# should be seq_dimension=1 or no argument should be passed as the default value is correct
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!
The text was updated successfully, but these errors were encountered:
🐛 Bug
Input tensors to attention must be in format
[B, M, H, K]
, whereB
is the batch size,M
the sequence length,H
the number of heads, andK
the embedding size per head as documented here.Hence positional embedding (e.g., rotary embedding) should be applied to
dim=1
. However, in theRotaryEmbedding
class,dim=-2
is being passed, which corresponds todim=2
as seen here.Additional context
Thanks to @jmercat who found symptoms of this problem downstream of xformers!
The text was updated successfully, but these errors were encountered: