From 589fb9e2e535b4897954d472b2f2ee16fc8f0b17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20M=C3=BCller?= Date: Fri, 4 Mar 2022 10:46:49 +0100 Subject: [PATCH] add BOBSL data set --- .gitignore | 2 +- README.md | 1 + requirements.txt | 2 + setup.py | 5 +- sign_language_datasets/datasets/__init__.py | 1 + .../datasets/bobsl/__init__.py | 3 + .../datasets/bobsl/bobsl.py | 365 ++++++++++++++++++ .../datasets/bobsl/bobsl_test.py | 24 ++ .../datasets/bobsl/checksums.tsv | 3 + .../datasets/bobsl/create_index.py | 121 ++++++ .../TODO-add_fake_data_in_this_directory.txt | 0 .../datasets/bobsl/openpose.poseheader | Bin 0 -> 1919 bytes .../datasets/ngt_corpus/create_index.py | 2 +- .../utils/downloaders/download_auth.py | 151 ++++++++ 14 files changed, 676 insertions(+), 4 deletions(-) create mode 100644 sign_language_datasets/datasets/bobsl/__init__.py create mode 100644 sign_language_datasets/datasets/bobsl/bobsl.py create mode 100644 sign_language_datasets/datasets/bobsl/bobsl_test.py create mode 100644 sign_language_datasets/datasets/bobsl/checksums.tsv create mode 100644 sign_language_datasets/datasets/bobsl/create_index.py create mode 100644 sign_language_datasets/datasets/bobsl/dummy_data/TODO-add_fake_data_in_this_directory.txt create mode 100644 sign_language_datasets/datasets/bobsl/openpose.poseheader create mode 100644 sign_language_datasets/utils/downloaders/download_auth.py diff --git a/.gitignore b/.gitignore index 6b9b62d..09355ac 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ .idea/ old/ __pycache__ - +.coverage sign_language_datasets/datasets/msasl sign_language_datasets/datasets/ncslgr .coverage diff --git a/README.md b/README.md index 5aa54b4..dbe7839 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ rwth_phoenix2014_t = tfds.load(name='rwth_phoenix2014_t', builder_kwargs=dict(co | Video-Based CSL | | | None | | RVL-SLLL ASL | | | None | | ngt_corpus | Yes | | 3.0.0 | +| bobsl | Yes | OpenPose | 1.0.0 | ## Data Interface diff --git a/requirements.txt b/requirements.txt index 8538ce4..2ec3464 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ numpy pytest pytest-cov pympi-ling +requests +lxml diff --git a/setup.py b/setup.py index a808a7a..aede693 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,14 @@ setup( name="sign-language-datasets", packages=packages, - version="0.0.4", + version="0.0.5", description="TFDS Datasets for sign language", author="Amit Moryossef", author_email="amitmoryossef@gmail.com", url="https://github.com/sign-language-processing/datasets", keywords=[], - install_requires=["python-dotenv", "tqdm", "pose-format", "tfds-nightly", "tensorflow", "numpy", "pympi-ling"], + install_requires=["python-dotenv", "tqdm", "pose-format", "tfds-nightly", "tensorflow", "numpy", "pympi-ling", + "requests", "lxml"], tests_require=['pytest', 'pytest-cov'], long_description=long_description, long_description_content_type="text/markdown", diff --git a/sign_language_datasets/datasets/__init__.py b/sign_language_datasets/datasets/__init__.py index e22851c..86290e9 100644 --- a/sign_language_datasets/datasets/__init__.py +++ b/sign_language_datasets/datasets/__init__.py @@ -12,3 +12,4 @@ from .swojs_glossario import SwojsGlossario from .wlasl import Wlasl from .ngt_corpus import NGTCorpus +from .bobsl import Bobsl diff --git a/sign_language_datasets/datasets/bobsl/__init__.py b/sign_language_datasets/datasets/bobsl/__init__.py new file mode 100644 index 0000000..6e3d8eb --- /dev/null +++ b/sign_language_datasets/datasets/bobsl/__init__.py @@ -0,0 +1,3 @@ +"""bobsl dataset.""" + +from .bobsl import Bobsl diff --git a/sign_language_datasets/datasets/bobsl/bobsl.py b/sign_language_datasets/datasets/bobsl/bobsl.py new file mode 100644 index 0000000..7cfd304 --- /dev/null +++ b/sign_language_datasets/datasets/bobsl/bobsl.py @@ -0,0 +1,365 @@ +"""bobsl dataset.""" +import os +import json +import requests + +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets.core import download +from tensorflow_datasets.core import utils + +from pose_format.utils.openpose import load_openpose_directory + +from ...datasets.config import SignDatasetConfig +from ...utils.features import PoseFeature +from ...utils.downloaders import download_auth + +_DESCRIPTION = """BOBSL is a large-scale dataset of British Sign Language (BSL).""" + +_CITATION = """@InProceedings{Albanie2021bobsl, + author = "Samuel Albanie and G{\"u}l Varol and Liliane Momeni and Hannah Bull and Triantafyllos Afouras + and Himel Chowdhury and Neil Fox and Bencie Woll and Rob Cooper and Andrew McParland and Andrew Zisserman", + title = "{BOBSL}: {BBC}-{O}xford {B}ritish {S}ign {L}anguage {D}ataset", + howpublished = "\\url{https://www.robots.ox.ac.uk/~vgg/data/bobsl}", + year = "2021", +} +""" + +_HOMEPAGE = "https://www.robots.ox.ac.uk/~vgg/data/bobsl/" + +INDEX_URL = "https://files.ifi.uzh.ch/cl/archiv/2022/easier/bobsl.json" + +_FRAMERATE = 25 + +_VIDEO_RESOLUTION = (444, 444) # (width, height) + +_OPENPOSE_HEADER = os.path.join(os.path.dirname(os.path.realpath(__file__)), "openpose.poseheader") + + +def _add_subtitles_to_index(index_data: dict, filename: str, folder: str, subtitle_alignment_method: str) -> dict: + """ + + :param index_data: + :param filename: + :param folder: + :param subtitle_alignment_method: + :return: + """ + assert ".vtt" in filename + example_id = filename.replace(".vtt", "") + filepath = os.path.join(folder, filename) + + if example_id not in index_data.keys(): + index_data[example_id] = {} + + index_data[example_id]["subtitles"] = filepath + index_data[example_id]["subtitle_alignment_method"] = subtitle_alignment_method + + return index_data + + +def _walk_subtitles_and_add_to_index(index_data: dict, extracted_path: str) -> dict: + """ + + :param index_data: + :param extracted_path: A local path returned by a tfds download manager. + :return: + """ + manually_aligned_folder = os.path.join(extracted_path, "subtitles", "manually-aligned") + audio_aligned_folder = os.path.join(extracted_path, "subtitles", "audio-aligned") + + for filename in os.listdir(manually_aligned_folder): + index_data = _add_subtitles_to_index(index_data=index_data, filename=filename, folder=manually_aligned_folder, + subtitle_alignment_method="manual") + + for filename in os.listdir(audio_aligned_folder): + index_data = _add_subtitles_to_index(index_data=index_data, filename=filename, folder=audio_aligned_folder, + subtitle_alignment_method="audio") + + return index_data + + +def _add_spottings_to_index(index_data: dict, extracted_path: str) -> dict: + """ + Structure of folder: + + spottings + |- attention_spottings.json + |- dict_spottings.json + |- mouthings.json + + Each JSON file has the following structure: + - Outermost keys, json_dict.keys(): ['train', 'public_test', 'val'] + - First 5 keys of json_dict["train"]: ['aachen', 'aardvark', 'aaron', 'ab', 'aba'] + - Keys of json_dict["train"]["aachen"]: ['global_times', 'names', 'probs'] + + `global_times` corresponds to times in seconds in a video. `names` has dataset example ids. `probs` has + one probability for each spotting. + + :param index_data: + :param extracted_path: + :return: + """ + paths = {"spottings_attention": os.path.join(extracted_path, "spottings", "attention_spottings.json"), + "spottings_dict": os.path.join(extracted_path, "spottings", "dict_spottings.json"), + "spottings_mouthings": os.path.join(extracted_path, "spottings", "mouthings.json")} + + for spottings_type, spottings_path in paths.items(): + + with open(spottings_path, "r") as infile: + json_dict = json.load(infile) + + for split, gloss_dict in json_dict.items(): + for gloss_key, gloss_value_dict in gloss_dict.items(): + global_times = gloss_value_dict["global_times"] + names = gloss_value_dict["names"] + probs = gloss_value_dict["probs"] + + for global_time, name, prob in zip(global_times, names, probs): + assert name in index_data.keys() + + if spottings_type not in index_data[name].keys(): + index_data[name][spottings_type] = [] + index_data[name][spottings_type].append({"global_time": global_time, + "prob": prob, + "gloss": gloss_key}) + + return index_data + + +def _download_and_maybe_extract(index_data: dict, + dl_manager: download_auth.DownloadManagerWithAuth) -> dict: + """ + + :param index_data: + :param dl_manager: + :return: + """ + urls_to_download = {} + urls_to_download_and_extract = {} + + for datum in index_data.values(): + for url in datum.values(): + if url.endswith(".tar.gz"): + urls_to_download_and_extract[url] = url + else: + urls_to_download[url] = url + + if urls_to_download: + local_paths_downloaded = dl_manager.download(urls_to_download) + else: + local_paths_downloaded = {} + + if urls_to_download_and_extract: + local_paths_downloaded_and_extracted = dl_manager.download_and_extract(urls_to_download_and_extract) + else: + local_paths_downloaded_and_extracted = {} + + local_paths = {**local_paths_downloaded, **local_paths_downloaded_and_extracted} + + return local_paths + + +class Bobsl(tfds.core.GeneratorBasedBuilder): + """DatasetBuilder for bobsl dataset.""" + + VERSION = tfds.core.Version('1.0.0') + RELEASE_NOTES = { + '1.0.0': 'Initial release.', + } + + BUILDER_CONFIGS = [ + SignDatasetConfig(name="default", include_video=True, include_pose="openpose"), + SignDatasetConfig(name="annotations", include_video=False, include_pose=None), + SignDatasetConfig(name="videos", include_video=True, include_pose=None), + SignDatasetConfig(name="openpose", include_video=False, include_pose="openpose"), + ] + + def __init__(self, bobsl_username: str, bobsl_password: str, **kwargs): + + super(Bobsl, self).__init__(**kwargs) + + self.bobsl_username = bobsl_username + self.bobsl_password = bobsl_password + + def _make_download_manager(self, download_dir, download_config): + """Creates a new download manager object.""" + download_dir = ( + download_dir or os.path.join(self._data_dir_root, "downloads")) + extract_dir = ( + download_config.extract_dir or os.path.join(download_dir, "extracted")) + manual_dir = ( + download_config.manual_dir or os.path.join(download_dir, "manual")) + + if download_config.register_checksums: + # Note: Error will be raised here if user try to record checksums + # from a `zipapp` + # noinspection PyTypeChecker + register_checksums_path = utils.to_write_path(self._checksums_path) + else: + register_checksums_path = None + + return download_auth.DownloadManagerWithAuth( + download_dir=download_dir, + extract_dir=extract_dir, + manual_dir=manual_dir, + url_infos=self.url_infos, + manual_dir_instructions=self.MANUAL_DOWNLOAD_INSTRUCTIONS, + force_download=(download_config.download_mode == download.GenerateMode.FORCE_REDOWNLOAD), + force_extraction=(download_config.download_mode == download.GenerateMode.FORCE_REDOWNLOAD), + force_checksums_validation=download_config.force_checksums_validation, + register_checksums=download_config.register_checksums, + register_checksums_path=register_checksums_path, + verify_ssl=download_config.verify_ssl, + dataset_name=self.name, + username=self.bobsl_username, + password=self.bobsl_password + ) + + def _info(self) -> tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + + spottings_feature_dict = {"global_time": tf.float32, + "prob": tf.float32, + "gloss": tfds.features.Text()} + + spottings_feature_sequence = tfds.features.Sequence(spottings_feature_dict, length=None) + + features = { + "id": tfds.features.Text(), + "paths": { + "subtitles": tfds.features.Text(), + }, + "subtitle_alignment_method": tfds.features.Text(), + "spottings": {"spottings_attention": spottings_feature_sequence, + "spottings_dict": spottings_feature_sequence, + "spottings_mouthings": spottings_feature_sequence} + } + + # add video features if requested + if self._builder_config.include_video: + features["fps"] = tf.int32 + features["paths"]["video"] = tfds.features.Text() + + if self._builder_config.process_video: + features["video"] = self._builder_config.video_feature(_VIDEO_RESOLUTION) + + # add pose features if requested + if self._builder_config.include_pose == "holistic": + raise NotImplementedError("Holistic poses are currently not available for the BOBSL corpus.") + elif self._builder_config.include_pose == "openpose": + stride = 1 if self._builder_config.fps is None else _FRAMERATE / self._builder_config.fps + pose_shape = (None, 1, 137, 2) + + features["poses"] = PoseFeature(shape=pose_shape, stride=stride, header_path=_OPENPOSE_HEADER) + + return tfds.core.DatasetInfo( + builder=self, + description=_DESCRIPTION, + features=tfds.features.FeaturesDict(features), + homepage=_HOMEPAGE, + supervised_keys=None, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager: download_auth.DownloadManagerWithAuth): + """Returns SplitGenerators.""" + + # download index without dl_manager, since it would try to authenticate + index_content = requests.get(INDEX_URL).content.decode("utf-8") + index_data = json.loads(index_content) + + # save tar.gz urls, then delete from index + subtitles_url = index_data["subtitles"] + spottings_url = index_data["spottings"] + + del index_data["subtitles"] + del index_data["spottings"] + + # Don't download videos if not necessary + if not self._builder_config.include_video: + for datum in index_data.values(): + del datum["video"] + + # Never download flows at the moment + for datum in index_data.values(): + del datum["flow"] + + # Don't download poses if not necessary + if self._builder_config.include_pose != "openpose": + for datum in index_data.values(): + del datum["openpose"] + + # download or download-and-extract, depending on file type + local_paths = _download_and_maybe_extract(index_data=index_data, dl_manager=dl_manager) + + processed_data = {} + + for _id, datum in index_data.items(): + processed_data[_id] = {} + for key, url in datum.items(): + processed_data[_id][key] = local_paths[url] + + # download and extract subtitles, add to local paths + subtitles_extracted_path = dl_manager.download_and_extract(subtitles_url) + + processed_data = _walk_subtitles_and_add_to_index(index_data=processed_data, + extracted_path=subtitles_extracted_path) + + # download and extract spottings, add to local paths + spottings_extracted_path = dl_manager.download_and_extract(spottings_url) + + processed_data = _add_spottings_to_index(index_data=processed_data, + extracted_path=spottings_extracted_path) + + one_example = {"5085344787448740525": processed_data["5085344787448740525"]} + + with open("processed_data.json", "w") as outfile: + json.dump(one_example, outfile) + + return [tfds.core.SplitGenerator(name=tfds.Split.TRAIN, gen_kwargs={"data": processed_data})] + + def _generate_examples(self, data): + """ Yields examples. """ + + def _return_dict_value_or_empty_list(datum: dict, dict_key: str) -> list: + """ + + :param datum: + :param dict_key: + :return: + """ + if dict_key in datum.keys(): + return datum[dict_key] + else: + return [] + + for _id, datum in list(data.items()): + features = { + "id": _id, + "paths": {"subtitles": str(datum["subtitles"])}, + "subtitle_alignment_method": datum["subtitle_alignment_method"], + "spottings": {"spottings_attention": _return_dict_value_or_empty_list(datum, "spottings_attention"), + "spottings_dict": _return_dict_value_or_empty_list(datum, "spottings_dict"), + "spottings_mouthings": _return_dict_value_or_empty_list(datum, "spottings_mouthings")} + } + + if self._builder_config.include_video: + + features["fps"] = self._builder_config.fps if self._builder_config.fps is not None else _FRAMERATE + features["paths"]["video"] = datum["video"] + if self._builder_config.process_video: + features["video"] = datum["video"] + + if self._builder_config.include_pose == "openpose": + features["poses"] = load_openpose_directory(directory=datum["openpose"], + fps=_FRAMERATE, + width=_VIDEO_RESOLUTION[0], + height=_VIDEO_RESOLUTION[1], + depth=0, + num_frames=None) + + if self._builder_config.include_pose == "holistic": + raise NotImplementedError("Holistic poses are currently not available for the BOBSL corpus.") + + yield _id, features diff --git a/sign_language_datasets/datasets/bobsl/bobsl_test.py b/sign_language_datasets/datasets/bobsl/bobsl_test.py new file mode 100644 index 0000000..a561619 --- /dev/null +++ b/sign_language_datasets/datasets/bobsl/bobsl_test.py @@ -0,0 +1,24 @@ +"""bobsl dataset.""" + +import tensorflow_datasets as tfds +from . import bobsl + + +class BobslTest(tfds.testing.DatasetBuilderTestCase): + """Tests for bobsl dataset.""" + # TODO(bobsl): + DATASET_CLASS = bobsl.Bobsl + SPLITS = { + 'train': 3, # Number of fake train example + 'test': 1, # Number of fake test example + } + + # If you are calling `download/download_and_extract` with a dict, like: + # dl_manager.download({'some_key': 'http://a.org/out.txt', ...}) + # then the tests needs to provide the fake output paths relative to the + # fake data directory + # DL_EXTRACT_RESULT = {'some_key': 'output_file1.txt', ...} + + +if __name__ == '__main__': + tfds.testing.test_main() diff --git a/sign_language_datasets/datasets/bobsl/checksums.tsv b/sign_language_datasets/datasets/bobsl/checksums.tsv new file mode 100644 index 0000000..8ba0b5e --- /dev/null +++ b/sign_language_datasets/datasets/bobsl/checksums.tsv @@ -0,0 +1,3 @@ +# TODO(bobsl): If your dataset downloads files, then the checksums +# will be automatically added here when running +# `tfds build --register_checksums`. diff --git a/sign_language_datasets/datasets/bobsl/create_index.py b/sign_language_datasets/datasets/bobsl/create_index.py new file mode 100644 index 0000000..c3af63a --- /dev/null +++ b/sign_language_datasets/datasets/bobsl/create_index.py @@ -0,0 +1,121 @@ +""" +Helper file to crawl The BOBSL Corpus and create an up-to-date index of the dataset. + +This script will index all individual files for videos, pose and flow. For subtitles and spottings, it will +only indicate a single *.tar.gz file that contains all files for all dataset examples. + +The environment variables BOBSL_USERNAME and BOBSL_PASSWORD must be set when this script is executed: + +BOBSL_USERNAME=??? BOBSL_PASSWORD=??? python -m sign_language_datasets.datasets.bobsl.create_index + +These credentials can be obtained by signing a license agreement with the data owners: +https://www.robots.ox.ac.uk/~vgg/data/bobsl +""" + +import os +import json +import lxml.html + +from io import BytesIO +from typing import Dict + +from ...utils.downloaders import download_auth + + +BOBSL_USERNAME = os.environ["BOBSL_USERNAME"] +BOBSL_PASSWORD = os.environ["BOBSL_PASSWORD"] + +BASE_URL = "https://thor.robots.ox.ac.uk/~vgg/data/bobsl" + + +def get(url: str, decode: bool = True): + """ + + :param url: + :param decode: + :return: + """ + return download_auth.download_with_auth(url=url, username=BOBSL_USERNAME, password=BOBSL_PASSWORD, decode=decode) + + +def parse_sub_index(index: dict, ids_must_exist: bool, base_url_suffix: str, index_key: str, + file_extension: str) -> dict: + """ + + :param index: + :param ids_must_exist: + :param base_url_suffix: + :param index_key: + :param file_extension: + :return: + """ + sub_base_url = BASE_URL + "/" + base_url_suffix + + subpage_content = get(sub_base_url, decode=False) + doc = lxml.html.parse(BytesIO(subpage_content)) + + for link_element in doc.xpath("//a[contains(text(), '%s')]" % file_extension): + example_id = link_element.text.replace(file_extension, "") + example_url = sub_base_url + "/" + link_element.get("href") + + if ids_must_exist: + assert example_id in index.keys() + else: + assert example_id not in index.keys() + index[example_id] = {} + + index[example_id][index_key] = example_url + + return index + + +def create_index() -> Dict[str, Dict[str, str]]: + + index = {} + + # first pass: get example IDs from video subfolder + + index = parse_sub_index(index=index, ids_must_exist=False, base_url_suffix="videos", + index_key="video", file_extension=".mp4") + + # second pass: get URLs from pose subfolder + + index = parse_sub_index(index=index, ids_must_exist=True, base_url_suffix="pose", + index_key="openpose", file_extension=".tar.gz") + + # third pass: get URLs from flow subfolder + + index = parse_sub_index(index=index, ids_must_exist=True, base_url_suffix="flow", + index_key="flow", file_extension=".tar.gz") + + # add subtitles and spottings URLs + + index["subtitles"] = BASE_URL + "/" + "subtitles.tar.gz" + index["spottings"] = BASE_URL + "/" + "spottings.tar.gz" + + return index + + +def create_and_write_index(json_path: str) -> None: + """ + + :param json_path: + :return: + """ + index = create_index() + + # print some examples as a sanity check + print("10 Examples from the index:") + + for item_index, kv_tuple in enumerate(index.items()): + if item_index == 10: + break + print(kv_tuple) + + with open(json_path, "w") as outfile: + print("Writing structured download dict '%s'." % json_path) + json.dump(index, outfile) + + +if __name__ == "__main__": + create_and_write_index(json_path="bobsl.json") diff --git a/sign_language_datasets/datasets/bobsl/dummy_data/TODO-add_fake_data_in_this_directory.txt b/sign_language_datasets/datasets/bobsl/dummy_data/TODO-add_fake_data_in_this_directory.txt new file mode 100644 index 0000000..e69de29 diff --git a/sign_language_datasets/datasets/bobsl/openpose.poseheader b/sign_language_datasets/datasets/bobsl/openpose.poseheader new file mode 100644 index 0000000000000000000000000000000000000000..e1efb6a34937014b0fde12155b5639fd13b98281 GIT binary patch literal 1919 zcmeHH+fEZv6y23xsE7!1^^V{Ly!A4@DX492W8)MjLlgCBQt1E_YD)^n@W!X{0r(ZX z_!qj?K0_eEXP?Z<*=uI6p4GF@*>ip#AFm!JicyiMjLU9+&-ZrxgWZ16+u!ptogw-B zr8+0G5`n%3717_=Ns7C^-T&V0_=C8(weHvc5Adf!u(zKS`%JP!$yO;Y&7ku>*o}&d z7!&tH&-XE}+}r8G*w&`JiJ>ihh+Z|!auk1-hdem5cR zspOtXE~SWE0)~Jn>WKkyAc5y$5xIsjhB1LLj+`Ky1TF(3c)kLR0$1@ii!lw%;B5-y zIxs1}F^6+E+Z%t`#k!Pr7wZ)-6&1v^#*D_SMoVK(V_suHBaR2ZFU53L%23)+#!%Lf zWhiGTAD4O!|5FIL7;?#QCKYlz&FHP508h%O+yfan6E3&^-w#>HZ3_D9cRX=IiL&apHyOCz&1 zGD{<~G%`ygbF0XqUoP~^sb9m(s~=eX!0HE9Kd}0N)eo$GU{>jPc~%K34Wru<+I8%{ zD8O!HS4J*$+{n3(yNb>nLx-j$4P<~UU;#NG4-|kRPy${8uYi=S$SUv}cmuoz%0LCE z0xNP;ZUMJ}JHTDw9xxBw2NsZp{F9s!SmC16pO z5l%}Gx!72Y)M@J literal 0 HcmV?d00001 diff --git a/sign_language_datasets/datasets/ngt_corpus/create_index.py b/sign_language_datasets/datasets/ngt_corpus/create_index.py index f037e11..9e5fd42 100644 --- a/sign_language_datasets/datasets/ngt_corpus/create_index.py +++ b/sign_language_datasets/datasets/ngt_corpus/create_index.py @@ -221,7 +221,7 @@ def create_structured_download_dict(force_rebuild: bool = False, sys.exit() with open(structured_json_path, "w") as outfile: - print("Writing structured download dict '%s'." % "ngt.json") + print("Writing structured download dict '%s'." % structured_json_path) json.dump(structured_download_dict, outfile) diff --git a/sign_language_datasets/utils/downloaders/download_auth.py b/sign_language_datasets/utils/downloaders/download_auth.py new file mode 100644 index 0000000..b1a2e33 --- /dev/null +++ b/sign_language_datasets/utils/downloaders/download_auth.py @@ -0,0 +1,151 @@ +import requests +import shutil +import os +import tensorflow as tf + +import tensorflow_datasets as tfds + +from tensorflow_datasets.core.download import downloader +from tensorflow_datasets.core import units +from tensorflow_datasets.core import utils +from tensorflow_datasets.core.download import checksums as checksums_lib + +from typing import Optional, Any + + +def download_with_auth(url: str, username: str, password: str, decode: bool = True): + """ + Download with basic HTTP authentication. + + :param url: + :param username: + :param password: + :param decode: + :return: + """ + response = requests.get(url, auth=(username, password)) + + if decode: + return response.content.decode("utf-8") + else: + return response.content + + +def download_tar_gz_to_file_with_auth(url: str, + filepath: str, + username: str, + password: str, + unpack: bool = False, + unpack_path: Optional[str] = None): + """ + + :param url: + :param filepath: + :param username: + :param password: + :param unpack: + :param unpack_path: + :return: + """ + response = requests.get(url, auth=(username, password), stream=True) + + with open(filepath, 'wb') as outfile: + outfile.write(response.raw.read()) + + if unpack: + if unpack_path is None: + assert ".tar.gz" in filepath, "'tar.gz' not found in filepath. Specify an explicit 'unpack_path'." + unpack_path = filepath.replace(".tar.gz", "") + + shutil.unpack_archive(filepath, unpack_path) + + +# noinspection PyProtectedMember +@utils.memoize() +def get_downloader_with_auth(*args: Any, **kwargs: Any) -> '_DownloaderWithAuth': + return _DownloaderWithAuth(*args, **kwargs) + + +# noinspection PyProtectedMember +class _DownloaderWithAuth(downloader._Downloader): + + def __init__(self, username: str, password: str, **kwargs): + super(_DownloaderWithAuth, self).__init__(**kwargs) + + self.username = username + self.password = password + + def _sync_download(self, + url: str, + destination_path: str, + verify: bool = True) -> downloader.DownloadResult: + """Synchronous version of `download` method. + To download through a proxy, the `HTTP_PROXY`, `HTTPS_PROXY`, + `REQUESTS_CA_BUNDLE`,... environment variables can be exported, as + described in: + https://requests.readthedocs.io/en/master/user/advanced/#proxies + Args: + url: url to download + destination_path: path where to write it + verify: whether to verify ssl certificates + Returns: + None + Raises: + DownloadError: when download fails. + """ + try: + # If url is on a filesystem that gfile understands, use copy. Otherwise, + # use requests (http) or urllib (ftp). + if not url.startswith('http'): + return self._sync_file_copy(url, destination_path) + except tf.errors.UnimplementedError: + pass + + with downloader._open_url(url, verify=verify, auth=(self.username, self.password)) as (response, iter_content): + fname = downloader._get_filename(response) + path = os.path.join(destination_path, fname) + size = 0 + + # Initialize the download size progress bar + size_mb = 0 + unit_mb = units.MiB + total_size = int(response.headers.get('Content-length', 0)) // unit_mb + self._pbar_dl_size.update_total(total_size) + with tf.io.gfile.GFile(path, 'wb') as file_: + checksum = self._checksumer_cls() + for block in iter_content: + size += len(block) + checksum.update(block) + file_.write(block) + + # Update the download size progress bar + size_mb += len(block) + if size_mb > unit_mb: + self._pbar_dl_size.update(size_mb // unit_mb) + size_mb %= unit_mb + self._pbar_url.update(1) + return downloader.DownloadResult( + path=utils.as_path(path), + url_info=checksums_lib.UrlInfo( + checksum=checksum.hexdigest(), + size=utils.Size(size), + filename=fname, + ), + ) + + +class DownloadManagerWithAuth(tfds.download.DownloadManager): + + def __init__(self, *, username: str, password: str, **kwargs): + super().__init__(**kwargs) + + self.username = username + self.password = password + + self.__downloader = None + + @property + def _downloader(self): + if self.__downloader is None: + self.__downloader = get_downloader_with_auth(username=self.username, password=self.password) + return self.__downloader