diff --git a/ppim/models/deit.py b/ppim/models/deit.py index d9f6428..4c7c6de 100644 --- a/ppim/models/deit.py +++ b/ppim/models/deit.py @@ -68,10 +68,10 @@ def __init__(self, if class_dim > 0: self.head_dist = nn.Linear(self.embed_dim, self.class_dim) + self.head_dist.apply(self._init_weights) trunc_normal_(self.dist_token) trunc_normal_(self.pos_embed) - self.head_dist.apply(self._init_weights) def forward_features(self, x): B = paddle.shape(x)[0]