Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: support cached cos/sin in rope APIs #585

Merged
merged 4 commits into from
Nov 5, 2024
Merged

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 5, 2024

As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases.

In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values.

In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation flashinfer.apply_rope (which stores cos/sin with fp32). So we require the cos_cache and sin_cache to use fp32 data type.

cc @dreaming-panda @ByronHsu

@yzh119 yzh119 merged commit 83e541d into main Nov 5, 2024
@yzh119 yzh119 mentioned this pull request Nov 10, 2024
@yzh119 yzh119 deleted the rope-cached-sin-cos branch November 10, 2024 08:47
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant