diff --git a/attn_gym/utils.py b/attn_gym/utils.py index 69cae71..8b6dc97 100644 --- a/attn_gym/utils.py +++ b/attn_gym/utils.py @@ -100,9 +100,9 @@ def visualize_attention_scores( Returns: None """ - assert ( - score_mod is not None or mask_mod is not None - ), "Must provide either score_mod or mask_mod" + assert score_mod is not None or mask_mod is not None, ( + "Must provide either score_mod or mask_mod" + ) query = query[batch_idx, head_idx, :, :] key = key[batch_idx, head_idx, :, :] scores_viz = create_score_mod( diff --git a/examples/benchmark.py b/examples/benchmark.py index 50debe2..690d524 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -26,18 +26,19 @@ from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap +def _causal_score(score, b, h, q_idx, kv_idx): + return causal_mask(b, h, q_idx, kv_idx).where(score, torch.finfo(score.dtype).min) + + AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), - "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), + "causal_score": lambda: test_mask(score_mod=_causal_score), + "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), - "softcap": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True - ), - "softcap_approx": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True - ), + "softcap": lambda: test_mask(score_mod=generate_tanh_softcap(30, approx=False)), + "softcap_approx": lambda: test_mask(score_mod=generate_tanh_softcap(30, approx=True)), } @@ -91,8 +92,15 @@ def test_mask( block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device) else: block_mask = None - sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod - mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device) + mask = create_mask(mask_mod, 1, H, S, S, device=device) if mask_mod else None + bias = create_mask(score_mod, 1, H, S, S, device=device) if score_mod else None + if bias is not None: + bias = bias.to(dtype=data_type) + if mask: + bias = bias.where(mask, torch.finfo(data_type).min) + mask = bias + else: + assert mask is not None qkv = [ torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True) @@ -121,6 +129,11 @@ def test_mask( del fwd_out torch.cuda.empty_cache() + ( + (causal_fa2_time, causal_fa2_bw_time), + (sdpa_mask_time, sdpa_mask_bw_time), + (flex_ms, flex_bw_ms), + ) = times print_header( f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace( @@ -152,11 +165,6 @@ def test_mask( print("Correctness check passed ✅") - ( - (causal_fa2_time, causal_fa2_bw_time), - (sdpa_mask_time, sdpa_mask_bw_time), - (flex_ms, flex_bw_ms), - ) = times # Usage in your results formatting: results = [ [