-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_batch_inference.py
106 lines (82 loc) · 4.5 KB
/
test_batch_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
from datetime import timedelta
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.utils import name_from_base
from sagemaker_ssh_helper.wrapper import SSHModelWrapper, SSHTransformerWrapper
import test_util
def test_clean_batch_inference():
# noinspection DuplicatedCode
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
estimator = PyTorch(entry_point='train_clean.py',
source_dir='source_dir/training_clean/',
framework_version='1.9.1',
py_version='py38',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
estimator.fit()
model = estimator.create_model(entry_point='inference_clean.py',
source_dir='source_dir/inference_clean/')
transformer_input = sagemaker_session.upload_data(path='data/batch_transform/input',
bucket=bucket,
key_prefix='batch-transform/input')
transformer_output = f"s3://{bucket}/batch-transform/output"
transformer = model.transformer(instance_count=1,
instance_type="ml.m5.xlarge",
accept='text/csv',
strategy='SingleRecord',
assemble_with='Line',
output_path=transformer_output)
transformer.transform(data=transformer_input,
content_type='text/csv',
split_type='Line',
join_source="Input")
test_util._cleanup_dir("./output", recreate=True)
sagemaker_session.download_data(path='output', bucket=bucket,
key_prefix='batch-transform/output')
def test_batch_ssh():
# noinspection DuplicatedCode
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
estimator = PyTorch(entry_point='train_clean.py',
source_dir='source_dir/training_clean/',
framework_version='1.9.1',
py_version='py38',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
estimator.fit()
model = estimator.create_model(entry_point='inference_ssh.py',
source_dir='source_dir/inference/',
dependencies=[SSHModelWrapper.dependency_dir()])
transformer_input = sagemaker_session.upload_data(path='data/batch_transform/input',
bucket=bucket,
key_prefix='batch-transform/input')
transformer_output = f"s3://{bucket}/batch-transform/output"
ssh_model_wrapper = SSHModelWrapper.create(model, connection_wait_time_seconds=3600)
transformer = model.transformer(instance_count=1,
instance_type="ml.m5.xlarge",
accept='text/csv',
strategy='SingleRecord',
assemble_with='Line',
output_path=transformer_output)
ssh_transformer_wrapper = SSHTransformerWrapper.create(transformer, ssh_model_wrapper)
transform_job_name = name_from_base('ssh-batch-transform')
transformer.transform(data=transformer_input,
job_name=transform_job_name,
content_type='text/csv',
split_type='Line',
join_source="Input",
wait=False)
ssh_transformer_wrapper.start_ssm_connection_and_continue(16022)
ssh_transformer_wrapper.print_ssh_info()
ssh_transformer_wrapper.wait_transform_job()
test_util._cleanup_dir("./output", recreate=True)
sagemaker_session.download_data(path='output', bucket=bucket,
key_prefix='batch-transform/output')