Skip to content

Commit c233f67

Browse files
authored
infra: use fixture for Chainer and XGBoost Python version, clean up remaining version fixtures (#1631)
1 parent 6edac7c commit c233f67

File tree

6 files changed

+70
-67
lines changed

6 files changed

+70
-67
lines changed

tests/conftest.py

+28-37
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,6 @@ def pytest_addoption(parser):
4444
parser.addoption("--sagemaker-client-config", action="store", default=None)
4545
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
4646
parser.addoption("--boto-config", action="store", default=None)
47-
parser.addoption("--chainer-full-version", action="store", default="5.0.0")
48-
parser.addoption("--ei-mxnet-full-version", action="store", default="1.5.1")
49-
parser.addoption(
50-
"--rl-coach-mxnet-full-version",
51-
action="store",
52-
default=RLEstimator.COACH_LATEST_VERSION_MXNET,
53-
)
54-
parser.addoption(
55-
"--rl-coach-tf-full-version", action="store", default=RLEstimator.COACH_LATEST_VERSION_TF
56-
)
57-
parser.addoption(
58-
"--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION
59-
)
60-
parser.addoption("--ei-tf-full-version", action="store")
61-
parser.addoption("--xgboost-full-version", action="store", default="1.0-1")
6247

6348

6449
def pytest_configure(config):
@@ -248,8 +233,13 @@ def rl_ray_version(request):
248233

249234

250235
@pytest.fixture(scope="module")
251-
def chainer_full_version(request):
252-
return request.config.getoption("--chainer-full-version")
236+
def chainer_full_version():
237+
return "5.0.0"
238+
239+
240+
@pytest.fixture(scope="module")
241+
def chainer_full_py_version():
242+
return "py3"
253243

254244

255245
@pytest.fixture(scope="module")
@@ -263,8 +253,8 @@ def mxnet_full_py_version():
263253

264254

265255
@pytest.fixture(scope="module")
266-
def ei_mxnet_full_version(request):
267-
return request.config.getoption("--ei-mxnet-full-version")
256+
def ei_mxnet_full_version():
257+
return "1.5.1"
268258

269259

270260
@pytest.fixture(scope="module")
@@ -283,18 +273,18 @@ def pytorch_full_ei_version():
283273

284274

285275
@pytest.fixture(scope="module")
286-
def rl_coach_mxnet_full_version(request):
287-
return request.config.getoption("--rl-coach-mxnet-full-version")
276+
def rl_coach_mxnet_full_version():
277+
return RLEstimator.COACH_LATEST_VERSION_MXNET
288278

289279

290280
@pytest.fixture(scope="module")
291-
def rl_coach_tf_full_version(request):
292-
return request.config.getoption("--rl-coach-tf-full-version")
281+
def rl_coach_tf_full_version():
282+
return RLEstimator.COACH_LATEST_VERSION_TF
293283

294284

295285
@pytest.fixture(scope="module")
296-
def rl_ray_full_version(request):
297-
return request.config.getoption("--rl-ray-full-version")
286+
def rl_ray_full_version():
287+
return RLEstimator.RAY_LATEST_VERSION
298288

299289

300290
@pytest.fixture(scope="module")
@@ -347,13 +337,19 @@ def tf_full_py_version(tf_full_version):
347337
return "py37"
348338

349339

350-
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
351-
def ei_tf_full_version(request):
352-
tf_ei_version = request.config.getoption("--ei-tf-full-version")
353-
if tf_ei_version is None:
354-
return request.param
355-
else:
356-
tf_ei_version
340+
@pytest.fixture(scope="module")
341+
def ei_tf_full_version():
342+
return "2.0.0"
343+
344+
345+
@pytest.fixture(scope="module")
346+
def xgboost_full_version():
347+
return "1.0-1"
348+
349+
350+
@pytest.fixture(scope="module")
351+
def xgboost_full_py_version():
352+
return "py3"
357353

358354

359355
@pytest.fixture(scope="session")
@@ -409,8 +405,3 @@ def pytest_generate_tests(metafunc):
409405
):
410406
params.append("ml.p2.xlarge")
411407
metafunc.parametrize("instance_type", params, scope="session")
412-
413-
414-
@pytest.fixture(scope="module")
415-
def xgboost_full_version(request):
416-
return request.config.getoption("--xgboost-full-version")

tests/integ/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import os
17-
import sys
1817

1918
import boto3
2019

@@ -23,7 +22,6 @@
2322
TUNING_DEFAULT_TIMEOUT_MINUTES = 20
2423
TRANSFORM_DEFAULT_TIMEOUT_MINUTES = 20
2524
AUTO_ML_DEFAULT_TIMEMOUT_MINUTES = 60
26-
PYTHON_VERSION = "py{}".format(sys.version_info.major)
2725

