Skip to content

Commit

Permalink
Assign transformers model as dpr_encoder...bert_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Timoeller committed Jun 2, 2021
1 parent 1a36345 commit 399f502
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
original_config_dict.update(kwargs)
dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict))
language_model_class = cls.get_language_model_class(farm_lm_config)
dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path))
dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model
dpr_question_encoder.language = dpr_question_encoder.model.config.language
else:
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
Expand Down Expand Up @@ -1589,7 +1589,7 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**original_config_dict))
language_model_class = cls.get_language_model_class(farm_lm_config)
dpr_context_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(
str(pretrained_model_name_or_path))
str(pretrained_model_name_or_path)).model
dpr_context_encoder.language = dpr_context_encoder.model.config.language

else:
Expand Down

0 comments on commit 399f502

Please # to comment.