From b0c47c5d501f86fc097a070fd9d52fa7d9414f00 Mon Sep 17 00:00:00 2001 From: Louis Lac Date: Wed, 1 Jan 2025 12:03:34 +0100 Subject: [PATCH] Fixed unfused attn2d scale --- tests/test_layers.py | 27 ++++++++++++++++++++++++++- timm/layers/attention2d.py | 8 ++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 2cc8420abf..7726c3d055 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,7 +1,8 @@ +import pytest import torch import torch.nn as nn -from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn +from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d import importlib import os @@ -119,3 +120,27 @@ def test_get_act_fn_none(): assert get_act_fn(None) is None assert get_act_fn('') is None + +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("expand_first", [True, False]) +@pytest.mark.parametrize("head_first", [True, False]) +@pytest.mark.parametrize("attn_mask", [True, False]) +def test_attn2d(bias, expand_first, head_first, attn_mask): + x = torch.randn(1, 128, 32, 48) + attn = Attention2d( + 128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first + ) + + if attn_mask: + mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32) + else: + mask = None + + o1 = attn(x, mask) + attn.fused_attn = False + o2 = attn(x, mask) + + assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}" + + + \ No newline at end of file diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 31200adf76..1b4c658429 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -312,7 +312,6 @@ def __init__( self.num_heads = num_heads self.dim_head = dim_attn // num_heads self.head_first = head_first - self.scale = num_heads ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) @@ -337,14 +336,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): dropout_p=self.attn_drop.p if self.training else 0., ).transpose(-1, -2).reshape(B, -1, H, W) else: - q = q * self.scale - attn = q.transpose(-2, -1) @ k + q = q.transpose(-1, -2) + v = v.transpose(-1, -2) + attn = q @ k * q.size(-1) ** -0.5 if attn_mask is not None: # NOTE: assumes mask is float and in correct shape attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W) x = self.proj(x) x = self.proj_drop(x)