Skip to content

Commit 8bed9d8

Browse files
Fix: fix wrong error message in CelebA datasets when invalid split argument is passed in. (#8866)
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent a187f1b commit 8bed9d8

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

test/test_datasets.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532532
self._create_bbox_txt(base_folder, num_images)
533533
self._create_landmarks_txt(base_folder, num_images)
534534

535-
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
535+
num_samples = num_images_per_split.get(config["split"], 0) if isinstance(config["split"], str) else 0
536+
return dict(num_examples=num_samples, attr_names=attr_names)
536537

537538
def _create_split_txt(self, root):
538539
num_images_per_split = dict(train=4, valid=3, test=2)
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635636
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
636637
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
637638

639+
def test_invalid_split_list(self):
640+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'list'>."):
641+
with self.create_dataset(split=[1]):
642+
pass
643+
644+
def test_invalid_split_int(self):
645+
with pytest.raises(ValueError, match="Expected type str for argument split, but got type <class 'int'>."):
646+
with self.create_dataset(split=1):
647+
pass
648+
649+
def test_invalid_split_value(self):
650+
with pytest.raises(
651+
ValueError,
652+
match="Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}.".format(
653+
value="invalid",
654+
arg="split",
655+
valid_values=("train", "valid", "test", "all"),
656+
),
657+
):
658+
with self.create_dataset(split="invalid"):
659+
pass
660+
638661

639662
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
640663
DATASET_CLASS = datasets.VOCSegmentation

torchvision/datasets/celeba.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def __init__(
9393
"test": 2,
9494
"all": None,
9595
}
96-
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
96+
split_ = split_map[
97+
verify_str_arg(
98+
split.lower() if isinstance(split, str) else split,
99+
"split",
100+
("train", "valid", "test", "all"),
101+
)
102+
]
97103
splits = self._load_csv("list_eval_partition.txt")
98104
identity = self._load_csv("identity_CelebA.txt")
99105
bbox = self._load_csv("list_bbox_celeba.txt", header=1)

0 commit comments

Comments
 (0)