Skip to content

Commit

Permalink
ENH Add update metadata to repocard (#844)
Browse files Browse the repository at this point in the history
* add `metadata_update` function

* add tests

* add docstring

* Apply suggestions from code review

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* refactore `_update_metadata_model_index`

* Apply suggestions from code review

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* fix style and imports

* switch to deepcopy everywhere

* load repo in repocard test into tmp folder

* simplify results and metrics checks when updating metadata

* run black

* Apply suggestions from code review

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* fix pyyaml version to work with `sort_keys` kwarg

* don't allow empty commits if file hasn't  changed

* switch order of updates to first check model-index for easier readbility

* expose repocard functions through `__init__`

* fix init

* make style & quality

* revert to for-loop

* Apply suggestions from code review

Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* post suggestion fixes

* add example

* add type to list

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
5 people authored May 9, 2022
1 parent da25c67 commit f6343cb
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_version() -> str:
"filelock",
"requests",
"tqdm",
"pyyaml",
"pyyaml>=5.1",
"typing-extensions>=3.7.4.3", # to be able to import TypeAlias
"importlib_metadata;python_version<'3.8'",
"packaging>=20.9",
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
push_to_hub_keras,
save_pretrained_keras,
)
from .repocard import (
metadata_eval_result,
metadata_load,
metadata_save,
metadata_update,
)
from .repository import Repository
from .snapshot_download import snapshot_download
from .utils import logging
Expand Down
182 changes: 182 additions & 0 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Any, Dict, Optional, Union

import yaml
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.repocard_types import (
ModelIndex,
SingleMetric,
Expand All @@ -13,10 +15,15 @@
SingleResultTask,
)

from .constants import REPOCARD_NAME


# exact same regex as in the Hub server. Please keep in sync.
REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]")

UNIQUE_RESULT_FEATURES = ["dataset", "task"]
UNIQUE_METRIC_FEATURES = ["name", "type"]


def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
content = Path(local_path).read_text()
Expand Down Expand Up @@ -99,3 +106,178 @@ def metadata_eval_result(
model_index, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}
)
return {"model-index": [data]}


def metadata_update(
repo_id: str,
metadata: Dict,
*,
repo_type: str = None,
overwrite: bool = False,
token: str = None,
) -> str:
"""
Updates the metadata in the README.md of a repository on the Hugging Face Hub.
Example:
>>> from huggingface_hub import metadata_update
>>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
... 'results': [{'dataset': {'name': 'ReactionGIF',
... 'type': 'julien-c/reactiongif'},
... 'metrics': [{'name': 'Recall',
... 'type': 'recall',
... 'value': 0.7762102282047272}],
... 'task': {'name': 'Text Classification',
... 'type': 'text-classification'}}]}]}
>>> update_metdata("julien-c/reactiongif-roberta", metadata)
Args:
repo_id (`str`):
The name of the repository.
metadata (`dict`):
A dictionary containing the metadata to be updated.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if updating to a dataset or space,
`None` or `"model"` if updating to a model. Default is `None`.
overwrite (`bool`, *optional*, defaults to `False`):
If set to `True` an existing field can be overwritten, otherwise
attempting to overwrite an existing field will cause an error.
token (`str`, *optional*):
The Hugging Face authentication token.
Returns:
`str`: URL of the commit which updated the card metadata.
"""

filepath = hf_hub_download(
repo_id,
filename=REPOCARD_NAME,
repo_type=repo_type,
use_auth_token=token,
force_download=True,
)
existing_metadata = metadata_load(filepath)

for key in metadata:
# update model index containing the evaluation results
if key == "model-index":
if "model-index" not in existing_metadata:
existing_metadata["model-index"] = metadata["model-index"]
else:
# the model-index contains a list of results as used by PwC but only has one element thus we take the first one
existing_metadata["model-index"][0][
"results"
] = _update_metadata_model_index(
existing_metadata["model-index"][0]["results"],
metadata["model-index"][0]["results"],
overwrite=overwrite,
)
# update all fields except model index
else:
if key in existing_metadata and not overwrite:
if existing_metadata[key] != metadata[key]:
raise ValueError(
f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
)
else:
existing_metadata[key] = metadata[key]

# save and push to hub
metadata_save(filepath, existing_metadata)

return HfApi().upload_file(
path_or_fileobj=filepath,
path_in_repo=REPOCARD_NAME,
repo_id=repo_id,
repo_type=repo_type,
identical_ok=False,
token=token,
)


def _update_metadata_model_index(existing_results, new_results, overwrite=False):
"""
Updates the model-index fields in the metadata. If results with same unique
features exist they are updated, else a new result is appended. Updating existing
values is only possible if `overwrite=True`.
Args:
new_metrics (`List[dict]`):
List of new metadata results.
existing_metrics (`List[dict]`):
List of existing metadata results.
overwrite (`bool`, *optional*, defaults to `False`):
If set to `True`, an existing metric values can be overwritten, otherwise
attempting to overwrite an existing field will cause an error.
Returns:
`list`: List of updated metadata results
"""
for new_result in new_results:
result_found = False
for existing_result_index, existing_result in enumerate(existing_results):
if all(
new_result[feat] == existing_result[feat]
for feat in UNIQUE_RESULT_FEATURES
):
result_found = True
existing_results[existing_result_index][
"metrics"
] = _update_metadata_results_metric(
new_result["metrics"],
existing_result["metrics"],
overwrite=overwrite,
)
if not result_found:
existing_results.append(new_result)
return existing_results


def _update_metadata_results_metric(new_metrics, existing_metrics, overwrite=False):
"""
Updates the metrics list of a result in the metadata. If metrics with same unique
features exist their values are updated, else a new metric is appended. Updating
existing values is only possible if `overwrite=True`.
Args:
new_metrics (`list`):
List of new metrics.
existing_metrics (`list`):
List of existing metrics.
overwrite (`bool`, *optional*, defaults to `False`):
If set to `True`, an existing metric values can be overwritten, otherwise
attempting to overwrite an existing field will cause an error.
Returns:
`list`: List of updated metrics
"""
for new_metric in new_metrics:
metric_exists = False
for existing_metric_index, existing_metric in enumerate(existing_metrics):
if all(
new_metric[feat] == existing_metric[feat]
for feat in UNIQUE_METRIC_FEATURES
):
if overwrite:
existing_metrics[existing_metric_index]["value"] = new_metric[
"value"
]
else:
# if metric exists and value is not the same throw an error without overwrite flag
if (
existing_metrics[existing_metric_index]["value"]
!= new_metric["value"]
):
existing_str = ", ".join(
f"{feat}: {new_metric[feat]}"
for feat in UNIQUE_METRIC_FEATURES
)
raise ValueError(
"You passed a new value for the existing metric"
f" '{existing_str}'. Set `overwrite=True` to overwrite"
" existing metrics."
)
metric_exists = True
if not metric_exists:
existing_metrics.append(new_metric)
return existing_metrics
Loading

0 comments on commit f6343cb

Please # to comment.