From e073d2c80a3f61233a427cc8e6eef9e505e8351e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 13 Apr 2022 11:28:19 -0400 Subject: [PATCH] Add defensive check for config num_labels and id2label (#16709) * Add defensive check for config num_labels and id2label * Actually check value... * Only warning inside init plus better error message --- src/transformers/configuration_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f7318bf8ab84bb..9b8b6e70fd28ed 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -304,7 +304,12 @@ def __init__(self, **kwargs): self.id2label = kwargs.pop("id2label", None) self.label2id = kwargs.pop("label2id", None) if self.id2label is not None: - kwargs.pop("num_labels", None) + num_labels = kwargs.pop("num_labels", None) + if num_labels is not None and len(self.id2label) != num_labels: + logger.warn( + f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " + f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." + ) self.id2label = dict((int(key), value) for key, value in self.id2label.items()) # Keys are always strings in JSON so convert ids to int here. else: @@ -678,6 +683,15 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them." + ) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key):