Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add an isolated implementation of FlashDiffAttention #1633

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

zhuzilin
Copy link

@zhuzilin zhuzilin commented Oct 9, 2024

This PR is trying to implement a FlashDiffAttention class similar to the FlashSelfAttention in the origin flash attention repo (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py#L53), so that training frameworks could easily add diff transformer support with and without varlen support.

The main idea is to set the num_head in the training process twice as the origin transformer so that we no longer need to change the code relates to RoPE.

A simple test script for the code is:

from dataclasses import dataclass

import torch
import torch.distributed as dist
from flash_attn.layers.rotary import RotaryEmbedding
from einops import rearrange

from multihead_flashdiff_2 import MultiheadFlashDiff2
from flashdiff import FlashDiffAttention
from kernel.rotary import apply_rotary_emb


@dataclass
class Args:
    model_parallel_size: int
    decoder_kv_attention_heads: int


def create_new_impl(origin_impl, head_dim, depth):
    diff_attn_func = FlashDiffAttention(
        head_dim=embed_dim // num_new_heads, depth=depth, causal=True
    ).to(device, dtype=dtype)
    # make the initialization the same
    diff_attn_func.lambda_q1.data.copy_(origin_impl.lambda_q1.data)
    diff_attn_func.lambda_k1.data.copy_(origin_impl.lambda_k1.data)
    diff_attn_func.lambda_q2.data.copy_(origin_impl.lambda_q2.data)
    diff_attn_func.lambda_k2.data.copy_(origin_impl.lambda_k2.data)
    #diff_attn_func.subln.weight.data.copy_(origin_impl.subln.weight.data)
    
    def new_impl(x, rel_pos):
        bsz, tgt_len, embed_dim = x.size()
        src_len = tgt_len

        q = origin_impl.q_proj(x)
        k = origin_impl.k_proj(x)
        v = origin_impl.v_proj(x)

        # here we no longer need "// 2"
        num_heads = embed_dim // head_dim
        num_kv_heads = k.shape[-1] // head_dim

        q = q.view(bsz, tgt_len, num_heads, head_dim)
        k = k.view(bsz, src_len, num_kv_heads, head_dim)
        v = v.view(bsz, src_len, num_kv_heads, head_dim)

        q = apply_rotary_emb(q, *rel_pos, interleaved=True)
        k = apply_rotary_emb(k, *rel_pos, interleaved=True)

        output = diff_attn_func(q, k, v)
        output = rearrange(output, '... H D -> ... (H D)')

        output = origin_impl.out_proj(output)
        return output
    
    return new_impl


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    device = torch.device("cuda")
    dtype = torch.bfloat16
    args = Args(model_parallel_size=1, decoder_kv_attention_heads=4)
    batch_size = 2
    num_heads = 16
    seq_len = 512
    embed_dim = 2048
    depth = 12
    # in the new implementation, the num_heads should be twice the original num_heads
    num_new_heads = num_heads * 2
    head_dim = embed_dim // num_new_heads

    print("initializing modules")
    origin_impl = MultiheadFlashDiff2(args, embed_dim=embed_dim, depth=depth, num_heads=num_heads).to(device, dtype=dtype)
    new_impl = create_new_impl(origin_impl, head_dim, depth)

    print("creating test data")
    rotary_emb = RotaryEmbedding(
        head_dim,
        base=10000.0,
        interleaved=True,
        device=device,
    )
    rotary_emb._update_cos_sin_cache(seq_len, device=device, dtype=torch.bfloat16)
    rel_pos = (rotary_emb._cos_cached, rotary_emb._sin_cached)
    hidden_states = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype)

    print("run origin forward")
    origin_output = origin_impl(hidden_states, rel_pos)

    print("run new forward")
    new_output = new_impl(hidden_states, rel_pos)

    assert torch.allclose(origin_output, new_output, atol=1e-6)

Thank you for your time on reviewing this PR.

@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch 2 times, most recently from 200479b to f9f35a8 Compare October 9, 2024 06:29
@zhuzilin zhuzilin changed the title [WIP] Add a isolated implementation of FlashDiffAttention Add an isolated implementation of FlashDiffAttention Oct 9, 2024
@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch from f9f35a8 to c6e6486 Compare October 9, 2024 06:36
@MarktHart
Copy link

You could go even closer to attention and use it as is with a doubled interleave. E.g.

def alternative_forward(
        self,
        x,
        rel_pos,
        attn_mask=None,
    ):
    bsz, tgt_len, embed_dim = x.size()
    src_len = tgt_len

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
    k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
    v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)

    q = apply_rotary_emb(q, *rel_pos, interleaved=True)
    k = apply_rotary_emb(k, *rel_pos, interleaved=True)

    q = q.transpose(1, 2)
    
    k = torch.repeat_interleave(k.transpose(1, 2), dim=1, repeats=self.n_rep)
    v = torch.repeat_interleave(v.transpose(1, 2), dim=1, repeats=self.n_rep * 2)
    if attn_mask is None:
        attn_mask = torch.triu(
            torch.zeros([tgt_len, src_len])
            .float()
            .fill_(float("-inf"))
            .type_as(q),
            1 + src_len - tgt_len,
        )

    lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
    lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
    lambda_full = lambda_1 - lambda_2 + self.lambda_init

    attn_weights = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_mask, scale=self.scaling)
    every_other_mask = torch.arange(attn_weights.size(1)) % 2 == 0
    attn = attn_weights[:, every_other_mask, :, :] - lambda_full * attn_weights[:, ~every_other_mask, :, :]

    attn = self.subln(attn)
    attn = attn * (1 - self.lambda_init)
    attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)

    attn = self.out_proj(attn)
    return attn

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants