Skip to content

Commit 272a6bc

Browse files
committed
feature: add python module entrypoint type, add python module support to torch_distributed
1 parent b7c660b commit 272a6bc

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

Diff for: src/sagemaker_training/_entry_point_type.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
class _EntryPointType(enum.Enum):
2222
"""Enumerated type consisting of valid types of training entry points."""
2323

24+
PYTHON_MODULE = "PYTHON_MODULE"
2425
PYTHON_PACKAGE = "PYTHON_PACKAGE"
2526
PYTHON_PROGRAM = "PYTHON_PROGRAM"
2627
COMMAND = "COMMAND"
2728

2829

30+
PYTHON_MODULE = _EntryPointType.PYTHON_MODULE
2931
PYTHON_PACKAGE = _EntryPointType.PYTHON_PACKAGE
3032
PYTHON_PROGRAM = _EntryPointType.PYTHON_PROGRAM
3133
COMMAND = _EntryPointType.COMMAND
@@ -46,5 +48,7 @@ def get(path, name): # type: (str, str) -> _EntryPointType
4648
return _EntryPointType.PYTHON_PACKAGE
4749
elif name.endswith(".py"):
4850
return _EntryPointType.PYTHON_PROGRAM
51+
elif name.startswith("-m "):
52+
return _EntryPointType.PYTHON_MODULE
4953
else:
5054
return _EntryPointType.COMMAND

Diff for: src/sagemaker_training/torch_distributed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _create_command(self):
9393
"Please use a python script as the entry-point"
9494
)
9595

96-
if entrypoint_type is _entry_point_type.PYTHON_PROGRAM:
96+
if entrypoint_type is _entry_point_type.PYTHON_PROGRAM or entrypoint_type is _entry_point_type.PYTHON_MODULE:
9797
num_hosts = len(self._hosts)
9898
torchrun_cmd = []
9999

@@ -135,7 +135,7 @@ def _create_command(self):
135135
torchrun_cmd += self._args
136136
return torchrun_cmd
137137
else:
138-
raise errors.ClientError("Unsupported entry point type for torch_distributed")
138+
raise errors.ClientError(f"Unsupported entry point type for torch_distributed: {entrypoint_type}")
139139

140140
def run(self, capture_error=True, wait=True):
141141
"""

Diff for: test/unit/test_entry_point_type.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def has_requirements():
3636
yield
3737

3838

39+
def test_get_module():
40+
assert _entry_point_type.get("bla", "-m program") == _entry_point_type.PYTHON_MODULE
41+
42+
3943
def test_get_package(entry_point_type_module):
4044
assert _entry_point_type.get("bla", "program.py") == _entry_point_type.PYTHON_PACKAGE
4145

0 commit comments

Comments
 (0)