Skip to content

Commit 7919331

Browse files
authored
breaking: require framework_version, py_version for pytorch (#1568)
1 parent 9977206 commit 7919331

File tree

11 files changed

+116
-116
lines changed

11 files changed

+116
-116
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ directories ('train' and 'test').
154154
pytorch_estimator = PyTorch('pytorch-train.py',
155155
train_instance_type='ml.p3.2xlarge',
156156
train_instance_count=1,
157-
framework_version='1.0.0',
157+
framework_version='1.5.0',
158+
py_version='py3',
158159
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
159160
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
160161
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -247,7 +248,8 @@ operation.
247248
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
248249
train_instance_type='ml.p3.2xlarge',
249250
train_instance_count=1,
250-
framework_version='1.0.0')
251+
framework_version='1.5.0',
252+
py_version='py3')
251253
pytorch_estimator.fit('s3://my_bucket/my_training_data/')
252254
253255
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/pytorch/estimator.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
23-
python_deprecation_warning,
2422
is_version_equal_or_higher,
23+
python_deprecation_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.pytorch import defaults
2727
from sagemaker.pytorch.model import PyTorchModel
@@ -40,10 +40,10 @@ class PyTorch(Framework):
4040
def __init__(
4141
self,
4242
entry_point,
43+
framework_version=None,
44+
py_version=None,
4345
source_dir=None,
4446
hyperparameters=None,
45-
py_version=defaults.PYTHON_VERSION,
46-
framework_version=None,
4747
image_name=None,
4848
**kwargs
4949
):
@@ -69,6 +69,13 @@ def __init__(
6969
file which should be executed as the entry point to training.
7070
If ``source_dir`` is specified, then ``entry_point``
7171
must point to a file located at the root of ``source_dir``.
72+
framework_version (str): PyTorch version you want to use for
73+
executing your model training code. Defaults to ``None``. Required unless
74+
``image_name`` is provided. List of supported versions:
75+
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
76+
py_version (str): Python version you want to use for executing your
77+
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
78+
unless ``image_name`` is provided.
7279
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7380
with any other training source code dependencies aside from the entry
7481
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -80,12 +87,6 @@ def __init__(
8087
SageMaker. For convenience, this accepts other types for keys
8188
and values, but ``str()`` will be called to convert them before
8289
training.
83-
py_version (str): Python version you want to use for executing your
84-
model training code (default: 'py3'). One of 'py2' or 'py3'.
85-
framework_version (str): PyTorch version you want to use for
86-
executing your model training code. List of supported versions
87-
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
88-
If not specified, this will default to 0.4.
8990
image_name (str): If specified, the estimator will use this image
9091
for training and hosting, instead of selecting the appropriate
9192
SageMaker official image based on framework_version and
@@ -95,6 +96,9 @@ def __init__(
9596
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
9697
* ``custom-image:latest``
9798
99+
If ``framework_version`` or ``py_version`` are ``None``, then
100+
``image_name`` is required. If also ``None``, then a ``ValueError``
101+
will be raised.
98102
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
99103
constructor.
100104
@@ -104,28 +108,25 @@ def __init__(
104108
:class:`~sagemaker.estimator.Framework` and
105109
:class:`~sagemaker.estimator.EstimatorBase`.
106110
"""
107-
if framework_version is None:
111+
validate_version_or_image_args(framework_version, py_version, image_name)
112+
if py_version == "py2":
108113
logger.warning(
109-
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
114+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
110115
)
111-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
116+
self.framework_version = framework_version
117+
self.py_version = py_version
112118

113119
if "enable_sagemaker_metrics" not in kwargs:
114120
# enable sagemaker metrics for PT v1.3 or greater:
115-
if is_version_equal_or_higher([1, 3], self.framework_version):
121+
if self.framework_version and is_version_equal_or_higher(
122+
[1, 3], self.framework_version
123+
):
116124
kwargs["enable_sagemaker_metrics"] = True
117125

118126
super(PyTorch, self).__init__(
119127
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
120128
)
121129

122-
if py_version == "py2":
123-
logger.warning(
124-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
125-
)
126-
127-
self.py_version = py_version
128-
129130
def create_model(
130131
self,
131132
model_server_workers=None,
@@ -177,12 +178,12 @@ def create_model(
177178
self.model_data,
178179
role or self.role,
179180
entry_point or self.entry_point,
181+
framework_version=self.framework_version,
182+
py_version=self.py_version,
180183
source_dir=(source_dir or self._model_source_dir()),
181184
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
182185
container_log_level=self.container_log_level,
183186
code_location=self.code_location,
184-
py_version=self.py_version,
185-
framework_version=self.framework_version,
186187
model_server_workers=model_server_workers,
187188
sagemaker_session=self.sagemaker_session,
188189
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -210,15 +211,19 @@ class constructor
210211
image_name = init_params.pop("image")
211212
framework, py_version, tag, _ = framework_name_from_image(image_name)
212213

214+
if tag is None:
215+
framework_version = None
216+
else:
217+
framework_version = framework_version_from_tag(tag)
218+
init_params["framework_version"] = framework_version
219+
init_params["py_version"] = py_version
220+
213221
if not framework:
214222
# If we were unable to parse the framework name from the image it is not one of our
215223
# officially supported images, in this case just add the image to the init params.
216224
init_params["image_name"] = image_name
217225
return init_params
218226

219-
init_params["py_version"] = py_version
220-
init_params["framework_version"] = framework_version_from_tag(tag)
221-
222227
training_job_name = init_params["base_job_name"]
223228

224229
if framework != cls.__framework_name__:

src/sagemaker/pytorch/model.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
create_image_uri,
2222
model_code_key_prefix,
2323
python_deprecation_warning,
24-
empty_framework_version_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2727
from sagemaker.pytorch import defaults
@@ -66,9 +66,9 @@ def __init__(
6666
model_data,
6767
role,
6868
entry_point,
69-
image=None,
70-
py_version=defaults.PYTHON_VERSION,
7169
framework_version=None,
70+
py_version=None,
71+
image=None,
7272
predictor_cls=PyTorchPredictor,
7373
model_server_workers=None,
7474
**kwargs
@@ -87,12 +87,16 @@ def __init__(
8787
file which should be executed as the entry point to model
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
90-
image (str): A Docker image URI (default: None). If not specified, a
91-
default image for PyTorch will be used.
92-
py_version (str): Python version you want to use for executing your
93-
model training code (default: 'py3').
9490
framework_version (str): PyTorch version you want to use for
95-
executing your model training code.
91+
executing your model training code. Defaults to None. Required
92+
unless ``image`` is provided.
93+
py_version (str): Python version you want to use for executing your
94+
model training code. Defaults to ``None``. Required unless
95+
``image`` is provided.
96+
image (str): A Docker image URI (default: None). If not specified, a
97+
default image for PyTorch will be used. If ``framework_version``
98+
or ``py_version`` are ``None``, then ``image`` is required. If
99+
also ``None``, then a ``ValueError`` will be raised.
96100
predictor_cls (callable[str, sagemaker.session.Session]): A function
97101
to call to create a predictor with an endpoint name and
98102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -109,22 +113,18 @@ def __init__(
109113
:class:`~sagemaker.model.FrameworkModel` and
110114
:class:`~sagemaker.model.Model`.
111115
"""
112-
super(PyTorchModel, self).__init__(
113-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
114-
)
115-
116-
if py_version == "py2":
116+
validate_version_or_image_args(framework_version, py_version, image)
117+
if py_version and py_version == "py2":
117118
logger.warning(
118119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
120123

121-
if framework_version is None:
122-
logger.warning(
123-
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
124-
)
124+
super(PyTorchModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
125127

126-
self.py_version = py_version
127-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
128128
self.model_server_workers = model_server_workers
129129

130130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def pytorch_version(request):
173173
return request.param
174174

175175

176+
@pytest.fixture(scope="module", params=["py2", "py3"])
177+
def pytorch_py_version(request):
178+
return request.param
179+
180+
176181
@pytest.fixture(scope="module", params=["0.20.0"])
177182
def sklearn_version(request):
178183
return request.param

tests/integ/test_airflow_config.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,14 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu
608608

609609
@pytest.mark.canary_quick
610610
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
611-
sagemaker_session, cpu_instance_type
611+
sagemaker_session, cpu_instance_type, pytorch_full_version
612612
):
613613
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
614614
estimator = PyTorch(
615615
entry_point=PYTORCH_MNIST_SCRIPT,
616616
role=ROLE,
617-
framework_version="1.1.0",
617+
framework_version=pytorch_full_version,
618+
py_version="py3",
618619
train_instance_count=2,
619620
train_instance_type=cpu_instance_type,
620621
hyperparameters={"epochs": 6, "backend": "gloo"},
@@ -639,6 +640,7 @@ def test_pytorch_12_airflow_config_uploads_data_source_to_s3_when_inputs_not_pro
639640
entry_point=PYTORCH_MNIST_SCRIPT,
640641
role=ROLE,
641642
framework_version="1.2.0",
643+
py_version="py3",
642644
train_instance_count=2,
643645
train_instance_type=cpu_instance_type,
644646
hyperparameters={"epochs": 6, "backend": "gloo"},

tests/integ/test_git.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from tests.integ import lock as lock
2323
from sagemaker.mxnet.estimator import MXNet
24-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2524
from sagemaker.pytorch.estimator import PyTorch
2625
from sagemaker.sklearn.estimator import SKLearn
2726
from sagemaker.sklearn.model import SKLearnModel
@@ -56,11 +55,14 @@ def test_github(sagemaker_local_session):
5655
script_path = "mnist.py"
5756
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
5857
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
58+
59+
# TODO: fails for newer pytorch versions when using MNIST from torchvision due to missing dataset
60+
# "algo-1-v767u_1 | RuntimeError: Dataset not found. You can use download=True to download it"
5961
pytorch = PyTorch(
6062
entry_point=script_path,
6163
role="SageMakerRole",
6264
source_dir="pytorch",
63-
framework_version=PYTORCH_VERSION,
65+
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
6466
py_version=PYTHON_VERSION,
6567
train_instance_count=1,
6668
train_instance_type="local",

tests/integ/test_pytorch_train.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,9 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
9898
predictor.delete_endpoint()
9999

100100

101-
@pytest.mark.skipif(
102-
PYTHON_VERSION == "py2",
103-
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
104-
)
105-
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
101+
def test_deploy_model(
102+
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
103+
):
106104
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
107105

108106
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -114,6 +112,8 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
114112
model_data,
115113
"SageMakerRole",
116114
entry_point=MNIST_SCRIPT,
115+
framework_version=pytorch_full_version,
116+
py_version="py3",
117117
sagemaker_session=sagemaker_session,
118118
)
119119
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -125,10 +125,6 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
125125
assert output.shape == (batch_size, 10)
126126

127127

128-
@pytest.mark.skipif(
129-
PYTHON_VERSION == "py2",
130-
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
131-
)
132128
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
133129
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
134130

@@ -139,6 +135,7 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
139135
"SageMakerRole",
140136
entry_point="mnist.py",
141137
framework_version="1.4.0",
138+
py_version="py3",
142139
sagemaker_session=sagemaker_session,
143140
)
144141
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -160,8 +157,9 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
160157
pytorch = PyTorchModel(
161158
model_data,
162159
"SageMakerRole",
163-
framework_version="1.3.1",
164160
entry_point=EIA_SCRIPT,
161+
framework_version="1.3.1",
162+
py_version="py3",
165163
sagemaker_session=sagemaker_session,
166164
)
167165
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -193,7 +191,7 @@ def _get_pytorch_estimator(
193191
entry_point=entry_point,
194192
role="SageMakerRole",
195193
framework_version=pytorch_full_version,
196-
py_version=PYTHON_VERSION,
194+
py_version="py3",
197195
train_instance_count=1,
198196
train_instance_type=instance_type,
199197
sagemaker_session=sagemaker_session,

tests/integ/test_source_dirs.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3030
with open(lib, "w") as f:
3131
f.write("def question(to_anything): return 42")
3232

33+
# TODO: fails on newer versions of pytorch in call to np.load(BytesIO(stream.read()))
34+
# "ValueError: Cannot load file containing pickled data when allow_pickle=False"
3335
estimator = PyTorch(
3436
entry_point="train.py",
3537
role="SageMakerRole",
3638
source_dir=source_dir,
3739
dependencies=[lib],
40+
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
3841
py_version=PYTHON_VERSION,
3942
train_instance_count=1,
4043
train_instance_type="local",

tests/integ/test_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_transform_pytorch_vpc_custom_model_bucket(
178178
entry_point=os.path.join(data_dir, "mnist.py"),
179179
role="SageMakerRole",
180180
framework_version=pytorch_full_version,
181-
py_version=PYTHON_VERSION,
181+
py_version="py3",
182182
sagemaker_session=sagemaker_session,
183183
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
184184
code_location="s3://{}".format(custom_bucket_name),

tests/integ/test_tuner.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -819,15 +819,16 @@ def test_tuning_chainer(sagemaker_session, cpu_instance_type):
819819
reason="This test has always failed, but the failure was masked by a bug. "
820820
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
821821
)
822-
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
822+
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type, pytorch_full_version):
823823
mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist")
824824
mnist_script = os.path.join(mnist_dir, "mnist.py")
825825

826826
estimator = PyTorch(
827827
entry_point=mnist_script,
828828
role="SageMakerRole",
829829
train_instance_count=1,
830-
py_version=PYTHON_VERSION,
830+
framework_version=pytorch_full_version,
831+
py_version="py3",
831832
train_instance_type=cpu_instance_type,
832833
sagemaker_session=sagemaker_session,
833834
)

0 commit comments

Comments
 (0)