Skip to content

[Feature] Batch Decoding #477

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

AnaRhisT94
Copy link

@AnaRhisT94 AnaRhisT94 commented Jul 18, 2024

Hi! @tridao

Thanks to @agiwave for the base code to get me started. (Implementation on CPU which isn't fully working)

I present an implementation on GPU which is working below.

There are 3 scripts:

  1. mamba_simple.py - which has all the logic of decoding N tokens.
  2. test_mamba_ssm_state.py - which shows that batch decoding works.
  3. time_measure.py - which measures the timing between batching decoding N=1 and N>1 (You can see the speedup is 10x atleast)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant