diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f7318bf8ab84bb..9b8b6e70fd28ed 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -304,7 +304,12 @@ def __init__(self, **kwargs): self.id2label = kwargs.pop("id2label", None) self.label2id = kwargs.pop("label2id", None) if self.id2label is not None: - kwargs.pop("num_labels", None) + num_labels = kwargs.pop("num_labels", None) + if num_labels is not None and len(self.id2label) != num_labels: + logger.warn( + f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " + f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." + ) self.id2label = dict((int(key), value) for key, value in self.id2label.items()) # Keys are always strings in JSON so convert ids to int here. else: @@ -678,6 +683,15 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them." + ) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key):