diff --git a/nucleus/util/io_utils.py b/nucleus/util/io_utils.py index 25db56c..467eca6 100644 --- a/nucleus/util/io_utils.py +++ b/nucleus/util/io_utils.py @@ -52,9 +52,9 @@ def ParseShardedFileSpec(spec): # pylint:disable=invalid-name 'gs://some/file@200.txt'. Here, '@200' specifies the number of shards. Returns: - basename: The basename for the files. - num_shards: The number of shards. - suffix: The suffix if there is one, or '' if not. + basename: str. The basename for the files. + num_shards: int >= 0. The number of shards. + suffix: str. The suffix if there is one, or '' if not. Raises: ShardError: If the spec is not a valid sharded specification. """ @@ -188,8 +188,9 @@ def resolve_filespecs(shard, *filespecs): shard-specific file paths. Returns: - A list. The first element is the number of shards, followed by the - shard-specific paths for each filespec, in order. + A list. The first element is the number of shards, which is an int >= 1 when + filespecs contains sharded paths and 0 if none do. All subsequent + returned values follow the shard-specific paths for each filespec, in order. Raises: ValueError: if any filespecs are inconsistent. @@ -200,7 +201,7 @@ def resolve_filespecs(shard, *filespecs): master = filespecs[0] master_is_sharded = IsShardedFileSpec(master) - master_num_shards = None + master_num_shards = 0 if master_is_sharded: _, master_num_shards, _ = ParseShardedFileSpec(master) if shard >= master_num_shards or shard < 0: diff --git a/nucleus/util/io_utils_test.py b/nucleus/util/io_utils_test.py index 53832ec..1482b1b 100644 --- a/nucleus/util/io_utils_test.py +++ b/nucleus/util/io_utils_test.py @@ -34,27 +34,63 @@ class IOTest(parameterized.TestCase): @parameterized.parameters( # Unsharded outputs pass through as expected. - (0, ['foo.txt'], [None, 'foo.txt']), - (0, ['foo.txt', 'bar.txt'], [None, 'foo.txt', 'bar.txt']), - (0, ['bar.txt', 'foo.txt'], [None, 'bar.txt', 'foo.txt']), + dict(task_id=0, filespecs=['foo.txt'], expected=[0, 'foo.txt']), + dict( + task_id=0, + filespecs=['foo.txt', 'bar.txt'], + expected=[0, 'foo.txt', 'bar.txt']), + dict( + task_id=0, + filespecs=['bar.txt', 'foo.txt'], + expected=[0, 'bar.txt', 'foo.txt']), # It's ok to have False values for other bindings. - (0, ['foo.txt', None], [None, 'foo.txt', None]), - (0, ['foo.txt', ''], [None, 'foo.txt', '']), - (0, ['foo@10.txt', None], [10, 'foo-00000-of-00010.txt', None]), - (0, ['foo@10.txt', ''], [10, 'foo-00000-of-00010.txt', '']), + dict( + task_id=0, filespecs=['foo.txt', None], expected=[0, 'foo.txt', + None]), + dict(task_id=0, filespecs=['foo.txt', ''], expected=[0, 'foo.txt', '']), + dict( + task_id=0, + filespecs=['foo@10.txt', None], + expected=[10, 'foo-00000-of-00010.txt', None]), + dict( + task_id=0, + filespecs=['foo@10.txt', ''], + expected=[10, 'foo-00000-of-00010.txt', '']), # Simple check that master behaves as expected. - (0, ['foo@10.txt', None], [10, 'foo-00000-of-00010.txt', None]), - (0, ['foo@10', None], [10, 'foo-00000-of-00010', None]), - (1, ['foo@10', None], [10, 'foo-00001-of-00010', None]), - (9, ['foo@10', None], [10, 'foo-00009-of-00010', None]), - # Make sure we handle sharding of multiple outputs. - (0, ['foo@10', 'bar@10', 'baz@10'], - [10, 'foo-00000-of-00010', 'bar-00000-of-00010', 'baz-00000-of-00010']), - (9, ['foo@10', 'bar@10', 'baz@10'], - [10, 'foo-00009-of-00010', 'bar-00009-of-00010', 'baz-00009-of-00010']), + dict( + task_id=0, + filespecs=['foo@10.txt', None], + expected=[10, 'foo-00000-of-00010.txt', None]), + dict( + task_id=0, + filespecs=['foo@10', None], + expected=[10, 'foo-00000-of-00010', None]), + dict( + task_id=1, + filespecs=['foo@10', None], + expected=[10, 'foo-00001-of-00010', None]), + dict( + task_id=9, + filespecs=['foo@10', None], + expected=[10, 'foo-00009-of-00010', None]), + # Make sure we handle sharding of multiple filespecs. + dict( + task_id=0, + filespecs=['foo@10', 'bar@10', 'baz@10'], + expected=[ + 10, 'foo-00000-of-00010', 'bar-00000-of-00010', + 'baz-00000-of-00010' + ]), + dict( + task_id=9, + filespecs=['foo@10', 'bar@10', 'baz@10'], + expected=[ + 10, 'foo-00009-of-00010', 'bar-00009-of-00010', + 'baz-00009-of-00010' + ]), ) - def test_resolve_filespecs(self, task_id, outputs, expected): - self.assertEqual(io.resolve_filespecs(task_id, *outputs), expected) + def test_resolve_filespecs(self, task_id, filespecs, expected): + self.assertEqual(io.resolve_filespecs(task_id, *filespecs), expected) @parameterized.parameters( # shard >= num_shards.