Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[FlaxWav2Vec2Model] Fix bug in attention mask #16725

Merged
merged 3 commits into from
Apr 12, 2022

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Apr 12, 2022

Currently, the FlaxWav2Vec2 reduced attention mask is computed by calling the function _get_feat_extract_output_lengths, without explicit specification of whether an (optional) adapter module is used:

# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))

By default, if add_adapter is None, the boolean add_adapter will be set based on the config:
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

For this default setting, if the model contains an adapter module, then add_adapter will be set to True. This results in the convolutional formula including the downsampling performed by the convolutional layers in the feature extractor and the adapter module.

However, since the reduced attention mask is required for the encoder module, it should be computed based on the convolutional layers of the feature extractor only, and not those of the subsequent adapter module. This is highlighted by the PyTorch Wav2Vec2 modelling code:

if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)

The following code snippet demonstrates the effect of this bug by means of a PyTorch-Flax cross-test:

import torch
import numpy as np
from transformers import Wav2Vec2Model, FlaxWav2Vec2Model
import tempfile
import random

encoder_id = "hf-internal-testing/tiny-random-wav2vec2"

fx_model = FlaxWav2Vec2Model.from_pretrained(encoder_id, add_adapter=True, from_pt=True)

with tempfile.TemporaryDirectory() as tmpdirname:
    fx_model.save_pretrained(tmpdirname)
    pt_model = Wav2Vec2Model.from_pretrained(tmpdirname, config=fx_model.config, from_flax=True)


# create synthetic input data
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values).reshape(shape)

    return output


def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


def floats_tensor(shape, scale=1.0):
    """Creates a random float32 tensor"""
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randn() * scale)

    return np.array(values, dtype=np.float32).reshape(shape)


def fx_batch(batch_size=2, input_length=96000):
    input_ids = floats_tensor([batch_size, input_length])
    attention_mask = random_attention_mask([batch_size, input_length])
    
    fx_inputs = {
        "input_values": input_ids,
        "attention_mask": attention_mask,
    }
    return fx_inputs


fx_inputs = fx_batch()
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}

fx_outputs = fx_model( **fx_inputs, output_hidden_states=True)
pt_outputs = pt_model(**pt_inputs, output_hidden_states=True)

# helper function for our analysis
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
    diff = np.abs((a - b)).max()
    if diff < tol:
        print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")


print("--------------------------Checking hidden states match--------------------------")
for fx_state, pt_state in zip(fx_outputs.hidden_states, pt_outputs.hidden_states):
    assert fx_state.shape == pt_state.shape
    assert_almost_equals(fx_state, pt_state.detach().numpy())


print("--------------------------Checking last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {fx_outputs.last_hidden_state.shape}, encoder-only output shape: {pt_outputs.last_hidden_state.shape}")
assert_almost_equals(fx_outputs.last_hidden_state, pt_outputs.last_hidden_state.detach().numpy())

Output prior to fix:

--------------------------Checking encoder hidden states match--------------------------
❌ Difference between Flax and PyTorch is 0.43152332305908203 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.43074753880500793 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.42613524198532104 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.4301084578037262 (>= 0.01)
❌ Difference between Flax and PyTorch is 4.519614219665527 (>= 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 188, 16), encoder-only output shape: torch.Size([2, 188, 16])
✅ Difference between Flax and PyTorch is 0.0015139428433030844 (< 0.01)

Output following fix:

--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 3.9674341678619385e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.041939973831177e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.041939973831177e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.948807716369629e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.947185516357422e-06 (< 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 188, 16), encoder-only output shape: torch.Size([2, 188, 16])
✅ Difference between Flax and PyTorch is 1.0913936421275139e-09 (< 0.01)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 12, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR & very clean! Good job

@sanchit-gandhi sanchit-gandhi merged commit a960406 into huggingface:main Apr 12, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* [FlaxWav2Vec2Model] Fix bug in attention mask

* more fixes

* add (Flax)SpeechEncoderDecoderModel PT-FX cross-test
@sanchit-gandhi sanchit-gandhi deleted the flax-wav2vec2 branch June 25, 2023 09:52
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants