Skip to content

Commit

Permalink
Merge pull request #9 from sanchit-gandhi/flax-train-script
Browse files Browse the repository at this point in the history
[FlaxWav2Vec2Model] Fix bug in attention mask
  • Loading branch information
sanchit-gandhi committed Apr 14, 2022
2 parents db3fe85 + 8a18ec4 commit b8f5570
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions models/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def __call__(
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths)
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)


class FlaxWav2Vec2Module(nn.Module):
Expand Down Expand Up @@ -730,15 +730,10 @@ def __call__(

# make sure that no loss is computed on padded inputs
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))

attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)

# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)

hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
Expand Down Expand Up @@ -808,12 +803,10 @@ def _get_feature_vector_attention_mask(
batch_size = attention_mask.shape[0]

attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
attention_mask = attention_mask.at[idx].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)

attention_mask = jnp.array(attention_mask, dtype=bool)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
return attention_mask


Expand Down

0 comments on commit b8f5570

Please # to comment.