Skip to content
This repository has been archived by the owner on May 25, 2024. It is now read-only.

Commit

Permalink
Improve the documentation of resolve_filespecs. Also have it return 0…
Browse files Browse the repository at this point in the history
… for num_shards instead of None when there are no shards; update all call sites to reflect this change.

PiperOrigin-RevId: 213698824
  • Loading branch information
mdepristo authored and Copybara-Service committed Sep 19, 2018
1 parent d72543c commit ce1ee08
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 24 deletions.
13 changes: 7 additions & 6 deletions nucleus/util/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def ParseShardedFileSpec(spec): # pylint:disable=invalid-name
'gs://some/file@200.txt'. Here, '@200' specifies the number of shards.
Returns:
basename: The basename for the files.
num_shards: The number of shards.
suffix: The suffix if there is one, or '' if not.
basename: str. The basename for the files.
num_shards: int >= 0. The number of shards.
suffix: str. The suffix if there is one, or '' if not.
Raises:
ShardError: If the spec is not a valid sharded specification.
"""
Expand Down Expand Up @@ -188,8 +188,9 @@ def resolve_filespecs(shard, *filespecs):
shard-specific file paths.
Returns:
A list. The first element is the number of shards, followed by the
shard-specific paths for each filespec, in order.
A list. The first element is the number of shards, which is an int >= 1 when
filespecs contains sharded paths and 0 if none do. All subsequent
returned values follow the shard-specific paths for each filespec, in order.
Raises:
ValueError: if any filespecs are inconsistent.
Expand All @@ -200,7 +201,7 @@ def resolve_filespecs(shard, *filespecs):
master = filespecs[0]
master_is_sharded = IsShardedFileSpec(master)

master_num_shards = None
master_num_shards = 0
if master_is_sharded:
_, master_num_shards, _ = ParseShardedFileSpec(master)
if shard >= master_num_shards or shard < 0:
Expand Down
72 changes: 54 additions & 18 deletions nucleus/util/io_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,63 @@ class IOTest(parameterized.TestCase):

@parameterized.parameters(
# Unsharded outputs pass through as expected.
(0, ['foo.txt'], [None, 'foo.txt']),
(0, ['foo.txt', 'bar.txt'], [None, 'foo.txt', 'bar.txt']),
(0, ['bar.txt', 'foo.txt'], [None, 'bar.txt', 'foo.txt']),
dict(task_id=0, filespecs=['foo.txt'], expected=[0, 'foo.txt']),
dict(
task_id=0,
filespecs=['foo.txt', 'bar.txt'],
expected=[0, 'foo.txt', 'bar.txt']),
dict(
task_id=0,
filespecs=['bar.txt', 'foo.txt'],
expected=[0, 'bar.txt', 'foo.txt']),
# It's ok to have False values for other bindings.
(0, ['foo.txt', None], [None, 'foo.txt', None]),
(0, ['foo.txt', ''], [None, 'foo.txt', '']),
(0, ['foo@10.txt', None], [10, 'foo-00000-of-00010.txt', None]),
(0, ['foo@10.txt', ''], [10, 'foo-00000-of-00010.txt', '']),
dict(
task_id=0, filespecs=['foo.txt', None], expected=[0, 'foo.txt',
None]),
dict(task_id=0, filespecs=['foo.txt', ''], expected=[0, 'foo.txt', '']),
dict(
task_id=0,
filespecs=['foo@10.txt', None],
expected=[10, 'foo-00000-of-00010.txt', None]),
dict(
task_id=0,
filespecs=['foo@10.txt', ''],
expected=[10, 'foo-00000-of-00010.txt', '']),
# Simple check that master behaves as expected.
(0, ['foo@10.txt', None], [10, 'foo-00000-of-00010.txt', None]),
(0, ['foo@10', None], [10, 'foo-00000-of-00010', None]),
(1, ['foo@10', None], [10, 'foo-00001-of-00010', None]),
(9, ['foo@10', None], [10, 'foo-00009-of-00010', None]),
# Make sure we handle sharding of multiple outputs.
(0, ['foo@10', 'bar@10', 'baz@10'],
[10, 'foo-00000-of-00010', 'bar-00000-of-00010', 'baz-00000-of-00010']),
(9, ['foo@10', 'bar@10', 'baz@10'],
[10, 'foo-00009-of-00010', 'bar-00009-of-00010', 'baz-00009-of-00010']),
dict(
task_id=0,
filespecs=['foo@10.txt', None],
expected=[10, 'foo-00000-of-00010.txt', None]),
dict(
task_id=0,
filespecs=['foo@10', None],
expected=[10, 'foo-00000-of-00010', None]),
dict(
task_id=1,
filespecs=['foo@10', None],
expected=[10, 'foo-00001-of-00010', None]),
dict(
task_id=9,
filespecs=['foo@10', None],
expected=[10, 'foo-00009-of-00010', None]),
# Make sure we handle sharding of multiple filespecs.
dict(
task_id=0,
filespecs=['foo@10', 'bar@10', 'baz@10'],
expected=[
10, 'foo-00000-of-00010', 'bar-00000-of-00010',
'baz-00000-of-00010'
]),
dict(
task_id=9,
filespecs=['foo@10', 'bar@10', 'baz@10'],
expected=[
10, 'foo-00009-of-00010', 'bar-00009-of-00010',
'baz-00009-of-00010'
]),
)
def test_resolve_filespecs(self, task_id, outputs, expected):
self.assertEqual(io.resolve_filespecs(task_id, *outputs), expected)
def test_resolve_filespecs(self, task_id, filespecs, expected):
self.assertEqual(io.resolve_filespecs(task_id, *filespecs), expected)

@parameterized.parameters(
# shard >= num_shards.
Expand Down

0 comments on commit ce1ee08

Please # to comment.