Skip to content

Fixes bug in SelectiveScanFn.forward for when B is not variable and last_state is returned #371

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vidavakil
Copy link

test_selective_scan() fails when is_variable_B is False.

Turns out selective_scan_fwd_kernel incorporates an optimization of not multiplying the state by B if B is not variable. This does not impact MambaInnerFn, because MambaInnerFn never returns the state. But SelectiveScanFn may need to return the last_state. The changes to the code fix this problem, by multiplying the last_state by B before returning it when B is not variable.

…function has to return

the last_state. # The cuda kernel does a peculiar optimization of not multiplying the state
by B if B is not variable! This does not impact MambaInnerFn, because it never returns the
state. But SelectiveScanFn may needd to return the last state! Hence the following is needed.
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant