Skip to content

Commit

Permalink
Merge pull request #8 from sanchit-gandhi/flax-train-script
Browse files Browse the repository at this point in the history
[FlaxSpeechEncoderDecoder] Fix input shape bug in weights init
  • Loading branch information
sanchit-gandhi authored Apr 14, 2022
2 parents 564e72a + 8f37a7a commit db3fe85
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions models/modeling_flax_speech_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,15 @@ def setup(self):
else:
self.enc_to_dec_proj = None

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -236,6 +240,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.encoder.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)

return input_lengths

def _get_encoder_module(self):
Expand Down Expand Up @@ -429,8 +437,10 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_
)
return unfreeze(init_variables["cache"])

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
return self.module._get_feat_extract_output_lengths(input_lengths)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)

@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
Expand Down

0 comments on commit db3fe85

Please # to comment.