diff --git a/models/modeling_flax_speech_encoder_decoder.py b/models/modeling_flax_speech_encoder_decoder.py index 0aa30f6..f546c5e 100644 --- a/models/modeling_flax_speech_encoder_decoder.py +++ b/models/modeling_flax_speech_encoder_decoder.py @@ -223,11 +223,15 @@ def setup(self): else: self.enc_to_dec_proj = None - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): """ Computes the output length of the convolutional layers """ + add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter + def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html @@ -236,6 +240,10 @@ def _conv_out_length(input_length, kernel_size, stride): for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + if add_adapter: + for _ in range(self.config.encoder.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) + return input_lengths def _get_encoder_module(self): @@ -429,8 +437,10 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_ ) return unfreeze(init_variables["cache"]) - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): - return self.module._get_feat_extract_output_lengths(input_lengths) + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)