@@ -532,7 +532,8 @@ def inject_fake_data(self, tmpdir, config):
532
532
self ._create_bbox_txt (base_folder , num_images )
533
533
self ._create_landmarks_txt (base_folder , num_images )
534
534
535
- return dict (num_examples = num_images_per_split [config ["split" ]], attr_names = attr_names )
535
+ num_samples = num_images_per_split .get (config ["split" ], 0 ) if isinstance (config ["split" ], str ) else 0
536
+ return dict (num_examples = num_samples , attr_names = attr_names )
536
537
537
538
def _create_split_txt (self , root ):
538
539
num_images_per_split = dict (train = 4 , valid = 3 , test = 2 )
@@ -635,6 +636,28 @@ def test_transforms_v2_wrapper_spawn(self):
635
636
with self .create_dataset (target_type = target_type , transform = v2 .Resize (size = expected_size )) as (dataset , _ ):
636
637
datasets_utils .check_transforms_v2_wrapper_spawn (dataset , expected_size = expected_size )
637
638
639
+ def test_invalid_split_list (self ):
640
+ with pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'list'>." ):
641
+ with self .create_dataset (split = [1 ]):
642
+ pass
643
+
644
+ def test_invalid_split_int (self ):
645
+ with pytest .raises (ValueError , match = "Expected type str for argument split, but got type <class 'int'>." ):
646
+ with self .create_dataset (split = 1 ):
647
+ pass
648
+
649
+ def test_invalid_split_value (self ):
650
+ with pytest .raises (
651
+ ValueError ,
652
+ match = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." .format (
653
+ value = "invalid" ,
654
+ arg = "split" ,
655
+ valid_values = ("train" , "valid" , "test" , "all" ),
656
+ ),
657
+ ):
658
+ with self .create_dataset (split = "invalid" ):
659
+ pass
660
+
638
661
639
662
class VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
640
663
DATASET_CLASS = datasets .VOCSegmentation
0 commit comments