Skip to content

Commit

Permalink
fp8 BWD
Browse files Browse the repository at this point in the history
Enable BWD fp8 with split kernel

Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save

add type info for backward

fix  DEBUG flag bug

fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides.

pass descale strides

test causal

fix causal compiler assert. min head should be 32

remove descale_p

save

explict name as causal

isolate bad case

just run fp8 tests

bench with autotune

min changes

cast_fp8 helper

cast_varlen_to_fp8

save

minor

highlight failing configs

increase test cases

mark failing

recategorize misc tests

group failing gqa configs

add more tests

add vis code

min ci changes

dump folder

single image per tensors

add tensor comparison

gen varlen tensor

vis varlen tensors

varlen diff

nice varlen vis

vis function

show seqlen in varlen
  • Loading branch information
micmelesse committed Feb 14, 2025
1 parent 92529cc commit d9de311
Show file tree
Hide file tree
Showing 8 changed files with 1,215 additions and 467 deletions.
60 changes: 39 additions & 21 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def _flash_attn_forward(
return_softmax: bool,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None
descale_v: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
Expand All @@ -113,8 +112,7 @@ def _flash_attn_forward(
None,
descale_q,
descale_k,
descale_v,
descale_p
descale_v
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -175,7 +173,6 @@ def _flash_attn_varlen_forward(
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
Expand All @@ -202,8 +199,7 @@ def _flash_attn_varlen_forward(
None,
descale_q,
descale_k,
descale_v,
descale_p
descale_v
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -273,6 +269,10 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +301,10 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_do
)
return softmax_d

Expand Down Expand Up @@ -369,6 +373,10 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -402,6 +410,10 @@ def _flash_attn_varlen_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_do
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -823,7 +835,7 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -846,10 +858,9 @@ def forward(
return_softmax=return_softmax and dropout_p > 0,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
descale_v=descale_v
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_do)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +873,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,6 +898,10 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down Expand Up @@ -917,7 +932,7 @@ def forward(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -945,11 +960,10 @@ def forward(
block_table=block_table,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p
descale_v=descale_v
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_do
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
Expand All @@ -965,7 +979,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_do = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(2)
dout_padded = dout
Expand Down Expand Up @@ -994,6 +1008,10 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_do=descale_do
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down Expand Up @@ -1151,7 +1169,7 @@ def flash_attn_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1216,7 +1234,7 @@ def flash_attn_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
)


Expand Down Expand Up @@ -1396,7 +1414,7 @@ def flash_attn_varlen_func(
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
descale_do=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1473,7 +1491,7 @@ def flash_attn_varlen_func(
descale_q,
descale_k,
descale_v,
descale_p
descale_do
)


Expand Down
3 changes: 3 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Inside the docker, it should open to the flash attention repo with everything in
pytest tests/test_flash_attn_triton_amd.py
```

##### FP8
In our fork, we have modified the api to work with fp8. You provide tensors that are scaled to be in fp8 range and their associated descaling factors.

##### Credits
AMD Triton kernels team

Expand Down
Loading

0 comments on commit d9de311

Please # to comment.