Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[FlaxWav2Vec2Model] Fix bug in attention mask #9

Merged
merged 1 commit into from
Apr 14, 2022

Conversation

sanchit-gandhi
Copy link
Owner

Mirrors huggingface/transformers#16725

Currently, the FlaxWav2Vec2 reduced attention mask is computed by calling the function _get_feat_extract_output_lengths, without explicit specification of whether an (optional) adapter module is used:
https://github.com/huggingface/transformers/blob/924484ee4a6ebc79426d27eef31a1ee7d13cbb9a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py#L959-L960
By default, if add_adapter is None, the boolean add_adapter will be set based on the config:
https://github.com/huggingface/transformers/blob/924484ee4a6ebc79426d27eef31a1ee7d13cbb9a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py#L1001-L1008
For this default setting, if the model contains an adapter module, then add_adapter will be set to True. This results in the convolutional formula including the downsampling performed by the convolutional layers in the feature extractor and the adapter module.

However, since the reduced attention mask is required for the encoder module, it should be computed based on the convolutional layers of the feature extractor only, and not those of the subsequent adapter module. This is highlighted by the PyTorch Wav2Vec2 modelling code:
https://github.com/huggingface/transformers/blob/924484ee4a6ebc79426d27eef31a1ee7d13cbb9a/src/transformers/models/wav2vec2/modeling_wav2vec2.py#L1350-L1354

The following code snippet demonstrates the effect of this bug by means of a PyTorch-Flax cross-test:

import torch
import numpy as np
from transformers import Wav2Vec2Model, FlaxWav2Vec2Model
import tempfile
import random

encoder_id = "hf-internal-testing/tiny-random-wav2vec2"

fx_model = FlaxWav2Vec2Model.from_pretrained(encoder_id, add_adapter=True, from_pt=True)

with tempfile.TemporaryDirectory() as tmpdirname:
    fx_model.save_pretrained(tmpdirname)
    pt_model = Wav2Vec2Model.from_pretrained(tmpdirname, config=fx_model.config, from_flax=True)


# create synthetic input data
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values).reshape(shape)

    return output


def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


def floats_tensor(shape, scale=1.0):
    """Creates a random float32 tensor"""
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randn() * scale)

    return np.array(values, dtype=np.float32).reshape(shape)


def fx_batch(batch_size=2, input_length=96000):
    input_ids = floats_tensor([batch_size, input_length])
    attention_mask = random_attention_mask([batch_size, input_length])
    
    fx_inputs = {
        "input_values": input_ids,
        "attention_mask": attention_mask,
    }
    return fx_inputs


fx_inputs = fx_batch()
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}

fx_outputs = fx_model( **fx_inputs, output_hidden_states=True)
pt_outputs = pt_model(**pt_inputs, output_hidden_states=True)

# helper function for our analysis
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
    diff = np.abs((a - b)).max()
    if diff < tol:
        print(f"✅ Difference between Flax and PyTorch is {diff} (< {tol})")
    else:
        print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")


print("--------------------------Checking hidden states match--------------------------")
for fx_state, pt_state in zip(fx_outputs.hidden_states, pt_outputs.hidden_states):
    assert fx_state.shape == pt_state.shape
    assert_almost_equals(fx_state, pt_state.detach().numpy())


print("--------------------------Checking last hidden states match--------------------------")
print(f"Encoder-decoder output shape: {fx_outputs.last_hidden_state.shape}, encoder-only output shape: {pt_outputs.last_hidden_state.shape}")
assert_almost_equals(fx_outputs.last_hidden_state, pt_outputs.last_hidden_state.detach().numpy())

Output prior to fix:

--------------------------Checking encoder hidden states match--------------------------
❌ Difference between Flax and PyTorch is 0.43152332305908203 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.43074753880500793 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.42613524198532104 (>= 0.01)
❌ Difference between Flax and PyTorch is 0.4301084578037262 (>= 0.01)
❌ Difference between Flax and PyTorch is 4.519614219665527 (>= 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 188, 16), encoder-only output shape: torch.Size([2, 188, 16])
✅ Difference between Flax and PyTorch is 0.0015139428433030844 (< 0.01)

Output following fix:

--------------------------Checking encoder hidden states match--------------------------
✅ Difference between Flax and PyTorch is 3.9674341678619385e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.041939973831177e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.041939973831177e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 3.948807716369629e-07 (< 0.01)
✅ Difference between Flax and PyTorch is 4.947185516357422e-06 (< 0.01)
--------------------------Checking encoder last hidden states match--------------------------
Encoder-decoder output shape: (2, 188, 16), encoder-only output shape: torch.Size([2, 188, 16])
✅ Difference between Flax and PyTorch is 1.0913936421275139e-09 (< 0.01)

@sanchit-gandhi sanchit-gandhi merged commit b8f5570 into main Apr 14, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant