Skip to content

Commit 1104baf

Browse files
committed
use jumpstart deployment config image as default optimization image
1 parent a58654e commit 1104baf

File tree

4 files changed

+332
-7
lines changed

4 files changed

+332
-7
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

+111-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
from abc import ABC, abstractmethod
1919
from datetime import datetime, timedelta
20-
from typing import Type, Any, List, Dict, Optional
20+
from typing import Type, Any, List, Dict, Optional, Tuple
2121
import logging
2222

2323
from botocore.exceptions import ClientError
@@ -82,6 +82,7 @@
8282
ModelServer.DJL_SERVING,
8383
ModelServer.TGI,
8484
}
85+
_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124"
8586

8687
logger = logging.getLogger(__name__)
8788

@@ -829,7 +830,13 @@ def _optimize_for_jumpstart(
829830
self.pysdk_model._enable_network_isolation = False
830831

831832
if quantization_config or sharding_config or is_compilation:
832-
return create_optimization_job_args
833+
# only apply default image for vLLM usecases.
834+
# vLLM does not support compilation for now so skip on compilation
835+
return (
836+
create_optimization_job_args
837+
if is_compilation
838+
else self._set_optimization_image_default(create_optimization_job_args)
839+
)
833840
return None
834841

835842
def _is_gated_model(self, model=None) -> bool:
@@ -986,3 +993,105 @@ def _get_neuron_model_env_vars(
986993
)
987994
return job_model.env
988995
return None
996+
997+
def _set_optimization_image_default(
998+
self, create_optimization_job_args: Dict[str, Any]
999+
) -> Dict[str, Any]:
1000+
"""Defaults the optimization image to the JumpStart deployment config default
1001+
1002+
Args:
1003+
create_optimization_job_args (Dict[str, Any]): create optimization job request
1004+
1005+
Returns:
1006+
Dict[str, Any]: create optimization job request with image uri default
1007+
"""
1008+
default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"])
1009+
1010+
# find the latest vLLM image version
1011+
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
1012+
if optimization_config.get("ModelQuantizationConfig"):
1013+
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
1014+
provided_image = model_quantization_config.get("Image")
1015+
if provided_image and self._get_latest_lmi_version_from_list(
1016+
default_image, provided_image
1017+
):
1018+
default_image = provided_image
1019+
if optimization_config.get("ModelShardingConfig"):
1020+
model_sharding_config = optimization_config.get("ModelShardingConfig")
1021+
provided_image = model_sharding_config.get("Image")
1022+
if provided_image and self._get_latest_lmi_version_from_list(
1023+
default_image, provided_image
1024+
):
1025+
default_image = provided_image
1026+
1027+
# default to latest vLLM version
1028+
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
1029+
if optimization_config.get("ModelQuantizationConfig") is not None:
1030+
optimization_config.get("ModelQuantizationConfig")["Image"] = default_image
1031+
if optimization_config.get("ModelShardingConfig") is not None:
1032+
optimization_config.get("ModelShardingConfig")["Image"] = default_image
1033+
1034+
logger.info("Defaulting to %s image for optimization job", default_image)
1035+
1036+
return create_optimization_job_args
1037+
1038+
def _get_default_vllm_image(self, image: str) -> bool:
1039+
"""Ensures the minimum working image version for vLLM enabled optimization techniques
1040+
1041+
Args:
1042+
image (str): JumpStart provided default image
1043+
1044+
Returns:
1045+
str: minimum working image version
1046+
"""
1047+
dlc_name, _ = image.split(":")
1048+
major_version_number, _, _ = self._parse_lmi_version(image)
1049+
1050+
if int(major_version_number) < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]:
1051+
minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name)
1052+
return minimum_version_default
1053+
return image
1054+
1055+
def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool:
1056+
"""LMI version comparator
1057+
1058+
Args:
1059+
version (str): current version
1060+
version_to_compare (str): version to compare to
1061+
1062+
Returns:
1063+
bool: if version_to_compare larger or equal to version
1064+
"""
1065+
parse_lmi_version = self._parse_lmi_version(version)
1066+
parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare)
1067+
1068+
# Check major version
1069+
if parse_lmi_version_to_compare[0] > parse_lmi_version[0]:
1070+
return True
1071+
# Check minor version
1072+
if parse_lmi_version_to_compare[0] == parse_lmi_version[0]:
1073+
if parse_lmi_version_to_compare[1] > parse_lmi_version[1]:
1074+
return True
1075+
if parse_lmi_version_to_compare[1] == parse_lmi_version[1]:
1076+
# Check patch version
1077+
if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]:
1078+
return True
1079+
return False
1080+
return False
1081+
return False
1082+
1083+
def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]:
1084+
"""Parse out LMI version
1085+
1086+
Args:
1087+
image (str): image to parse version out of
1088+
1089+
Returns:
1090+
Tuple[int, int, it]: LMI version split into major, minor, patch
1091+
"""
1092+
_, dlc_tag = image.split(":")
1093+
_, lmi_version, _ = dlc_tag.split("-")
1094+
major_version, minor_version, patch_version = lmi_version.split(".")
1095+
major_version_number = major_version[3:]
1096+
1097+
return (int(major_version_number), int(minor_version), int(patch_version))

tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py

+18
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
3232
iam_client = sagemaker_session.boto_session.client("iam")
3333
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
3434

35+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
36+
3537
schema_builder = SchemaBuilder("test", "test")
3638
model_builder = ModelBuilder(
3739
model="meta-textgeneration-llama-3-1-8b-instruct",
@@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
5052
accept_eula=True,
5153
)
5254

55+
assert not sagemaker_session.sagemaker_client.create_optimization_job.called
56+
5357
optimized_model.deploy()
5458

5559
mock_create_model.assert_called_once_with(
@@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
126130
accept_eula=True,
127131
)
128132

133+
assert (
134+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
135+
"OptimizationConfigs"
136+
][0]["ModelShardingConfig"]["Image"]
137+
is not None
138+
)
139+
129140
optimized_model.deploy(
130141
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
131142
)
@@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
206217
accept_eula=True,
207218
)
208219

220+
assert (
221+
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
222+
"OptimizationConfigs"
223+
][0]["ModelQuantizationConfig"]["Image"]
224+
is not None
225+
)
226+
209227
optimized_model.deploy()
210228

211229
mock_create_model.assert_called_once_with(

0 commit comments

Comments
 (0)