diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index 6b6564ad..eeea6d0e 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -1474,7 +1474,7 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs): original_config_dict.update(kwargs) dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict)) language_model_class = cls.get_language_model_class(farm_lm_config) - dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)) + dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model dpr_question_encoder.language = dpr_question_encoder.model.config.language else: original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path) @@ -1589,7 +1589,7 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs): dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**original_config_dict)) language_model_class = cls.get_language_model_class(farm_lm_config) dpr_context_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load( - str(pretrained_model_name_or_path)) + str(pretrained_model_name_or_path)).model dpr_context_encoder.language = dpr_context_encoder.model.config.language else: