Skip to content

Commit

Permalink
Workaround transformers overwriting model_type when saving dpr models (
Browse files Browse the repository at this point in the history
…#765)

* Workaround transformers overwriting model_type when saving dpr models

* Added test for saving/loading camembert dpr model

* Setting base_model.bert_model for DPR models loaded in FARM style

* Fix loading dpr model with standard bert

* Removed whitespace from test case question

* Renaming names of model weights when saving dpr models

* Assign transformers model as dpr_encoder...bert_model

* Rename save directory to prevent two tests using same directory

* Fix loading DPR with standard BERT models

* Extending dpr test case to different models

* Adjust names of model weights only if non-standard BERT DPR model

* DPREncoder classes handle renaming of model weights instead of LanguageModel parent class

Co-authored-by: Timo Moeller <timo.moeller@deepset.ai>
  • Loading branch information
julian-risch and Timoeller authored Jun 9, 2021
1 parent 1f1fe4c commit 84f67e0
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 7 deletions.
101 changes: 95 additions & 6 deletions farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,31 @@ def save_config(self, save_dir):
with open(save_filename, "w") as file:
setattr(self.model.config, "name", self.__class__.__name__)
setattr(self.model.config, "language", self.language)
# For DPR models, transformers overwrites the model_type with the one set in DPRConfig
# Therefore, we copy the model_type from the model config to DPRConfig
if self.__class__.__name__ == "DPRQuestionEncoder" or self.__class__.__name__ == "DPRContextEncoder":
setattr(transformers.DPRConfig, "model_type", self.model.config.model_type)
string = self.model.config.to_json_string()
file.write(string)

def save(self, save_dir):
def save(self, save_dir, state_dict=None):
"""
Save the model state_dict and its config file so that it can be loaded again.
:param save_dir: The directory in which the model should be saved.
:type save_dir: str
:param state_dict: A dictionary containing a whole state of the module including names of layers. By default, the unchanged state dict of the module is used
:type state_dict: dict
"""
# Save Weights
save_name = Path(save_dir) / "language_model.bin"
model_to_save = (
self.model.module if hasattr(self.model, "module") else self.model
) # Only save the model it-self
torch.save(model_to_save.state_dict(), save_name)

if not state_dict:
state_dict = model_to_save.state_dict()
torch.save(state_dict, save_name)
self.save_config(save_dir)

@classmethod
Expand Down Expand Up @@ -1442,9 +1451,21 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
if os.path.exists(farm_lm_config):
# FARM style
dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
original_model_config = AutoConfig.from_pretrained(farm_lm_config)
farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained(farm_lm_model, config=dpr_config, **kwargs)

if original_model_config.model_type == "dpr":
dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained(farm_lm_model, config=dpr_config, **kwargs)
else:
if original_model_config.model_type != "bert":
logger.warning(f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders."
f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors.")
original_config_dict = vars(original_model_config)
original_config_dict.update(kwargs)
dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict))
language_model_class = cls.get_language_model_class(farm_lm_config)
dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model
dpr_question_encoder.language = dpr_question_encoder.model.config.language
else:
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
Expand All @@ -1468,6 +1489,32 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):

return dpr_question_encoder

def save(self, save_dir, state_dict=None):
"""
Save the model state_dict and its config file so that it can be loaded again.
:param save_dir: The directory in which the model should be saved.
:type save_dir: str
:param state_dict: A dictionary containing a whole state of the module including names of layers. By default, the unchanged state dict of the module is used
:type state_dict: Optional[dict]
"""
model_to_save = (
self.model.module if hasattr(self.model, "module") else self.model
) # Only save the model it-self

if self.model.config.model_type != "dpr" and model_to_save.base_model_prefix.startswith("question_"):
state_dict = model_to_save.state_dict()
keys = state_dict.keys()
for key in list(keys):
new_key = key
if key.startswith("question_encoder.bert_model.model."):
new_key = key.split("_encoder.bert_model.model.", 1)[1]
elif key.startswith("question_encoder.bert_model."):
new_key = key.split("_encoder.bert_model.", 1)[1]
state_dict[new_key] = state_dict.pop(key)

super(DPRQuestionEncoder, self).save(save_dir=save_dir, state_dict=state_dict)

def forward(
self,
query_input_ids,
Expand Down Expand Up @@ -1540,12 +1587,28 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
dpr_context_encoder.name = pretrained_model_name_or_path
# We need to differentiate between loading model using FARM format and Pytorch-Transformers format
farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"

if os.path.exists(farm_lm_config):
# FARM style
dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
original_model_config = AutoConfig.from_pretrained(farm_lm_config)
farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained(farm_lm_model, config=dpr_config, **kwargs)

if original_model_config.model_type == "dpr":
dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained(farm_lm_model,config=dpr_config,**kwargs)
else:
if original_model_config.model_type != "bert":
logger.warning(
f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders."
f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors.")
original_config_dict = vars(original_model_config)
original_config_dict.update(kwargs)
dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**original_config_dict))
language_model_class = cls.get_language_model_class(farm_lm_config)
dpr_context_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(
str(pretrained_model_name_or_path)).model
dpr_context_encoder.language = dpr_context_encoder.model.config.language

else:
# Pytorch-transformer Style
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
Expand All @@ -1571,6 +1634,32 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):

return dpr_context_encoder

def save(self, save_dir, state_dict=None):
"""
Save the model state_dict and its config file so that it can be loaded again.
:param save_dir: The directory in which the model should be saved.
:type save_dir: str
:param state_dict: A dictionary containing a whole state of the module including names of layers. By default, the unchanged state dict of the module is used
:type state_dict: Optional[dict]
"""
model_to_save = (
self.model.module if hasattr(self.model, "module") else self.model
) # Only save the model it-self

if self.model.config.model_type != "dpr" and model_to_save.base_model_prefix.startswith("ctx_"):
state_dict = model_to_save.state_dict()
keys = state_dict.keys()
for key in list(keys):
new_key = key
if key.startswith("ctx_encoder.bert_model.model."):
new_key = key.split("_encoder.bert_model.model.", 1)[1]
elif key.startswith("ctx_encoder.bert_model."):
new_key = key.split("_encoder.bert_model.", 1)[1]
state_dict[new_key] = state_dict.pop(key)

super(DPRContextEncoder, self).save(save_dir=save_dir, state_dict=state_dict)

def forward(
self,
passage_input_ids,
Expand Down
Loading

0 comments on commit 84f67e0

Please # to comment.