Skip to content

Commit 749cd30

Browse files
committed
feature: python module support to torch_distributed
1 parent 889d7b7 commit 749cd30

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/sagemaker/fw_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,12 @@ def validate_source_dir(script, directory):
200200
not contain ``script``.
201201
"""
202202
if directory:
203-
if not os.path.isfile(os.path.join(directory, script)):
203+
if script.endswith(".py") and not os.path.isfile(os.path.join(directory, script)):
204204
raise ValueError(
205205
'No file named "{}" was found in directory "{}".'.format(script, directory)
206206
)
207207

208+
208209
return True
209210

210211

@@ -1143,10 +1144,10 @@ def validate_torch_distributed_distribution(
11431144
)
11441145

11451146
# Check entry point type
1146-
if not entry_point.endswith(".py"):
1147+
if not entry_point.startswith("-m") and not entry_point.endswith(".py"):
11471148
err_msg += (
11481149
"Unsupported entry point type for the distribution torch_distributed.\n"
1149-
"Only python programs (*.py) are supported."
1150+
"Only python programs (*.py) or modules (-m *) are supported."
11501151
)
11511152

11521153
if err_msg:

tests/unit/test_fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_validate_source_dir_is_not_directory(sagemaker_session):
186186

187187

188188
def test_validate_source_dir_file_not_in_dir():
189-
script = " !@#$%^&*() .myscript. !@#$%^&*() "
189+
script = " !@#$%^&*() .myscript. !@#$%^&*().py"
190190
directory = "."
191191
with pytest.raises(ValueError):
192192
fw_utils.validate_source_dir(script, directory)

0 commit comments

Comments
 (0)