Skip to content

Commit 23f70a0

Browse files
committed
Pass SIGTERM to training subprocess
feature: Pass SIGTERM to training subprocess fix: #125
1 parent 22a170a commit 23f70a0

File tree

5 files changed

+81
-5
lines changed

5 files changed

+81
-5
lines changed

Diff for: CONTRIBUTING.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ GitHub provides additional document on [forking a repository](https://help.githu
3737
### Running the unit tests
3838

3939
1. Install tox using `pip install tox`
40-
1. Install coverage using `pip install .[test]`
40+
1. Install coverage using `pip install ".[test]"`
4141
1. cd into the sagemaker-training-toolkit folder: `cd sagemaker-training-toolkit`
4242
1. Run the following tox command and verify that all code checks and unit tests pass: `tox test/unit`
4343

Diff for: README.md

+2
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ entry_point.run(uri=env.module_dir,
201201

202202
If the entry point execution fails, `trainer.train()` will write the error message to `/opt/ml/output/failure`. Otherwise, it will write to the file `/opt/ml/success`.
203203

204+
If `sagemaker_training` receives a `SIGTERM`, such as from `StopTrainingJob`, it will pass that signal to your script.
205+
204206
## :scroll: License
205207

206208
This library is licensed under the [Apache 2.0 License](http://aws.amazon.com/apache2.0/).

Diff for: src/sagemaker_training/process.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import asyncio
1919
from asyncio.subprocess import PIPE
20+
from contextlib import contextmanager
2021
import os
2122
import re
23+
import signal
2224
import subprocess
2325
import sys
2426

@@ -36,6 +38,24 @@
3638
_DEFAULT_BUF_SIZE = 1024 * 64
3739

3840

41+
@contextmanager
42+
def capture_signal(signalnum, callback):
43+
"""
44+
Install handler to capture signal
45+
46+
Args:
47+
signalnum: signal to capture
48+
callback: callback if signal occurs
49+
50+
"""
51+
original_handler = signal.getsignal(signalnum)
52+
signal.signal(signalnum, callback)
53+
try:
54+
yield
55+
finally:
56+
signal.signal(signalnum, original_handler)
57+
58+
3959
async def watch(stream, proc_per_host):
4060
"""Process the stdout and stderr streams on the fly.
4161
Decode the output lines
@@ -118,9 +138,10 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
118138
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
119139
)
120140

121-
output = await asyncio.gather(
122-
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
123-
)
141+
with capture_signal(signal.SIGTERM, lambda signalnum, *_: proc.send_signal(signalnum)):
142+
output = await asyncio.gather(
143+
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
144+
)
124145
return_code = proc.returncode
125146
return return_code, output, proc
126147

@@ -198,7 +219,8 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
198219
process = subprocess.Popen(
199220
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
200221
)
201-
return_code = process.wait()
222+
with capture_signal(signal.SIGTERM, lambda signalnum, *_: process.send_signal(signalnum)):
223+
return_code = process.wait()
202224
if return_code:
203225
extra_info = None
204226
if return_code == 137:

Diff for: test/unit/_test_process_helper.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Helper script for testing signal handling
3+
4+
- If it receives SIGTERM, immediately exit "21"
5+
- If it doesn't receive a signal, sleep for 3 seconds then exit "-1"
6+
"""
7+
8+
import signal
9+
import time
10+
11+
12+
def signal_handler(signalnum, *_):
13+
assert signalnum == signal.SIGTERM
14+
exit(21)
15+
16+
17+
def main():
18+
signal.signal(signal.SIGTERM, signal_handler)
19+
time.sleep(3)
20+
exit(-1)
21+
22+
23+
if __name__ == "__main__":
24+
main()

Diff for: test/unit/test_process.py

+28
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import asyncio
16+
import multiprocessing
1617
import os
18+
import signal
1719
import sys
20+
import time
1821

1922
from mock import ANY, MagicMock, patch
2023
import pytest
@@ -175,3 +178,28 @@ def test_run_python(log, async_shell, async_gather, entry_point_type_script, eve
175178
stdout=asyncio.subprocess.PIPE,
176179
)
177180
log.assert_called_with(cmd, {})
181+
182+
183+
def _sleep_subprocess(capture_error):
184+
with pytest.raises(errors.ExecuteUserScriptError) as error:
185+
process.check_error(
186+
[sys.executable, os.path.abspath(os.path.join(__file__, "../_test_process_helper.py"))],
187+
errors.ExecuteUserScriptError,
188+
1,
189+
capture_error=capture_error,
190+
)
191+
assert int(error.value.return_code) == 21
192+
exit(42)
193+
194+
195+
@pytest.mark.skipif(
196+
sys.version_info < (3, 7) or sys.version_info >= (3, 8), reason="requires python3.7"
197+
)
198+
@pytest.mark.parametrize("capture_error", [True, False])
199+
def test_check_error_signal(capture_error):
200+
proc = multiprocessing.Process(target=_sleep_subprocess, args=(capture_error,))
201+
proc.start()
202+
time.sleep(1)
203+
os.kill(proc.pid, signal.SIGTERM)
204+
proc.join(1)
205+
assert int(proc.exitcode) == 42

0 commit comments

Comments
 (0)