diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 3c8d3213956866..52c332418699d7 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -920,7 +920,7 @@ def __call__( def _get_feat_extract_output_lengths( self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None ): - return self.module._get_feat_extract_output_lengths(input_lengths) + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) class FlaxWav2Vec2Module(nn.Module): @@ -956,15 +956,10 @@ def __call__( # make sure that no loss is computed on padded inputs if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4")) - - attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype) - - # these two operations makes sure that all values - # before the output lengths indices are attended to - attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) - attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) if mask_time_indices is not None: # apply SpecAugment along time axis with given indices @@ -1034,12 +1029,10 @@ def _get_feature_vector_attention_mask( batch_size = attention_mask.shape[0] attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) - # these two operations makes sure that all values before the output lengths idxs are attended to - idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1) - attention_mask = attention_mask.at[idx].set(1) - attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1) - - attention_mask = jnp.array(attention_mask, dtype=bool) + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) + attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") return attention_mask @@ -1286,11 +1279,15 @@ def __call__( attentions=outputs.attentions, ) - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): """ Computes the output length of the convolutional layers """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html @@ -1299,6 +1296,10 @@ def _conv_out_length(input_length, kernel_size, stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 873c09105bc0f0..0549e650645238 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -539,6 +539,12 @@ def test_pt_flax_equivalence(self): self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) + # check `add_adapter` works as expected + config.add_adapter = True + self.assertTrue(config.add_adapter) + self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) + self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) + @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model()