diff --git a/models/modeling_flax_wav2vec2.py b/models/modeling_flax_wav2vec2.py index 6095cd0..bc878d5 100644 --- a/models/modeling_flax_wav2vec2.py +++ b/models/modeling_flax_wav2vec2.py @@ -694,7 +694,7 @@ def __call__( def _get_feat_extract_output_lengths( self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None ): - return self.module._get_feat_extract_output_lengths(input_lengths) + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) class FlaxWav2Vec2Module(nn.Module): @@ -730,15 +730,10 @@ def __call__( # make sure that no loss is computed on padded inputs if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4")) - - attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype) - - # these two operations makes sure that all values - # before the output lengths indices are attended to - attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) - attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) if mask_time_indices is not None: # apply SpecAugment along time axis with given indices @@ -808,12 +803,10 @@ def _get_feature_vector_attention_mask( batch_size = attention_mask.shape[0] attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) - # these two operations makes sure that all values before the output lengths idxs are attended to - idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1) - attention_mask = attention_mask.at[idx].set(1) - attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1) - - attention_mask = jnp.array(attention_mask, dtype=bool) + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) + attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") return attention_mask