From 62cb855efe8390806448dd2e1e39ee6b1c0ca6f1 Mon Sep 17 00:00:00 2001 From: Francois Ledoyen Date: Mon, 24 Feb 2025 10:59:35 +0100 Subject: [PATCH] fix: generate encoder decoder model --- src/adapters/model_mixin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 01571d358..70369678b 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1730,11 +1730,12 @@ def _prepare_encoder_decoder_kwargs_for_generation( } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + forward_context_args = ["adapter_input_parallelized", "task_ids"] if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() - if argument in encoder_signature or argument == "adapter_input_parallelized" + if argument in encoder_signature or argument in forward_context_args } encoder_kwargs["output_attentions"] = generation_config.output_attentions encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states @@ -1744,7 +1745,9 @@ def _prepare_encoder_decoder_kwargs_for_generation( encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor with ForwardContext(self, **encoder_kwargs): - encoder_kwargs.pop("adapter_input_parallelized", None) # This should not be passed to actual model + for arg_name in forward_context_args: + encoder_kwargs.pop(arg_name, None) # This should not be passed to actual model + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) return model_kwargs