Skip to content

Commit 9977206

Browse files
authored
breaking: require framework_version, py_version for sklearn (#1576)
1 parent 5233dfc commit 9977206

File tree

6 files changed

+113
-81
lines changed

6 files changed

+113
-81
lines changed

doc/frameworks/sklearn/using_sklearn.rst

+5-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ To train a Scikit-learn model by using the SageMaker Python SDK:
3131
Prepare a Scikit-learn Training Script
3232
======================================
3333

34-
Your Scikit-learn training script must be a Python 2.7 or 3.6 compatible source file.
34+
Your Scikit-learn training script must be a Python 3.6 compatible source file.
3535

3636
The training script is similar to a training script you might run outside of SageMaker, but you
3737
can access useful properties about the training environment through various environment variables.
@@ -465,8 +465,10 @@ The following code sample shows how to do this, using the ``SKLearnModel`` class
465465

466466
.. code:: python
467467
468-
sklearn_model = SKLearnModel(model_data="s3://bucket/model.tar.gz", role="SageMakerRole",
469-
entry_point="transform_script.py")
468+
sklearn_model = SKLearnModel(model_data="s3://bucket/model.tar.gz",
469+
role="SageMakerRole",
470+
entry_point="transform_script.py",
471+
framework_version="0.20.0")
470472
471473
predictor = sklearn_model.deploy(instance_type="ml.c4.xlarge", initial_instance_count=1)
472474

src/sagemaker/sklearn/estimator.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import (
2121
framework_name_from_image,
22-
empty_framework_version_warning,
23-
python_deprecation_warning,
22+
framework_version_from_tag,
23+
validate_version_or_image_args,
2424
)
2525
from sagemaker.sklearn import defaults
2626
from sagemaker.sklearn.model import SKLearnModel
@@ -37,10 +37,10 @@ class SKLearn(Framework):
3737
def __init__(
3838
self,
3939
entry_point,
40-
framework_version=defaults.SKLEARN_VERSION,
40+
framework_version=None,
41+
py_version="py3",
4142
source_dir=None,
4243
hyperparameters=None,
43-
py_version="py3",
4444
image_name=None,
4545
**kwargs
4646
):
@@ -68,8 +68,13 @@ def __init__(
6868
If ``source_dir`` is specified, then ``entry_point``
6969
must point to a file located at the root of ``source_dir``.
7070
framework_version (str): Scikit-learn version you want to use for
71-
executing your model training code. List of supported versions
71+
executing your model training code. Defaults to ``None``. Required
72+
unless ``image_name`` is provided. List of supported versions:
7273
https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
74+
py_version (str): Python version you want to use for executing your
75+
model training code (default: 'py3'). Currently, 'py3' is the only
76+
supported version. If ``None`` is passed in, ``image_name`` must be
77+
provided.
7378
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7479
with any other training source code dependencies aside from the entry
7580
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -81,15 +86,18 @@ def __init__(
8186
SageMaker. For convenience, this accepts other types for keys
8287
and values, but ``str()`` will be called to convert them before
8388
training.
84-
py_version (str): Python version you want to use for executing your
85-
model training code (default: 'py3'). One of 'py2' or 'py3'.
8689
image_name (str): If specified, the estimator will use this image
8790
for training and hosting, instead of selecting the appropriate
8891
SageMaker official image based on framework_version and
8992
py_version. It can be an ECR url or dockerhub image and tag.
93+
9094
Examples:
9195
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
9296
custom-image:latest.
97+
98+
If ``framework_version`` or ``py_version`` are ``None``, then
99+
``image_name`` is required. If also ``None``, then a ``ValueError``
100+
will be raised.
93101
**kwargs: Additional kwargs passed to the
94102
:class:`~sagemaker.estimator.Framework` constructor.
95103
@@ -99,6 +107,14 @@ def __init__(
99107
:class:`~sagemaker.estimator.Framework` and
100108
:class:`~sagemaker.estimator.EstimatorBase`.
101109
"""
110+
validate_version_or_image_args(framework_version, py_version, image_name)
111+
if py_version and py_version != "py3":
112+
raise AttributeError(
113+
"Scikit-learn image only supports Python 3. Please use 'py3' for py_version."
114+
)
115+
self.framework_version = framework_version
116+
self.py_version = py_version
117+
102118
# SciKit-Learn does not support distributed training or training on GPU instance types.
103119
# Fail fast.
104120
train_instance_type = kwargs.get("train_instance_type")
@@ -112,6 +128,7 @@ def __init__(
112128
"Please remove the 'train_instance_count' argument or set "
113129
"'train_instance_count=1' when initializing SKLearn."
114130
)
131+
115132
super(SKLearn, self).__init__(
116133
entry_point,
117134
source_dir,
@@ -120,19 +137,6 @@ def __init__(
120137
**dict(kwargs, train_instance_count=1)
121138
)
122139

123-
if py_version == "py2":
124-
logger.warning(
125-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
126-
)
127-
128-
self.py_version = py_version
129-
130-
if framework_version is None:
131-
logger.warning(
132-
empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION)
133-
)
134-
self.framework_version = framework_version or defaults.SKLEARN_VERSION
135-
136140
if image_name is None:
137141
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)
138142
self.image_name = default_framework_uri(
@@ -216,28 +220,40 @@ class constructor
216220
Args:
217221
job_details: the returned job details from a describe_training_job
218222
API call.
219-
model_channel_name:
223+
model_channel_name (str): Name of the channel where pre-trained
224+
model data will be downloaded (default: None).
220225
221226
Returns:
222227
dictionary: The transformed init_params
223228
"""
224-
init_params = super(SKLearn, cls)._prepare_init_params_from_job_description(job_details)
225-
229+
init_params = super(SKLearn, cls)._prepare_init_params_from_job_description(
230+
job_details, model_channel_name
231+
)
226232
image_name = init_params.pop("image")
227-
framework, py_version, _, _ = framework_name_from_image(image_name)
233+
framework, py_version, tag, _ = framework_name_from_image(image_name)
234+
235+
if tag is None:
236+
framework_version = None
237+
else:
238+
framework_version = framework_version_from_tag(tag)
239+
init_params["framework_version"] = framework_version
228240
init_params["py_version"] = py_version
229241

242+
if not framework:
243+
# If we were unable to parse the framework name from the image it is not one of our
244+
# officially supported images, in this case just add the image to the init params.
245+
init_params["image_name"] = image_name
246+
return init_params
247+
248+
training_job_name = init_params["base_job_name"]
249+
230250
if framework and framework != cls.__framework_name__:
231-
training_job_name = init_params["base_job_name"]
232251
raise ValueError(
233252
"Training job: {} didn't use image for requested framework".format(
234253
training_job_name
235254
)
236255
)
237-
if not framework:
238-
# If we were unable to parse the framework name from the image it is not one of our
239-
# officially supported images, in this case just add the image to the init params.
240-
init_params["image_name"] = image_name
256+
241257
return init_params
242258

243259

src/sagemaker/sklearn/model.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from sagemaker import fw_utils
1919

2020
import sagemaker
21-
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
2221
from sagemaker.fw_registry import default_framework_uri
22+
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2323
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2424
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2525
from sagemaker.sklearn import defaults
@@ -62,9 +62,9 @@ def __init__(
6262
model_data,
6363
role,
6464
entry_point,
65-
image=None,
65+
framework_version=None,
6666
py_version="py3",
67-
framework_version=defaults.SKLEARN_VERSION,
67+
image=None,
6868
predictor_cls=SKLearnPredictor,
6969
model_server_workers=None,
7070
**kwargs
@@ -83,12 +83,19 @@ def __init__(
8383
file which should be executed as the entry point to model
8484
hosting. If ``source_dir`` is specified, then ``entry_point``
8585
must point to a file located at the root of ``source_dir``.
86+
framework_version (str): Scikit-learn version you want to use for
87+
executing your model training code. Defaults to ``None``. Required
88+
unless ``image`` is provided.
89+
py_version (str): Python version you want to use for executing your
90+
model training code (default: 'py3'). Currently, 'py3' is the only
91+
supported version. If ``None`` is passed in, ``image`` must be
92+
provided.
8693
image (str): A Docker image URI (default: None). If not specified, a
8794
default image for Scikit-learn will be used.
88-
py_version (str): Python version you want to use for executing your
89-
model training code (default: 'py3').
90-
framework_version (str): Scikit-learn version you want to use for
91-
executing your model training code.
95+
96+
If ``framework_version`` or ``py_version`` are ``None``, then
97+
``image`` is required. If also ``None``, then a ``ValueError``
98+
will be raised.
9299
predictor_cls (callable[str, sagemaker.session.Session]): A function
93100
to call to create a predictor with an endpoint name and
94101
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -105,17 +112,18 @@ def __init__(
105112
:class:`~sagemaker.model.FrameworkModel` and
106113
:class:`~sagemaker.model.Model`.
107114
"""
115+
validate_version_or_image_args(framework_version, py_version, image)
116+
if py_version and py_version != "py3":
117+
raise AttributeError(
118+
"Scikit-learn image only supports Python 3. Please use 'py3' for py_version."
119+
)
120+
self.framework_version = framework_version
121+
self.py_version = py_version
122+
108123
super(SKLearnModel, self).__init__(
109124
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
110125
)
111126

112-
if py_version == "py2":
113-
logger.warning(
114-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
115-
)
116-
117-
self.py_version = py_version
118-
self.framework_version = framework_version
119127
self.model_server_workers = model_server_workers
120128

121129
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/integ/test_git.py

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def test_private_github_with_2fa(sagemaker_local_session, sklearn_full_version):
173173
model_data,
174174
"SageMakerRole",
175175
entry_point=script_path,
176+
framework_version=sklearn_full_version,
176177
source_dir=source_dir,
177178
sagemaker_session=sagemaker_local_session,
178179
git_config=git_config,

tests/integ/test_sklearn_train.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def sklearn_training_job(sagemaker_session, sklearn_full_version, cpu_instance_t
3636
sagemaker_session.boto_region_name
3737

3838

39-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
39+
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only Python 3.")
4040
def test_training_with_additional_hyperparameters(
4141
sagemaker_session, sklearn_full_version, cpu_instance_type
4242
):
@@ -66,7 +66,7 @@ def test_training_with_additional_hyperparameters(
6666
return sklearn.latest_training_job.name
6767

6868

69-
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.")
69+
@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only Python 3.")
7070
def test_training_with_network_isolation(
7171
sagemaker_session, sklearn_full_version, cpu_instance_type
7272
):
@@ -121,7 +121,9 @@ def test_attach_deploy(sklearn_training_job, sagemaker_session, cpu_instance_typ
121121
reason="This test has always failed, but the failure was masked by a bug. "
122122
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
123123
)
124-
def test_deploy_model(sklearn_training_job, sagemaker_session, cpu_instance_type):
124+
def test_deploy_model(
125+
sklearn_training_job, sagemaker_session, cpu_instance_type, sklearn_full_version
126+
):
125127
endpoint_name = "test-sklearn-deploy-model-{}".format(sagemaker_timestamp())
126128
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
127129
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -133,6 +135,7 @@ def test_deploy_model(sklearn_training_job, sagemaker_session, cpu_instance_type
133135
model_data,
134136
"SageMakerRole",
135137
entry_point=script_path,
138+
framework_version=sklearn_full_version,
136139
sagemaker_session=sagemaker_session,
137140
)
138141
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)

0 commit comments

Comments
 (0)