19
19
import pytest
20
20
21
21
from sagemaker .tensorflow import TensorFlow
22
- from sagemaker .tensorflow .defaults import LATEST_SERVING_VERSION
22
+ from sagemaker .tensorflow .defaults import LATEST_VERSION , LATEST_SERVING_VERSION
23
23
from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
24
24
25
25
import tests .integ
26
- from tests .integ import timeout
27
- from tests .integ import kms_utils
26
+ from tests .integ import kms_utils , timeout , PYTHON_VERSION
28
27
from tests .integ .retry import retries
29
28
from tests .integ .s3_utils import assert_s3_files_exist
30
29
39
38
MPI_DISTRIBUTION = {"mpi" : {"enabled" : True }}
40
39
TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
41
40
41
+ PY37_SUPPORTED_FRAMEWORK_VERSION = [TensorFlow ._LATEST_1X_VERSION , LATEST_VERSION ]
42
+
42
43
43
44
@pytest .fixture (scope = "module" )
44
- def py_version (tf_full_version , tf_serving_version ):
45
- return "py37" if tf_full_version == tf_serving_version else tests . integ . PYTHON_VERSION
45
+ def py_version (tf_full_version ):
46
+ return "py37" if tf_full_version in PY37_SUPPORTED_FRAMEWORK_VERSION else PYTHON_VERSION
46
47
47
48
48
49
def test_mnist_with_checkpoint_config (
@@ -89,7 +90,7 @@ def test_mnist_with_checkpoint_config(
89
90
assert actual_training_checkpoint_config == expected_training_checkpoint_config
90
91
91
92
92
- def test_server_side_encryption (sagemaker_session , tf_serving_version , py_version ):
93
+ def test_server_side_encryption (sagemaker_session , tf_full_version , py_version ):
93
94
with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (bucket_with_kms , kms_key ):
94
95
output_path = os .path .join (
95
96
bucket_with_kms , "test-server-side-encryption" , time .strftime ("%y%m%d-%H%M" )
@@ -102,7 +103,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
102
103
train_instance_count = 1 ,
103
104
train_instance_type = "ml.c5.xlarge" ,
104
105
sagemaker_session = sagemaker_session ,
105
- framework_version = tf_serving_version ,
106
+ framework_version = tf_full_version ,
106
107
py_version = py_version ,
107
108
code_location = output_path ,
108
109
output_path = output_path ,
@@ -154,13 +155,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
154
155
)
155
156
156
157
157
- def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , py_version ):
158
+ def test_mnist_async (sagemaker_session , cpu_instance_type ):
158
159
estimator = TensorFlow (
159
160
entry_point = SCRIPT ,
160
161
role = ROLE ,
161
162
train_instance_count = 1 ,
162
163
train_instance_type = "ml.c5.4xlarge" ,
163
- py_version = tests . integ . PYTHON_VERSION ,
164
+ py_version = PYTHON_VERSION ,
164
165
sagemaker_session = sagemaker_session ,
165
166
# testing py-sdk functionality, no need to run against all TF versions
166
167
framework_version = LATEST_SERVING_VERSION ,
@@ -195,18 +196,16 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
195
196
_assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
196
197
197
198
198
- def test_deploy_with_input_handlers (
199
- sagemaker_session , instance_type , tf_serving_version , py_version
200
- ):
199
+ def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_serving_version ):
201
200
estimator = TensorFlow (
202
201
entry_point = "training.py" ,
203
202
source_dir = TFS_RESOURCE_PATH ,
204
203
role = ROLE ,
205
204
train_instance_count = 1 ,
206
205
train_instance_type = instance_type ,
207
- py_version = py_version ,
208
- sagemaker_session = sagemaker_session ,
209
206
framework_version = tf_serving_version ,
207
+ py_version = PYTHON_VERSION ,
208
+ sagemaker_session = sagemaker_session ,
210
209
tags = TAGS ,
211
210
)
212
211
0 commit comments