diff --git a/docs/hub/endpoints.md b/docs/hub/endpoints.md index 3690330d2a..ab5d23186d 100644 --- a/docs/hub/endpoints.md +++ b/docs/hub/endpoints.md @@ -14,15 +14,17 @@ We have open endpoints that you can use to retrieve information from the Hub as |------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------| | /api/models GET | Get information from all models in the Hub. You can specify additional parameters to have more specific results. - `search`: Filter based on substrings for repos and their usernames, such as `resnet` or `microsoft` - `author`: Filter models by an author or organization, such as `huggingface` or `microsoft` - `filter`: Filter based on tags, such as `text-classification` or `spacy`. - `sort`: Property to use when sorting, such as `downloads` or `author`. - `direction`: Direction in which to sort, such as `-1` for descending, and anything else for ascending. - `limit`: Limit the number of models fetched. - `full`: Whether to fetch most model data, such as all tags, the files, etc. - `config`: Whether to also fetch the repo config. | `list_models()` | ```params= { "search":"search", "author":"author", "filter":"filter", "sort":"sort", "direction":"direction", "limit":"limit", "full":"full", "config":"config"}``` | | | /api/models/{repo_id} /api/models/{repo_id}/revision/{revision} GET | Get all information for a specific model. | `model_info(repo_id, revision)` | ```headers = { "authorization" : "Bearer $token" }``` | | +| /api/models-tags-by-type GET | Gets all the available model tags hosted in the Hub | `get_model_tags()` | | | | /api/datasets GET | Get information from all datasets in the Hub. You can specify additional parameters to have more specific results. - `search`: Filter based on substrings for repos and their usernames, such as `pets` or `microsoft` - `author`: Filter datasets by an other or organization, such as `huggingface` or `microsoft` - `filter`: Filter based on tags, such as `task_categories:text-classification` or `languages:en`. - `sort`: Property to use when sorting, such as `downloads` or `author`. - `direction`: Direction in which to sort, such as `-1` for descending, and anything else for ascending. - `limit`: Limit the number of datasets fetched. - `full`: Whether to fetch most dataset data, such as all tags, the files, etc. | `list_datasets()` | ```params= { "search":"search", "author":"author", "filter":"filter", "sort":"sort", "direction":"direction", "limit":"limit", "full":"full", "config":"config"}``` | | | /api/datasets/{repo_id} /api/datasets/{repo_id}/revision/{revision} GET | Get all information for a specific dataset. - `full`: Whether to fetch most dataset data, such as all tags, the files, etc. | `dataset_info(repo_id, revision)` | ```headers = { "authorization" : "Bearer $token", "full" : "full" }``` | | +| /api/datasets-tags-by-type GET | Gets all the available dataset tags hosted in the Hub | `get_dataset_tags()` | | | | /api/metrics GET | Get information from all metrics in the Hub. | `list_metrics()` | | | | /api/repos/ls GET ⚠️ deprecated | Get list of all stored files for user or organization. | `list_repos_objs(token, organization)` | ```headers = { "authorization" : "Bearer $token" }``` ```params= { "organization":"organization"}``` | | | /api/repos/create POST | Create a repository. It's a model repo by default. - type: Type of repo (datasets or spaces; model by default). - name: Name of repo. - organization: Name of organization. - - private: Whether the repo is private. | `create_repo()` | ```headers = { authorization : "Bearer $token" }``` ```json= {"type":"type", "name":"name", "organization":"organization", "private":"private"}``` | | | /api/repos/delete DELETE | Delete a repository. It's a model repo by default. - type: Type of repo (datasets or spaces; model by default). - name: Name of repo. - organization: Name of organization. | `delete_repo()` | ```headers = { "authorization" : "Bearer $token" }``` ```json= {"type":"type", "name":"name", "organization":"organization"}``` | | | /api/repos/{type}/{repo_id}/settings PUT | Update repo visibility. | `update_repo_visibility()` | ```headers = { "authorization" : "Bearer $token" }``` ```json= {"private":"private"}``` | | | /api/{type}/{repo_id}/ upload/{revision}/{path_in_repo} POST | Upload a file to a specific repository. | `upload_file()` | ```headers = { "authorization" : "Bearer $token" }``` ```"data"="bytestream"``` | | -| /api/login POST | Login user and obtain authentication token. | `login(username, password)` | ```json = { "username" : "username", "password": "password" }``` +| /api/login POST | Login user and obtain authentication token. | `login(username, password)` | ```json = { "username" : "username", "password": "password" }``` | /api/whoami GET | Get username and organizations the user belongs to. | `whoami(token)` | ```headers = { "authorization" : "Bearer $token" }``` | | | /api/logout POST | Log out user. | `logout(token)` | ```headers = { "authorization" : "Bearer $token" }``` | | diff --git a/src/huggingface_hub/README.md b/src/huggingface_hub/README.md index 7fa5a3706a..050b553f91 100644 --- a/src/huggingface_hub/README.md +++ b/src/huggingface_hub/README.md @@ -121,9 +121,11 @@ With the `HfApi` class there are methods to query models, datasets, and metrics - **Models**: - `list_models()` - `model_info()` + - `get_model_tags()` - **Datasets**: - `list_datasets()` - `dataset_info()` + - `get_dataset_tags()` These lightly wrap around the API Endpoints. Documentation for valid parameters and descriptions can be found [here](https://huggingface.co/docs/hub/endpoints). diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index a019bbe310..b030311a17 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -38,7 +38,9 @@ dataset_info, delete_file, delete_repo, + get_dataset_tags, get_full_repo_name, + get_model_tags, list_datasets, list_metrics, list_models, diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 72e63b23d7..158693e7f7 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -32,6 +32,7 @@ REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES, ) +from .utils.tags import DatasetTags, ModelTags if sys.version_info >= (3, 8): @@ -417,6 +418,22 @@ def set_access_token(access_token: str): def unset_access_token(): erase_from_credential_store(USERNAME_PLACEHOLDER) + def get_model_tags(self) -> ModelTags: + "Gets all valid model tags as a nested namespace object" + path = f"{self.endpoint}/api/models-tags-by-type" + r = requests.get(path) + r.raise_for_status() + d = r.json() + return ModelTags(d) + + def get_dataset_tags(self) -> DatasetTags: + "Gets all valid dataset tags as a nested namespace object" + path = f"{self.endpoint}/api/datasets-tags-by-type" + r = requests.get(path) + r.raise_for_status() + d = r.json() + return DatasetTags(d) + def list_models( self, filter: Union[str, Iterable[str], None] = None, @@ -1176,6 +1193,9 @@ def delete_token(cls): list_metrics = api.list_metrics +get_model_tags = api.get_model_tags +get_dataset_tags = api.get_dataset_tags + create_repo = api.create_repo delete_repo = api.delete_repo update_repo_visibility = api.update_repo_visibility diff --git a/src/huggingface_hub/utils/tags.py b/src/huggingface_hub/utils/tags.py new file mode 100644 index 0000000000..88f3c76ee2 --- /dev/null +++ b/src/huggingface_hub/utils/tags.py @@ -0,0 +1,144 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helpful utility functions and classes in relation to exploring API endpoints +with the aim for a user-friendly interface +""" + + +class AttributeDictionary(dict): + """ + `dict` subclass that also provides access to keys as attributes + + If a key starts with a number, it will exist in the dictionary + but not as an attribute + + Example usage: + + >>> d = AttributeDictionary() + >>> d["test"] = "a" + >>> print(d.test) # prints "a" + + """ + + def __getattr__(self, k): + if k in self: + return self[k] + else: + raise AttributeError(k) + + def __setattr__(self, k, v): + (self.__setitem__, super().__setattr__)[k[0] == "_"](k, v) + + def __delattr__(self, k): + if k in self: + del self[k] + else: + raise AttributeError(k) + + def __dir__(self): + keys = sorted(self.keys()) + keys = [key for key in keys if key.replace("_", "").isalpha()] + return super().__dir__() + keys + + def __repr__(self): + repr_str = "Available Attributes or Keys:\n" + for key in sorted(self.keys()): + repr_str += f" * {key}" + if not key.replace("_", "").isalpha(): + repr_str += " (Key only)" + repr_str += "\n" + return repr_str + + +class GeneralTags(AttributeDictionary): + """ + A namespace object holding all tags, filtered by `keys` + If a tag starts with a number, it will only exist in the dictionary + + Example + >>> a.b.1a # will not work + >>> a.b["1a"] # will work + >>> a["b"]["1a"] # will work + + Args: + tag_dictionary (``dict``): + A dictionary of tags returned from the /api/***-tags-by-type api endpoint + keys (``list``): + A list of keys to unpack the `tag_dictionary` with, such as `["library","language"]` + """ + + def __init__(self, tag_dictionary: dict, keys: list = None): + self._tag_dictionary = tag_dictionary + if keys is None: + keys = list(self._tag_dictionary.keys()) + for key in keys: + self._unpack_and_assign_dictionary(key) + + def _unpack_and_assign_dictionary(self, key: str): + "Assignes nested attributes to `self.key` containing information as an `AttributeDictionary`" + setattr(self, key, AttributeDictionary()) + for item in self._tag_dictionary[key]: + ref = getattr(self, key) + item["label"] = ( + item["label"].replace(" ", "").replace("-", "_").replace(".", "_") + ) + setattr(ref, item["label"], item["id"]) + + +class ModelTags(GeneralTags): + """ + A namespace object holding all available model tags + If a tag starts with a number, it will only exist in the dictionary + + Example + >>> o.dataset.1_5BArabicCorpus # will not work + >>> a.dataset["1_5BArabicCorpus"] # will work + >>> a["dataset"]["1_5BArabicCorpus"] # will work + + Args: + model_tag_dictionary (``dict``): + A dictionary of valid model tags, returned from the /api/models-tags-by-type api endpoint + """ + + def __init__(self, model_tag_dictionary: dict): + keys = ["library", "language", "license", "dataset", "pipeline_tag"] + super().__init__(model_tag_dictionary, keys) + + +class DatasetTags(GeneralTags): + """ + A namespace object holding all available dataset tags + If a tag starts with a number, it will only exist in the dictionary + + Example + >>> o.size_categories.100K>> a.size_categories["100K>> a["size_categories"]["100K 0) + + +class DatasetTagsTest(unittest.TestCase): + @with_production_testing + def test_tags(self): + _api = HfApi() + path = f"{_api.endpoint}/api/datasets-tags-by-type" + r = requests.get(path) + r.raise_for_status() + d = r.json() + o = DatasetTags(d) + for kind in [ + "languages", + "multilinguality", + "language_creators", + "task_categories", + "size_categories", + "benchmark", + "task_ids", + "licenses", + ]: + self.assertTrue(len(getattr(o, kind).keys()) > 0)