19
19
from sagemaker .fw_utils import (
20
20
framework_name_from_image ,
21
21
framework_version_from_tag ,
22
- empty_framework_version_warning ,
23
- python_deprecation_warning ,
24
22
is_version_equal_or_higher ,
23
+ python_deprecation_warning ,
24
+ validate_version_or_image_args ,
25
25
)
26
26
from sagemaker .pytorch import defaults
27
27
from sagemaker .pytorch .model import PyTorchModel
@@ -40,10 +40,10 @@ class PyTorch(Framework):
40
40
def __init__ (
41
41
self ,
42
42
entry_point ,
43
+ framework_version = None ,
44
+ py_version = None ,
43
45
source_dir = None ,
44
46
hyperparameters = None ,
45
- py_version = defaults .PYTHON_VERSION ,
46
- framework_version = None ,
47
47
image_name = None ,
48
48
** kwargs
49
49
):
@@ -69,6 +69,13 @@ def __init__(
69
69
file which should be executed as the entry point to training.
70
70
If ``source_dir`` is specified, then ``entry_point``
71
71
must point to a file located at the root of ``source_dir``.
72
+ framework_version (str): PyTorch version you want to use for
73
+ executing your model training code. Defaults to ``None``. Required unless
74
+ ``image_name`` is provided. List of supported versions:
75
+ https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
76
+ py_version (str): Python version you want to use for executing your
77
+ model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
78
+ unless ``image_name`` is provided.
72
79
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
73
80
with any other training source code dependencies aside from the entry
74
81
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -80,12 +87,6 @@ def __init__(
80
87
SageMaker. For convenience, this accepts other types for keys
81
88
and values, but ``str()`` will be called to convert them before
82
89
training.
83
- py_version (str): Python version you want to use for executing your
84
- model training code (default: 'py3'). One of 'py2' or 'py3'.
85
- framework_version (str): PyTorch version you want to use for
86
- executing your model training code. List of supported versions
87
- https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
88
- If not specified, this will default to 0.4.
89
90
image_name (str): If specified, the estimator will use this image
90
91
for training and hosting, instead of selecting the appropriate
91
92
SageMaker official image based on framework_version and
@@ -95,6 +96,9 @@ def __init__(
95
96
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
96
97
* ``custom-image:latest``
97
98
99
+ If ``framework_version`` or ``py_version`` are ``None``, then
100
+ ``image_name`` is required. If also ``None``, then a ``ValueError``
101
+ will be raised.
98
102
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
99
103
constructor.
100
104
@@ -104,28 +108,25 @@ def __init__(
104
108
:class:`~sagemaker.estimator.Framework` and
105
109
:class:`~sagemaker.estimator.EstimatorBase`.
106
110
"""
107
- if framework_version is None :
111
+ validate_version_or_image_args (framework_version , py_version , image_name )
112
+ if py_version == "py2" :
108
113
logger .warning (
109
- empty_framework_version_warning ( defaults . PYTORCH_VERSION , self . LATEST_VERSION )
114
+ python_deprecation_warning ( self . __framework_name__ , defaults . LATEST_PY2_VERSION )
110
115
)
111
- self .framework_version = framework_version or defaults .PYTORCH_VERSION
116
+ self .framework_version = framework_version
117
+ self .py_version = py_version
112
118
113
119
if "enable_sagemaker_metrics" not in kwargs :
114
120
# enable sagemaker metrics for PT v1.3 or greater:
115
- if is_version_equal_or_higher ([1 , 3 ], self .framework_version ):
121
+ if self .framework_version and is_version_equal_or_higher (
122
+ [1 , 3 ], self .framework_version
123
+ ):
116
124
kwargs ["enable_sagemaker_metrics" ] = True
117
125
118
126
super (PyTorch , self ).__init__ (
119
127
entry_point , source_dir , hyperparameters , image_name = image_name , ** kwargs
120
128
)
121
129
122
- if py_version == "py2" :
123
- logger .warning (
124
- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
125
- )
126
-
127
- self .py_version = py_version
128
-
129
130
def create_model (
130
131
self ,
131
132
model_server_workers = None ,
@@ -177,12 +178,12 @@ def create_model(
177
178
self .model_data ,
178
179
role or self .role ,
179
180
entry_point or self .entry_point ,
181
+ framework_version = self .framework_version ,
182
+ py_version = self .py_version ,
180
183
source_dir = (source_dir or self ._model_source_dir ()),
181
184
enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
182
185
container_log_level = self .container_log_level ,
183
186
code_location = self .code_location ,
184
- py_version = self .py_version ,
185
- framework_version = self .framework_version ,
186
187
model_server_workers = model_server_workers ,
187
188
sagemaker_session = self .sagemaker_session ,
188
189
vpc_config = self .get_vpc_config (vpc_config_override ),
@@ -210,15 +211,19 @@ class constructor
210
211
image_name = init_params .pop ("image" )
211
212
framework , py_version , tag , _ = framework_name_from_image (image_name )
212
213
214
+ if tag is None :
215
+ framework_version = None
216
+ else :
217
+ framework_version = framework_version_from_tag (tag )
218
+ init_params ["framework_version" ] = framework_version
219
+ init_params ["py_version" ] = py_version
220
+
213
221
if not framework :
214
222
# If we were unable to parse the framework name from the image it is not one of our
215
223
# officially supported images, in this case just add the image to the init params.
216
224
init_params ["image_name" ] = image_name
217
225
return init_params
218
226
219
- init_params ["py_version" ] = py_version
220
- init_params ["framework_version" ] = framework_version_from_tag (tag )
221
-
222
227
training_job_name = init_params ["base_job_name" ]
223
228
224
229
if framework != cls .__framework_name__ :
0 commit comments