19
19
from sagemaker .fw_registry import default_framework_uri
20
20
from sagemaker .fw_utils import (
21
21
framework_name_from_image ,
22
- empty_framework_version_warning ,
23
- python_deprecation_warning ,
22
+ framework_version_from_tag ,
23
+ validate_version_or_image_args ,
24
24
)
25
25
from sagemaker .sklearn import defaults
26
26
from sagemaker .sklearn .model import SKLearnModel
@@ -37,10 +37,10 @@ class SKLearn(Framework):
37
37
def __init__ (
38
38
self ,
39
39
entry_point ,
40
- framework_version = defaults .SKLEARN_VERSION ,
40
+ framework_version = None ,
41
+ py_version = "py3" ,
41
42
source_dir = None ,
42
43
hyperparameters = None ,
43
- py_version = "py3" ,
44
44
image_name = None ,
45
45
** kwargs
46
46
):
@@ -68,8 +68,13 @@ def __init__(
68
68
If ``source_dir`` is specified, then ``entry_point``
69
69
must point to a file located at the root of ``source_dir``.
70
70
framework_version (str): Scikit-learn version you want to use for
71
- executing your model training code. List of supported versions
71
+ executing your model training code. Defaults to ``None``. Required
72
+ unless ``image_name`` is provided. List of supported versions:
72
73
https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
74
+ py_version (str): Python version you want to use for executing your
75
+ model training code (default: 'py3'). Currently, 'py3' is the only
76
+ supported version. If ``None`` is passed in, ``image_name`` must be
77
+ provided.
73
78
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
74
79
with any other training source code dependencies aside from the entry
75
80
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -81,15 +86,18 @@ def __init__(
81
86
SageMaker. For convenience, this accepts other types for keys
82
87
and values, but ``str()`` will be called to convert them before
83
88
training.
84
- py_version (str): Python version you want to use for executing your
85
- model training code (default: 'py3'). One of 'py2' or 'py3'.
86
89
image_name (str): If specified, the estimator will use this image
87
90
for training and hosting, instead of selecting the appropriate
88
91
SageMaker official image based on framework_version and
89
92
py_version. It can be an ECR url or dockerhub image and tag.
93
+
90
94
Examples:
91
95
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
92
96
custom-image:latest.
97
+
98
+ If ``framework_version`` or ``py_version`` are ``None``, then
99
+ ``image_name`` is required. If also ``None``, then a ``ValueError``
100
+ will be raised.
93
101
**kwargs: Additional kwargs passed to the
94
102
:class:`~sagemaker.estimator.Framework` constructor.
95
103
@@ -99,6 +107,14 @@ def __init__(
99
107
:class:`~sagemaker.estimator.Framework` and
100
108
:class:`~sagemaker.estimator.EstimatorBase`.
101
109
"""
110
+ validate_version_or_image_args (framework_version , py_version , image_name )
111
+ if py_version and py_version != "py3" :
112
+ raise AttributeError (
113
+ "Scikit-learn image only supports Python 3. Please use 'py3' for py_version."
114
+ )
115
+ self .framework_version = framework_version
116
+ self .py_version = py_version
117
+
102
118
# SciKit-Learn does not support distributed training or training on GPU instance types.
103
119
# Fail fast.
104
120
train_instance_type = kwargs .get ("train_instance_type" )
@@ -112,6 +128,7 @@ def __init__(
112
128
"Please remove the 'train_instance_count' argument or set "
113
129
"'train_instance_count=1' when initializing SKLearn."
114
130
)
131
+
115
132
super (SKLearn , self ).__init__ (
116
133
entry_point ,
117
134
source_dir ,
@@ -120,19 +137,6 @@ def __init__(
120
137
** dict (kwargs , train_instance_count = 1 )
121
138
)
122
139
123
- if py_version == "py2" :
124
- logger .warning (
125
- python_deprecation_warning (self .__framework_name__ , defaults .LATEST_PY2_VERSION )
126
- )
127
-
128
- self .py_version = py_version
129
-
130
- if framework_version is None :
131
- logger .warning (
132
- empty_framework_version_warning (defaults .SKLEARN_VERSION , defaults .SKLEARN_VERSION )
133
- )
134
- self .framework_version = framework_version or defaults .SKLEARN_VERSION
135
-
136
140
if image_name is None :
137
141
image_tag = "{}-{}-{}" .format (framework_version , "cpu" , py_version )
138
142
self .image_name = default_framework_uri (
@@ -216,28 +220,40 @@ class constructor
216
220
Args:
217
221
job_details: the returned job details from a describe_training_job
218
222
API call.
219
- model_channel_name:
223
+ model_channel_name (str): Name of the channel where pre-trained
224
+ model data will be downloaded (default: None).
220
225
221
226
Returns:
222
227
dictionary: The transformed init_params
223
228
"""
224
- init_params = super (SKLearn , cls )._prepare_init_params_from_job_description (job_details )
225
-
229
+ init_params = super (SKLearn , cls )._prepare_init_params_from_job_description (
230
+ job_details , model_channel_name
231
+ )
226
232
image_name = init_params .pop ("image" )
227
- framework , py_version , _ , _ = framework_name_from_image (image_name )
233
+ framework , py_version , tag , _ = framework_name_from_image (image_name )
234
+
235
+ if tag is None :
236
+ framework_version = None
237
+ else :
238
+ framework_version = framework_version_from_tag (tag )
239
+ init_params ["framework_version" ] = framework_version
228
240
init_params ["py_version" ] = py_version
229
241
242
+ if not framework :
243
+ # If we were unable to parse the framework name from the image it is not one of our
244
+ # officially supported images, in this case just add the image to the init params.
245
+ init_params ["image_name" ] = image_name
246
+ return init_params
247
+
248
+ training_job_name = init_params ["base_job_name" ]
249
+
230
250
if framework and framework != cls .__framework_name__ :
231
- training_job_name = init_params ["base_job_name" ]
232
251
raise ValueError (
233
252
"Training job: {} didn't use image for requested framework" .format (
234
253
training_job_name
235
254
)
236
255
)
237
- if not framework :
238
- # If we were unable to parse the framework name from the image it is not one of our
239
- # officially supported images, in this case just add the image to the init params.
240
- init_params ["image_name" ] = image_name
256
+
241
257
return init_params
242
258
243
259
0 commit comments