Skip to content

Commit 04f1858

Browse files
authored
Merge branch 'zwei' into require-framework-version-tensorflow
2 parents 99279ff + 9df3f5a commit 04f1858

28 files changed

+230
-164
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ directories ('train' and 'test').
154154
pytorch_estimator = PyTorch('pytorch-train.py',
155155
train_instance_type='ml.p3.2xlarge',
156156
train_instance_count=1,
157-
framework_version='1.0.0',
157+
framework_version='1.5.0',
158+
py_version='py3',
158159
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
159160
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
160161
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -247,7 +248,8 @@ operation.
247248
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
248249
train_instance_type='ml.p3.2xlarge',
249250
train_instance_count=1,
250-
framework_version='1.0.0')
251+
framework_version='1.5.0',
252+
py_version='py3')
251253
pytorch_estimator.fit('s3://my_bucket/my_training_data/')
252254
253255
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/amazon/factorization_machines.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
310310
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
311311
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
312312
super(FactorizationMachinesModel, self).__init__(
313-
model_data,
314313
image,
314+
model_data,
315315
role,
316316
predictor_cls=FactorizationMachinesPredictor,
317317
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/ipinsights.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
216216
)
217217

218218
super(IPInsightsModel, self).__init__(
219-
model_data,
220219
image,
220+
model_data,
221221
role,
222222
predictor_cls=IPInsightsPredictor,
223223
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/kmeans.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
241241
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
242242
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
243243
super(KMeansModel, self).__init__(
244-
model_data,
245244
image,
245+
model_data,
246246
role,
247247
predictor_cls=KMeansPredictor,
248248
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/knn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
231231
registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo
232232
)
233233
super(KNNModel, self).__init__(
234-
model_data,
235234
image,
235+
model_data,
236236
role,
237237
predictor_cls=KNNPredictor,
238238
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/lda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
215215
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
216216
)
217217
super(LDAModel, self).__init__(
218-
model_data,
219218
image,
219+
model_data,
220220
role,
221221
predictor_cls=LDAPredictor,
222222
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/linear_learner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
474474
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
475475
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
476476
super(LinearLearnerModel, self).__init__(
477-
model_data,
478477
image,
478+
model_data,
479479
role,
480480
predictor_cls=LinearLearnerPredictor,
481481
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/ntm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
245245
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
246246
)
247247
super(NTMModel, self).__init__(
248-
model_data,
249248
image,
249+
model_data,
250250
role,
251251
predictor_cls=NTMPredictor,
252252
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/object2vec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
355355
registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo
356356
)
357357
super(Object2VecModel, self).__init__(
358-
model_data,
359358
image,
359+
model_data,
360360
role,
361361
predictor_cls=RealTimePredictor,
362362
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/pca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
225225
repo = "{}:{}".format(PCA.repo_name, PCA.repo_version)
226226
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
227227
super(PCAModel, self).__init__(
228-
model_data,
229228
image,
229+
model_data,
230230
role,
231231
predictor_cls=PCAPredictor,
232232
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/randomcutforest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
206206
registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo
207207
)
208208
super(RandomCutForestModel, self).__init__(
209-
model_data,
210209
image,
210+
model_data,
211211
role,
212212
predictor_cls=RandomCutForestPredictor,
213213
sagemaker_session=sagemaker_session,

src/sagemaker/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,8 @@ def predict_wrapper(endpoint, session):
13981398
kwargs["enable_network_isolation"] = self.enable_network_isolation()
13991399

14001400
return Model(
1401-
self.model_data,
14021401
image or self.train_image(),
1402+
self.model_data,
14031403
role,
14041404
vpc_config=self.get_vpc_config(vpc_config_override),
14051405
sagemaker_session=self.sagemaker_session,

src/sagemaker/model.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class Model(object):
6060

6161
def __init__(
6262
self,
63-
model_data,
6463
image,
64+
model_data=None,
6565
role=None,
6666
predictor_cls=None,
6767
env=None,
@@ -74,9 +74,9 @@ def __init__(
7474
"""Initialize an SageMaker ``Model``.
7575
7676
Args:
77-
model_data (str): The S3 location of a SageMaker model data
78-
``.tar.gz`` file.
7977
image (str): A Docker image URI.
78+
model_data (str): The S3 location of a SageMaker model data
79+
``.tar.gz`` file (default: None).
8080
role (str): An AWS IAM role (either name or full ARN). The Amazon
8181
SageMaker training jobs and APIs that create Amazon SageMaker
8282
endpoints use this role to access training data and model
@@ -361,6 +361,8 @@ def compile(
361361
)
362362
if job_name is None:
363363
raise ValueError("You must provide a compilation job name")
364+
if self.model_data is None:
365+
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
364366

365367
framework = framework.upper()
366368
framework_version = self._get_framework_version() or framework_version
@@ -778,8 +780,8 @@ def __init__(
778780
:class:`~sagemaker.model.Model`.
779781
"""
780782
super(FrameworkModel, self).__init__(
781-
model_data,
782783
image,
784+
model_data,
783785
role,
784786
predictor_cls=predictor_cls,
785787
env=env,

src/sagemaker/multidatamodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def __init__(
103103
# Set the ``Model`` parameters if the model parameter is not specified
104104
if not self.model:
105105
super(MultiDataModel, self).__init__(
106-
self.model_data_prefix,
107106
image,
107+
self.model_data_prefix,
108108
role,
109109
name=self.name,
110110
sagemaker_session=self.sagemaker_session,

src/sagemaker/pytorch/estimator.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
23-
python_deprecation_warning,
2422
is_version_equal_or_higher,
23+
python_deprecation_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.pytorch import defaults
2727
from sagemaker.pytorch.model import PyTorchModel
@@ -40,10 +40,10 @@ class PyTorch(Framework):
4040
def __init__(
4141
self,
4242
entry_point,
43+
framework_version=None,
44+
py_version=None,
4345
source_dir=None,
4446
hyperparameters=None,
45-
py_version=defaults.PYTHON_VERSION,
46-
framework_version=None,
4747
image_name=None,
4848
**kwargs
4949
):
@@ -69,6 +69,13 @@ def __init__(
6969
file which should be executed as the entry point to training.
7070
If ``source_dir`` is specified, then ``entry_point``
7171
must point to a file located at the root of ``source_dir``.
72+
framework_version (str): PyTorch version you want to use for
73+
executing your model training code. Defaults to ``None``. Required unless
74+
``image_name`` is provided. List of supported versions:
75+
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
76+
py_version (str): Python version you want to use for executing your
77+
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
78+
unless ``image_name`` is provided.
7279
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7380
with any other training source code dependencies aside from the entry
7481
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -80,12 +87,6 @@ def __init__(
8087
SageMaker. For convenience, this accepts other types for keys
8188
and values, but ``str()`` will be called to convert them before
8289
training.
83-
py_version (str): Python version you want to use for executing your
84-
model training code (default: 'py3'). One of 'py2' or 'py3'.
85-
framework_version (str): PyTorch version you want to use for
86-
executing your model training code. List of supported versions
87-
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
88-
If not specified, this will default to 0.4.
8990
image_name (str): If specified, the estimator will use this image
9091
for training and hosting, instead of selecting the appropriate
9192
SageMaker official image based on framework_version and
@@ -95,6 +96,9 @@ def __init__(
9596
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
9697
* ``custom-image:latest``
9798
99+
If ``framework_version`` or ``py_version`` are ``None``, then
100+
``image_name`` is required. If also ``None``, then a ``ValueError``
101+
will be raised.
98102
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
99103
constructor.
100104
@@ -104,28 +108,25 @@ def __init__(
104108
:class:`~sagemaker.estimator.Framework` and
105109
:class:`~sagemaker.estimator.EstimatorBase`.
106110
"""
107-
if framework_version is None:
111+
validate_version_or_image_args(framework_version, py_version, image_name)
112+
if py_version == "py2":
108113
logger.warning(
109-
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
114+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
110115
)
111-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
116+
self.framework_version = framework_version
117+
self.py_version = py_version
112118

113119
if "enable_sagemaker_metrics" not in kwargs:
114120
# enable sagemaker metrics for PT v1.3 or greater:
115-
if is_version_equal_or_higher([1, 3], self.framework_version):
121+
if self.framework_version and is_version_equal_or_higher(
122+
[1, 3], self.framework_version
123+
):
116124
kwargs["enable_sagemaker_metrics"] = True
117125

118126
super(PyTorch, self).__init__(
119127
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
120128
)
121129

122-
if py_version == "py2":
123-
logger.warning(
124-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
125-
)
126-
127-
self.py_version = py_version
128-
129130
def create_model(
130131
self,
131132
model_server_workers=None,
@@ -177,12 +178,12 @@ def create_model(
177178
self.model_data,
178179
role or self.role,
179180
entry_point or self.entry_point,
181+
framework_version=self.framework_version,
182+
py_version=self.py_version,
180183
source_dir=(source_dir or self._model_source_dir()),
181184
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
182185
container_log_level=self.container_log_level,
183186
code_location=self.code_location,
184-
py_version=self.py_version,
185-
framework_version=self.framework_version,
186187
model_server_workers=model_server_workers,
187188
sagemaker_session=self.sagemaker_session,
188189
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -210,15 +211,19 @@ class constructor
210211
image_name = init_params.pop("image")
211212
framework, py_version, tag, _ = framework_name_from_image(image_name)
212213

214+
if tag is None:
215+
framework_version = None
216+
else:
217+
framework_version = framework_version_from_tag(tag)
218+
init_params["framework_version"] = framework_version
219+
init_params["py_version"] = py_version
220+
213221
if not framework:
214222
# If we were unable to parse the framework name from the image it is not one of our
215223
# officially supported images, in this case just add the image to the init params.
216224
init_params["image_name"] = image_name
217225
return init_params
218226

219-
init_params["py_version"] = py_version
220-
init_params["framework_version"] = framework_version_from_tag(tag)
221-
222227
training_job_name = init_params["base_job_name"]
223228

224229
if framework != cls.__framework_name__:

src/sagemaker/pytorch/model.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
create_image_uri,
2222
model_code_key_prefix,
2323
python_deprecation_warning,
24-
empty_framework_version_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2727
from sagemaker.pytorch import defaults
@@ -66,9 +66,9 @@ def __init__(
6666
model_data,
6767
role,
6868
entry_point,
69-
image=None,
70-
py_version=defaults.PYTHON_VERSION,
7169
framework_version=None,
70+
py_version=None,
71+
image=None,
7272
predictor_cls=PyTorchPredictor,
7373
model_server_workers=None,
7474
**kwargs
@@ -87,12 +87,16 @@ def __init__(
8787
file which should be executed as the entry point to model
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
90-
image (str): A Docker image URI (default: None). If not specified, a
91-
default image for PyTorch will be used.
92-
py_version (str): Python version you want to use for executing your
93-
model training code (default: 'py3').
9490
framework_version (str): PyTorch version you want to use for
95-
executing your model training code.
91+
executing your model training code. Defaults to None. Required
92+
unless ``image`` is provided.
93+
py_version (str): Python version you want to use for executing your
94+
model training code. Defaults to ``None``. Required unless
95+
``image`` is provided.
96+
image (str): A Docker image URI (default: None). If not specified, a
97+
default image for PyTorch will be used. If ``framework_version``
98+
or ``py_version`` are ``None``, then ``image`` is required. If
99+
also ``None``, then a ``ValueError`` will be raised.
96100
predictor_cls (callable[str, sagemaker.session.Session]): A function
97101
to call to create a predictor with an endpoint name and
98102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -109,22 +113,18 @@ def __init__(
109113
:class:`~sagemaker.model.FrameworkModel` and
110114
:class:`~sagemaker.model.Model`.
111115
"""
112-
super(PyTorchModel, self).__init__(
113-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
114-
)
115-
116-
if py_version == "py2":
116+
validate_version_or_image_args(framework_version, py_version, image)
117+
if py_version and py_version == "py2":
117118
logger.warning(
118119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
120123

121-
if framework_version is None:
122-
logger.warning(
123-
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
124-
)
124+
super(PyTorchModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
125127

126-
self.py_version = py_version
127-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
128128
self.model_server_workers = model_server_workers
129129

130130
def prepare_container_def(self, instance_type, accelerator_type=None):

0 commit comments

Comments
 (0)