Skip to content

Commit

Permalink
Add defensive check for config num_labels and id2label (huggingface#1…
Browse files Browse the repository at this point in the history
…6709)

* Add defensive check for config num_labels and id2label

* Actually check value...

* Only warning inside init plus better error message
  • Loading branch information
sgugger authored and elusenji committed Jun 12, 2022
1 parent 30d1f23 commit e073d2c
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ def __init__(self, **kwargs):
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
kwargs.pop("num_labels", None)
num_labels = kwargs.pop("num_labels", None)
if num_labels is not None and len(self.id2label) != num_labels:
logger.warn(
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
)
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
# Keys are always strings in JSON so convert ids to int here.
else:
Expand Down Expand Up @@ -678,6 +683,15 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())

# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
if len(id2label) != num_labels:
raise ValueError(
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
"one of them."
)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
Expand Down

0 comments on commit e073d2c

Please # to comment.