Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor mamba to integrate torch.compile, reference invocation, accuracy errors #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mamba_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from mamba_lm import MambaLM, MambaLMConfig

import time

def nano_time(func, inps, *, iterations=100):
start_time = time.time_ns()

for _ in range(iterations):
func(inps)

end_time = time.time_ns()

total_time_ns = end_time - start_time
return total_time_ns

config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=32000)
model = MambaLM(config)

# Changing backend to inductor causes accuracy errors!
compiled_model = torch.compile(backend="eager", fullgraph=True)(model)

x = torch.randint(high=32000, size=(16, 64))
ref = model(x)
# First run (cheat by preheating, lol)
logits = compiled_model(x) # (B, L, vocab_size)
assert torch.equal(logits, ref)

eager_time = nano_time(model, x)
print("Eager time:", eager_time)
compile_time = nano_time(compiled_model, x)
print("Compiled time:", compile_time)

print("Speedup or slowdown?", eager_time / compile_time)
105 changes: 53 additions & 52 deletions pscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,54 @@

# TODO eviter les .flip() en codant un pscan reverse (avec flag)
# TODO commentaires en docstring
# TODO(voz): This was moved out of the autograd.Function due to torch.compile being weird
# about user defined functions and methods grafted onto autograd.Function.
# We should just teach torch.compile to treat methods here properly via inlining.
def pscan_inner(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)

class PScan(torch.autograd.Function):
@staticmethod
def pscan(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)

# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

B, D, L, _ = A.size()
num_steps = int(math.log2(L))

# up sweep or reduction step
Aa = A
Xa = X
for k in range(num_steps):
T = 2 * (Xa.size(2) // 2)

Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)

Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]

# down sweep
for k in range(num_steps-1, -1, -1):
Aa = A[:, :, 2**k-1:L:2**k]
Xa = X[:, :, 2**k-1:L:2**k]

T = 2 * (Xa.size(2) // 2)

if T < Xa.size(2):
Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2]))
Aa[:, :, -1].mul_(Aa[:, :, -2])

Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)

Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

B, D, L, _ = A.size()
num_steps = int(math.log2(L))

# up sweep or reduction step
Aa = A
Xa = X
for k in range(num_steps):
T = 2 * (Xa.size(2) // 2)

Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)

Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]

# down sweep
for k in range(num_steps-1, -1, -1):
Aa = A[:, :, 2**k-1:L:2**k]
Xa = X[:, :, 2**k-1:L:2**k]

T = 2 * (Xa.size(2) // 2)

if T < Xa.size(2):
Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2]))
Aa[:, :, -1].mul_(Aa[:, :, -2])

Aa = Aa[:, :, :T].view(B, D, T//2, 2, -1)
Xa = Xa[:, :, :T].view(B, D, T//2, 2, -1)

Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

class PScan(torch.autograd.Function):
@staticmethod
def forward(ctx, A_in, X_in):
"""
Expand All @@ -76,18 +77,18 @@ def forward(ctx, A_in, X_in):
# clone tensor (in-place ops)
A = A_in.clone() # (B, L, D, N)
X = X_in.clone() # (B, L, D, N)

# prepare tensors
A = A.transpose(2, 1) # (B, D, L, N)
X = X.transpose(2, 1) # (B, D, L, N)

# parallel scan
PScan.pscan(A, X)
pscan_inner(A, X)

ctx.save_for_backward(A_in, X)

return X.transpose(2, 1)

@staticmethod
def backward(ctx, grad_output_in):
"""
Expand All @@ -103,7 +104,7 @@ def backward(ctx, grad_output_in):

A_in, X = ctx.saved_tensors

# clone tensors
# clone tensors
A = A_in.clone()
# grad_output_in will be cloned with flip()

Expand All @@ -114,12 +115,12 @@ def backward(ctx, grad_output_in):

# reverse parallel scan
grad_output_b = grad_output_b.flip(2)
PScan.pscan(A, grad_output_b)
pscan_inner(A, grad_output_b)
grad_output_b = grad_output_b.flip(2)

Q = torch.zeros_like(X)
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output_b[:, :, 1:])

return Q.transpose(2, 1), grad_output_b.transpose(2, 1)
pscan = PScan.apply

pscan = PScan.apply