Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

RotaryEmbedding applied to the incorrect channel dimension #841

Open
sagadre opened this issue Aug 31, 2023 · 0 comments
Open

RotaryEmbedding applied to the incorrect channel dimension #841

sagadre opened this issue Aug 31, 2023 · 0 comments

Comments

@sagadre
Copy link

sagadre commented Aug 31, 2023

🐛 Bug

Input tensors to attention must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embedding size per head as documented here.

Hence positional embedding (e.g., rotary embedding) should be applied to dim=1. However, in the RotaryEmbedding class, dim=-2 is being passed, which corresponds to dim=2 as seen here.

def forward(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
            k, seq_dimension=-2 # should be seq_dimension=1 or no argument should be passed as the default value is correct
        )

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

Additional context

Thanks to @jmercat who found symptoms of this problem downstream of xformers!

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant