diff --git a/timm/models/davit.py b/timm/models/davit.py index 1450c357cc..f00cf73384 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -570,11 +570,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - x = self.head.global_pool(x) - x = self.head.norm(x) - x = self.head.flatten(x) - x = self.head.drop(x) - return x if pre_logits else self.head.fc(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x)