2826
# these regions have some p2 and p3 instances, but not enough for continuous testing
2927
HOSTING_NO_P2_REGIONS = [

tests/integ/test_airflow_config.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from sagemaker.utils import sagemaker_timestamp
4646
from sagemaker.workflow import airflow as sm_airflow
4747
from sagemaker.xgboost import XGBoost
48-
from tests.integ import datasets, DATA_DIR, PYTHON_VERSION
48+
from tests.integ import datasets, DATA_DIR
4949
from tests.integ.record_set import prepare_record_set_from_local_files
5050
from tests.integ.timeout import timeout
5151

@@ -404,7 +404,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
404404

405405
@pytest.mark.canary_quick
406406
def test_chainer_airflow_config_uploads_data_source_to_s3(
407-
sagemaker_local_session, cpu_instance_type, chainer_full_version
407+
sagemaker_local_session, cpu_instance_type, chainer_full_version, chainer_full_py_version
408408
):
409409
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
410410
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -416,7 +416,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
416416
train_instance_count=SINGLE_INSTANCE_COUNT,
417417
train_instance_type="local",
418418
framework_version=chainer_full_version,
419-
py_version=PYTHON_VERSION,
419+
py_version=chainer_full_py_version,
420420
sagemaker_session=sagemaker_local_session,
421421
hyperparameters={"epochs": 1},
422422
use_mpi=True,
@@ -545,20 +545,19 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
545545

546546

547547
@pytest.mark.canary_quick
548-
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="XGBoost container does not support Python 2.")
549548
def test_xgboost_airflow_config_uploads_data_source_to_s3(
550-
sagemaker_session, cpu_instance_type, xgboost_full_version
549+
sagemaker_session, cpu_instance_type, xgboost_full_version, xgboost_full_py_version
551550
):
552551
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
553552
xgboost = XGBoost(
554553
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
555554
framework_version=xgboost_full_version,
555+
py_version=xgboost_full_py_version,
556556
role=ROLE,
557557
sagemaker_session=sagemaker_session,
558558
train_instance_type=cpu_instance_type,
559559
train_instance_count=SINGLE_INSTANCE_COUNT,
560560
base_job_name="XGBoost job",
561-
py_version=PYTHON_VERSION,
562561
)
563562

564563
training_config = _build_airflow_workflow(

tests/integ/test_chainer_train.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,32 @@
2020
from sagemaker.chainer.estimator import Chainer
2121
from sagemaker.chainer.model import ChainerModel
2222
from sagemaker.utils import unique_name_from_base
23-
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
23+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2424
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2525

2626

2727
@pytest.fixture(scope="module")
28-
def chainer_local_training_job(sagemaker_local_session, chainer_full_version):
29-
return _run_mnist_training_job(sagemaker_local_session, "local", 1, chainer_full_version)
28+
def chainer_local_training_job(
29+
sagemaker_local_session, chainer_full_version, chainer_full_py_version
30+
):
31+
return _run_mnist_training_job(
32+
sagemaker_local_session, "local", 1, chainer_full_version, chainer_full_py_version
33+
)
3034

3135

3236
@pytest.mark.local_mode
33-
def test_distributed_cpu_training(sagemaker_local_session, chainer_full_version):
34-
_run_mnist_training_job(sagemaker_local_session, "local", 2, chainer_full_version)
37+
def test_distributed_cpu_training(
38+
sagemaker_local_session, chainer_full_version, chainer_full_py_version
39+
):
40+
_run_mnist_training_job(
41+
sagemaker_local_session, "local", 2, chainer_full_version, chainer_full_py_version
42+
)
3543

3644

3745
@pytest.mark.local_mode
38-
def test_training_with_additional_hyperparameters(sagemaker_local_session, chainer_full_version):
46+
def test_training_with_additional_hyperparameters(
47+
sagemaker_local_session, chainer_full_version, chainer_full_py_version
48+
):
3949
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
4050
data_path = os.path.join(DATA_DIR, "chainer_mnist")
4151

@@ -45,7 +55,7 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
4555
train_instance_count=1,
4656
train_instance_type="local",
4757
framework_version=chainer_full_version,
48-
py_version=PYTHON_VERSION,
58+
py_version=chainer_full_py_version,
4959
sagemaker_session=sagemaker_local_session,
5060
hyperparameters={"epochs": 1},
5161
use_mpi=True,
@@ -62,7 +72,9 @@ def test_training_with_additional_hyperparameters(sagemaker_local_session, chain
6272

6373
@pytest.mark.canary_quick
6474
@pytest.mark.regional_testing
65-
def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_type):
75+
def test_attach_deploy(
76+
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
77+
):
6678
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
6779
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
6880
data_path = os.path.join(DATA_DIR, "chainer_mnist")
@@ -71,7 +83,7 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
7183
entry_point=script_path,
7284
role="SageMakerRole",
7385
framework_version=chainer_full_version,
74-
py_version=PYTHON_VERSION,
86+
py_version=chainer_full_py_version,
7587
train_instance_count=1,
7688
train_instance_type=cpu_instance_type,
7789
sagemaker_session=sagemaker_session,
@@ -100,7 +112,12 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
100112

101113

102114
@pytest.mark.local_mode
103-
def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chainer_full_version):
115+
def test_deploy_model(
116+
chainer_local_training_job,
117+
sagemaker_local_session,
118+
chainer_full_version,
119+
chainer_full_py_version,
120+
):
104121
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
105122

