Skip to content

Commit

Permalink
Add waterbirds (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
seyuboglu authored Feb 15, 2022
1 parent 961c227 commit 513bf9a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
7 changes: 7 additions & 0 deletions meerkat/contrib/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,10 @@ def yesno(dataset_dir: str = None, download: bool = True, **kwargs):
from .torchaudio import get_yesno

return get_yesno(dataset_dir=dataset_dir, download=download, **kwargs)


@datasets.register()
def waterbirds(dataset_dir: str = None, download: bool = True, **kwargs):
from .waterbirds import build_waterbirds_dp

return build_waterbirds_dp(dataset_dir=dataset_dir, download=download, **kwargs)
54 changes: 54 additions & 0 deletions meerkat/contrib/waterbirds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os

import pandas as pd
from wilds import get_dataset
from wilds.datasets.wilds_dataset import WILDSDataset

import meerkat as mk

# flake8: noqa
URL = "http://worksheets.codalab.org/rest/bundles/0x505056d5cdea4e4eaa0e242cbfe2daa4/contents/blob/"


def build_waterbirds_dp(
dataset_dir: str,
download: bool = True,
):
"""Download and load the Waterbirds dataset.
Args:
download_dir (str): The directory to save to.
Returns:
a DataPanel containing columns `image`, `y`, "background", and `split`,
References:
"""
dataset = get_dataset(dataset="waterbirds", root_dir=dataset_dir, download=download)

df = pd.DataFrame(dataset.metadata_array, columns=dataset.metadata_fields)
df["filepath"] = dataset._input_array

dp = mk.DataPanel.from_pandas(df)
dp["image"] = mk.ImageColumn(
dp["filepath"], base_dir=os.path.join(dataset_dir, "waterbirds_v1.0")
)
dp["split"] = pd.Series(dataset._split_array).map(
{
v: k if k != "val" else "valid"
for k, v in WILDSDataset.DEFAULT_SPLITS.items()
}
)

backgrounds = dataset._metadata_map["background"]
birds = dataset._metadata_map["y"]
group_mapping = {
f"{bird_idx}{bground_idx}": f"{birds[bird_idx]}-{backgrounds[bground_idx]}"
for bird_idx in [0, 1]
for bground_idx in [0, 1]
}

dp["group"] = (dp["y"].astype(str) + dp["background"].data.astype(str)).map(
group_mapping
)
return dp

0 comments on commit 513bf9a

Please # to comment.