20
20
from sagemaker .chainer .estimator import Chainer
21
21
from sagemaker .chainer .model import ChainerModel
22
22
from sagemaker .utils import unique_name_from_base
23
- from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES
23
+ from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
24
24
from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
25
25
26
26
27
27
@pytest .fixture (scope = "module" )
28
- def chainer_local_training_job (sagemaker_local_session , chainer_full_version ):
29
- return _run_mnist_training_job (sagemaker_local_session , "local" , 1 , chainer_full_version )
28
+ def chainer_local_training_job (
29
+ sagemaker_local_session , chainer_full_version , chainer_full_py_version
30
+ ):
31
+ return _run_mnist_training_job (
32
+ sagemaker_local_session , "local" , 1 , chainer_full_version , chainer_full_py_version
33
+ )
30
34
31
35
32
36
@pytest .mark .local_mode
33
- def test_distributed_cpu_training (sagemaker_local_session , chainer_full_version ):
34
- _run_mnist_training_job (sagemaker_local_session , "local" , 2 , chainer_full_version )
37
+ def test_distributed_cpu_training (
38
+ sagemaker_local_session , chainer_full_version , chainer_full_py_version
39
+ ):
40
+ _run_mnist_training_job (
41
+ sagemaker_local_session , "local" , 2 , chainer_full_version , chainer_full_py_version
42
+ )
35
43
36
44
37
45
@pytest .mark .local_mode
38
- def test_training_with_additional_hyperparameters (sagemaker_local_session , chainer_full_version ):
46
+ def test_training_with_additional_hyperparameters (
47
+ sagemaker_local_session , chainer_full_version , chainer_full_py_version
48
+ ):
39
49
script_path = os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
40
50
data_path = os .path .join (DATA_DIR , "chainer_mnist" )
41
51
@@ -45,7 +55,7 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
45
55
train_instance_count = 1 ,
46
56
train_instance_type = "local" ,
47
57
framework_version = chainer_full_version ,
48
- py_version = PYTHON_VERSION ,
58
+ py_version = chainer_full_py_version ,
49
59
sagemaker_session = sagemaker_local_session ,
50
60
hyperparameters = {"epochs" : 1 },
51
61
use_mpi = True ,
@@ -62,7 +72,9 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
62
72
63
73
@pytest .mark .canary_quick
64
74
@pytest .mark .regional_testing
65
- def test_attach_deploy (sagemaker_session , chainer_full_version , cpu_instance_type ):
75
+ def test_attach_deploy (
76
+ sagemaker_session , chainer_full_version , chainer_full_py_version , cpu_instance_type
77
+ ):
66
78
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
67
79
script_path = os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
68
80
data_path = os .path .join (DATA_DIR , "chainer_mnist" )
@@ -71,7 +83,7 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
71
83
entry_point = script_path ,
72
84
role = "SageMakerRole" ,
73
85
framework_version = chainer_full_version ,
74
- py_version = PYTHON_VERSION ,
86
+ py_version = chainer_full_py_version ,
75
87
train_instance_count = 1 ,
76
88
train_instance_type = cpu_instance_type ,
77
89
sagemaker_session = sagemaker_session ,
@@ -100,7 +112,12 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
100
112
101
113
102
114
@pytest .mark .local_mode
103
- def test_deploy_model (chainer_local_training_job , sagemaker_local_session , chainer_full_version ):
115
+ def test_deploy_model (
116
+ chainer_local_training_job ,
117
+ sagemaker_local_session ,
118
+ chainer_full_version ,
119
+ chainer_full_py_version ,
120
+ ):
104
121
script_path = os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
105
122
106
123
model = ChainerModel (
@@ -109,7 +126,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
109
126
entry_point = script_path ,
110
127
sagemaker_session = sagemaker_local_session ,
111
128
framework_version = chainer_full_version ,
112
- py_version = PYTHON_VERSION ,
129
+ py_version = chainer_full_py_version ,
113
130
)
114
131
115
132
predictor = model .deploy (1 , "local" )
@@ -120,7 +137,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
120
137
121
138
122
139
def _run_mnist_training_job (
123
- sagemaker_session , instance_type , instance_count , chainer_full_version , wait = True
140
+ sagemaker_session , instance_type , instance_count , chainer_version , py_version
124
141
):
125
142
script_path = (
126
143
os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
@@ -133,8 +150,8 @@ def _run_mnist_training_job(
133
150
chainer = Chainer (
134
151
entry_point = script_path ,
135
152
role = "SageMakerRole" ,
136
- framework_version = chainer_full_version ,
137
- py_version = PYTHON_VERSION ,
153
+ framework_version = chainer_version ,
154
+ py_version = py_version ,
138
155
train_instance_count = instance_count ,
139
156
train_instance_type = instance_type ,
140
157
sagemaker_session = sagemaker_session ,
@@ -147,7 +164,7 @@ def _run_mnist_training_job(
147
164
test_input = "file://" + os .path .join (data_path , "test" )
148
165
149
166
job_name = unique_name_from_base ("test-chainer-training" )
150
- chainer .fit ({"train" : train_input , "test" : test_input }, wait = wait , job_name = job_name )
167
+ chainer .fit ({"train" : train_input , "test" : test_input }, job_name = job_name )
151
168
return chainer
152
169
153
170
0 commit comments