106123
model = ChainerModel(
@@ -109,7 +126,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
109126
entry_point=script_path,
110127
sagemaker_session=sagemaker_local_session,
111128
framework_version=chainer_full_version,
112-
py_version=PYTHON_VERSION,
129+
py_version=chainer_full_py_version,
113130
)
114131

115132
predictor = model.deploy(1, "local")
@@ -120,7 +137,7 @@ def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chain
120137

121138

122139
def _run_mnist_training_job(
123-
sagemaker_session, instance_type, instance_count, chainer_full_version, wait=True
140+
sagemaker_session, instance_type, instance_count, chainer_version, py_version
124141
):
125142
script_path = (
126143
os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -133,8 +150,8 @@ def _run_mnist_training_job(
133150
chainer = Chainer(
134151
entry_point=script_path,
135152
role="SageMakerRole",
136-
framework_version=chainer_full_version,
137-
py_version=PYTHON_VERSION,
153+
framework_version=chainer_version,
154+
py_version=py_version,
138155
train_instance_count=instance_count,
139156
train_instance_type=instance_type,
140157
sagemaker_session=sagemaker_session,
@@ -147,7 +164,7 @@ def _run_mnist_training_job(
147164
test_input = "file://" + os.path.join(data_path, "test")
148165

149166
job_name = unique_name_from_base("test-chainer-training")
150-
chainer.fit({"train": train_input, "test": test_input}, wait=wait, job_name=job_name)
167+
chainer.fit({"train": train_input, "test": test_input}, job_name=job_name)
151168
return chainer
152169

153170

tests/integ/test_rl.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020
from sagemaker.rl import RLEstimator, RLFramework, RLToolkit
2121
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base
22-
from tests.integ import DATA_DIR, PYTHON_VERSION
22+
from tests.integ import DATA_DIR
2323
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2424

2525

2626
@pytest.mark.canary_quick
27-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
2827
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instance_type):
2928
estimator = _test_coach(
3029
sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version, cpu_instance_type
@@ -52,7 +51,6 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instanc
5251
assert 0 < action[0][1] < 1
5352

5453

55-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
5654
def test_coach_tf(sagemaker_session, rl_coach_tf_full_version, cpu_instance_type):
5755
estimator = _test_coach(
5856
sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version, cpu_instance_type
@@ -98,7 +96,6 @@ def _test_coach(sagemaker_session, rl_framework, rl_coach_version, cpu_instance_
9896

9997

10098
@pytest.mark.canary_quick
101-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.")
10299
def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
103100
source_dir = os.path.join(DATA_DIR, "ray_cartpole")
104101
cartpole = "train_ray.py"

tests/integ/test_tuner.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
datasets,
4646
vpc_test_utils,
4747
DATA_DIR,
48-
PYTHON_VERSION,
4948
TUNING_DEFAULT_TIMEOUT_MINUTES,
5049
)
5150
from tests.integ.record_set import prepare_record_set_from_local_files
@@ -687,7 +686,9 @@ def test_tuning_tf_vpc_multi(
687686

688687

689688
@pytest.mark.canary_quick
690-
def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_type):
689+
def test_tuning_chainer(
690+
sagemaker_session, chainer_full_version, chainer_full_py_version, cpu_instance_type
691+
):
691692
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
692693
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
693694
data_path = os.path.join(DATA_DIR, "chainer_mnist")
@@ -696,7 +697,7 @@ def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_ty
696697
entry_point=script_path,
697698
role="SageMakerRole",
698699
framework_version=chainer_full_version,
699-
py_version=PYTHON_VERSION,
700+
py_version=chainer_full_py_version,
700701
train_instance_count=1,
701702
train_instance_type=cpu_instance_type,
702703
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)