File tree 2 files changed +5
-4
lines changed
2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -200,11 +200,12 @@ def validate_source_dir(script, directory):
200
200
not contain ``script``.
201
201
"""
202
202
if directory :
203
- if not os .path .isfile (os .path .join (directory , script )):
203
+ if script . endswith ( ".py" ) and not os .path .isfile (os .path .join (directory , script )):
204
204
raise ValueError (
205
205
'No file named "{}" was found in directory "{}".' .format (script , directory )
206
206
)
207
207
208
+
208
209
return True
209
210
210
211
@@ -1143,10 +1144,10 @@ def validate_torch_distributed_distribution(
1143
1144
)
1144
1145
1145
1146
# Check entry point type
1146
- if not entry_point .endswith (".py" ):
1147
+ if not entry_point .startswith ( "-m" ) and not entry_point . endswith (".py" ):
1147
1148
err_msg += (
1148
1149
"Unsupported entry point type for the distribution torch_distributed.\n "
1149
- "Only python programs (*.py) are supported."
1150
+ "Only python programs (*.py) or modules (-m *) are supported."
1150
1151
)
1151
1152
1152
1153
if err_msg :
Original file line number Diff line number Diff line change @@ -186,7 +186,7 @@ def test_validate_source_dir_is_not_directory(sagemaker_session):
186
186
187
187
188
188
def test_validate_source_dir_file_not_in_dir ():
189
- script = " !@#$%^&*() .myscript. !@#$%^&*() "
189
+ script = " !@#$%^&*() .myscript. !@#$%^&*().py "
190
190
directory = "."
191
191
with pytest .raises (ValueError ):
192
192
fw_utils .validate_source_dir (script , directory )
You can’t perform that action at this time.
0 commit comments