Skip to content

Commit 9df3f5a

Browse files
authored
breaking: change Model parameter order to make model_data optional (#1579)
1 parent 7919331 commit 9df3f5a

17 files changed

+114
-48
lines changed

src/sagemaker/amazon/factorization_machines.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
310310
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
311311
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
312312
super(FactorizationMachinesModel, self).__init__(
313-
model_data,
314313
image,
314+
model_data,
315315
role,
316316
predictor_cls=FactorizationMachinesPredictor,
317317
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/ipinsights.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
216216
)
217217

218218
super(IPInsightsModel, self).__init__(
219-
model_data,
220219
image,
220+
model_data,
221221
role,
222222
predictor_cls=IPInsightsPredictor,
223223
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/kmeans.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
241241
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
242242
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
243243
super(KMeansModel, self).__init__(
244-
model_data,
245244
image,
245+
model_data,
246246
role,
247247
predictor_cls=KMeansPredictor,
248248
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/knn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
231231
registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo
232232
)
233233
super(KNNModel, self).__init__(
234-
model_data,
235234
image,
235+
model_data,
236236
role,
237237
predictor_cls=KNNPredictor,
238238
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/lda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
215215
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
216216
)
217217
super(LDAModel, self).__init__(
218-
model_data,
219218
image,
219+
model_data,
220220
role,
221221
predictor_cls=LDAPredictor,
222222
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/linear_learner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
474474
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
475475
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
476476
super(LinearLearnerModel, self).__init__(
477-
model_data,
478477
image,
478+
model_data,
479479
role,
480480
predictor_cls=LinearLearnerPredictor,
481481
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/ntm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
245245
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
246246
)
247247
super(NTMModel, self).__init__(
248-
model_data,
249248
image,
249+
model_data,
250250
role,
251251
predictor_cls=NTMPredictor,
252252
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/object2vec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
355355
registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo
356356
)
357357
super(Object2VecModel, self).__init__(
358-
model_data,
359358
image,
359+
model_data,
360360
role,
361361
predictor_cls=RealTimePredictor,
362362
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/pca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
225225
repo = "{}:{}".format(PCA.repo_name, PCA.repo_version)
226226
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
227227
super(PCAModel, self).__init__(
228-
model_data,
229228
image,
229+
model_data,
230230
role,
231231
predictor_cls=PCAPredictor,
232232
sagemaker_session=sagemaker_session,

src/sagemaker/amazon/randomcutforest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
206206
registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo
207207
)
208208
super(RandomCutForestModel, self).__init__(
209-
model_data,
210209
image,
210+
model_data,
211211
role,
212212
predictor_cls=RandomCutForestPredictor,
213213
sagemaker_session=sagemaker_session,

src/sagemaker/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,8 @@ def predict_wrapper(endpoint, session):
13981398
kwargs["enable_network_isolation"] = self.enable_network_isolation()
13991399

14001400
return Model(
1401-
self.model_data,
14021401
image or self.train_image(),
1402+
self.model_data,
14031403
role,
14041404
vpc_config=self.get_vpc_config(vpc_config_override),
14051405
sagemaker_session=self.sagemaker_session,

src/sagemaker/model.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class Model(object):
6060

6161
def __init__(
6262
self,
63-
model_data,
6463
image,
64+
model_data=None,
6565
role=None,
6666
predictor_cls=None,
6767
env=None,
@@ -74,9 +74,9 @@ def __init__(
7474
"""Initialize an SageMaker ``Model``.
7575
7676
Args:
77-
model_data (str): The S3 location of a SageMaker model data
78-
``.tar.gz`` file.
7977
image (str): A Docker image URI.
78+
model_data (str): The S3 location of a SageMaker model data
79+
``.tar.gz`` file (default: None).
8080
role (str): An AWS IAM role (either name or full ARN). The Amazon
8181
SageMaker training jobs and APIs that create Amazon SageMaker
8282
endpoints use this role to access training data and model
@@ -361,6 +361,8 @@ def compile(
361361
)
362362
if job_name is None:
363363
raise ValueError("You must provide a compilation job name")
364+
if self.model_data is None:
365+
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
364366

365367
framework = framework.upper()
366368
framework_version = self._get_framework_version() or framework_version
@@ -778,8 +780,8 @@ def __init__(
778780
:class:`~sagemaker.model.Model`.
779781
"""
780782
super(FrameworkModel, self).__init__(
781-
model_data,
782783
image,
784+
model_data,
783785
role,
784786
predictor_cls=predictor_cls,
785787
env=env,

src/sagemaker/multidatamodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def __init__(
103103
# Set the ``Model`` parameters if the model parameter is not specified
104104
if not self.model:
105105
super(MultiDataModel, self).__init__(
106-
self.model_data_prefix,
107106
image,
107+
self.model_data_prefix,
108108
role,
109109
name=self.name,
110110
sagemaker_session=self.sagemaker_session,

src/sagemaker/sparkml/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=N
9696
region_name = (sagemaker_session or Session()).boto_region_name
9797
image = "{}/{}:{}".format(registry(region_name, framework_name), repo_name, spark_version)
9898
super(SparkMLModel, self).__init__(
99-
model_data,
10099
image,
100+
model_data,
101101
role,
102102
predictor_cls=SparkMLPredictor,
103103
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/model/test_deploy.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
5454
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
5555
prepare_container_def.return_value = container_def
5656

57-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
57+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
5858
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
5959

6060
name_from_image.assert_called_with(MODEL_IMAGE)
@@ -81,7 +81,7 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
8181
@patch("sagemaker.production_variant")
8282
def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session):
8383
model = Model(
84-
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
84+
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
8585
)
8686

8787
production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
@@ -113,7 +113,7 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
113113
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
114114
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
115115
def test_deploy_endpoint_name(sagemaker_session):
116-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
116+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
117117

118118
endpoint_name = "blah"
119119
model.deploy(
@@ -136,7 +136,7 @@ def test_deploy_endpoint_name(sagemaker_session):
136136
@patch("sagemaker.model.Model._create_sagemaker_model")
137137
def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session):
138138
model = Model(
139-
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
139+
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
140140
)
141141

142142
tags = [{"Key": "ModelName", "Value": "TestModel"}]
@@ -157,7 +157,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi
157157
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
158158
def test_deploy_kms_key(production_variant, sagemaker_session):
159159
model = Model(
160-
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
160+
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
161161
)
162162

163163
key = "some-key-arn"
@@ -177,7 +177,7 @@ def test_deploy_kms_key(production_variant, sagemaker_session):
177177
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
178178
def test_deploy_async(production_variant, sagemaker_session):
179179
model = Model(
180-
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
180+
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
181181
)
182182

183183
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False)
@@ -196,7 +196,7 @@ def test_deploy_async(production_variant, sagemaker_session):
196196
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
197197
def test_deploy_data_capture_config(production_variant, sagemaker_session):
198198
model = Model(
199-
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
199+
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
200200
)
201201

202202
data_capture_config = Mock()
@@ -223,20 +223,20 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
223223
@patch("sagemaker.local.LocalSession")
224224
def test_deploy_creates_correct_session(local_session, session):
225225
# We expect a LocalSession when deploying to instance_type = 'local'
226-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
226+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE)
227227
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
228228
assert model.sagemaker_session == local_session.return_value
229229

230230
# We expect a real Session when deploying to instance_type != local/local_gpu
231-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
231+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE)
232232
model.deploy(
233233
endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2
234234
)
235235
assert model.sagemaker_session == session.return_value
236236

237237

238238
def test_deploy_no_role(sagemaker_session):
239-
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
239+
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session)
240240

241241
with pytest.raises(ValueError, match="Role can not be null for deploying a model"):
242242
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
@@ -248,8 +248,8 @@ def test_deploy_no_role(sagemaker_session):
248248
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
249249
def test_deploy_predictor_cls(production_variant, sagemaker_session):
250250
model = Model(
251-
MODEL_DATA,
252251
MODEL_IMAGE,
252+
MODEL_DATA,
253253
role=ROLE,
254254
name=MODEL_NAME,
255255
predictor_cls=sagemaker.predictor.RealTimePredictor,
@@ -269,7 +269,7 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
269269

270270

271271
def test_deploy_update_endpoint(sagemaker_session):
272-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
272+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
273273
model.deploy(
274274
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, update_endpoint=True
275275
)
@@ -300,7 +300,7 @@ def test_deploy_update_endpoint_optional_args(sagemaker_session):
300300
kms_key = "foo"
301301
data_capture_config = Mock()
302302

303-
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
303+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
304304
model.deploy(
305305
instance_type=INSTANCE_TYPE,
306306
initial_instance_count=INSTANCE_COUNT,

0 commit comments

Comments
 (0)