From 59c39cdbad452df4b3f015ef71400587f47e5958 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 15 Apr 2022 06:37:10 +0200 Subject: [PATCH] Fix PT TF ViTMAE (#16766) Co-authored-by: ydshieh --- src/transformers/models/vit_mae/modeling_tf_vit_mae.py | 4 +++- src/transformers/models/vit_mae/modeling_vit_mae.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 6ff588fce3d44a..40642ef4a63df0 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -860,7 +860,9 @@ def __init__(self, config, num_patches, **kwargs): self.decoder_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm") self.decoder_pred = tf.keras.layers.Dense( - config.patch_size**2 * config.num_channels, name="decoder_pred" + config.patch_size**2 * config.num_channels, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder_pred", ) # encoder to decoder self.config = config self.num_patches = num_patches diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 5f257dd61f98a2..473ccd14feb099 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -756,7 +756,7 @@ def __init__(self, config, num_patches): [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)] ) - self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size) + self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) self.decoder_pred = nn.Linear( config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True ) # encoder to decoder