|
17 | 17 | import re
|
18 | 18 | from abc import ABC, abstractmethod
|
19 | 19 | from datetime import datetime, timedelta
|
20 |
| -from typing import Type, Any, List, Dict, Optional |
| 20 | +from typing import Type, Any, List, Dict, Optional, Tuple |
21 | 21 | import logging
|
22 | 22 |
|
23 | 23 | from botocore.exceptions import ClientError
|
|
82 | 82 | ModelServer.DJL_SERVING,
|
83 | 83 | ModelServer.TGI,
|
84 | 84 | }
|
| 85 | +_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124" |
85 | 86 |
|
86 | 87 | logger = logging.getLogger(__name__)
|
87 | 88 |
|
@@ -829,7 +830,13 @@ def _optimize_for_jumpstart(
|
829 | 830 | self.pysdk_model._enable_network_isolation = False
|
830 | 831 |
|
831 | 832 | if quantization_config or sharding_config or is_compilation:
|
832 |
| - return create_optimization_job_args |
| 833 | + # only apply default image for vLLM usecases. |
| 834 | + # vLLM does not support compilation for now so skip on compilation |
| 835 | + return ( |
| 836 | + create_optimization_job_args |
| 837 | + if is_compilation |
| 838 | + else self._set_optimization_image_default(create_optimization_job_args) |
| 839 | + ) |
833 | 840 | return None
|
834 | 841 |
|
835 | 842 | def _is_gated_model(self, model=None) -> bool:
|
@@ -986,3 +993,105 @@ def _get_neuron_model_env_vars(
|
986 | 993 | )
|
987 | 994 | return job_model.env
|
988 | 995 | return None
|
| 996 | + |
| 997 | + def _set_optimization_image_default( |
| 998 | + self, create_optimization_job_args: Dict[str, Any] |
| 999 | + ) -> Dict[str, Any]: |
| 1000 | + """Defaults the optimization image to the JumpStart deployment config default |
| 1001 | +
|
| 1002 | + Args: |
| 1003 | + create_optimization_job_args (Dict[str, Any]): create optimization job request |
| 1004 | +
|
| 1005 | + Returns: |
| 1006 | + Dict[str, Any]: create optimization job request with image uri default |
| 1007 | + """ |
| 1008 | + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) |
| 1009 | + |
| 1010 | + # find the latest vLLM image version |
| 1011 | + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): |
| 1012 | + if optimization_config.get("ModelQuantizationConfig"): |
| 1013 | + model_quantization_config = optimization_config.get("ModelQuantizationConfig") |
| 1014 | + provided_image = model_quantization_config.get("Image") |
| 1015 | + if provided_image and self._get_latest_lmi_version_from_list( |
| 1016 | + default_image, provided_image |
| 1017 | + ): |
| 1018 | + default_image = provided_image |
| 1019 | + if optimization_config.get("ModelShardingConfig"): |
| 1020 | + model_sharding_config = optimization_config.get("ModelShardingConfig") |
| 1021 | + provided_image = model_sharding_config.get("Image") |
| 1022 | + if provided_image and self._get_latest_lmi_version_from_list( |
| 1023 | + default_image, provided_image |
| 1024 | + ): |
| 1025 | + default_image = provided_image |
| 1026 | + |
| 1027 | + # default to latest vLLM version |
| 1028 | + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): |
| 1029 | + if optimization_config.get("ModelQuantizationConfig") is not None: |
| 1030 | + optimization_config.get("ModelQuantizationConfig")["Image"] = default_image |
| 1031 | + if optimization_config.get("ModelShardingConfig") is not None: |
| 1032 | + optimization_config.get("ModelShardingConfig")["Image"] = default_image |
| 1033 | + |
| 1034 | + logger.info("Defaulting to %s image for optimization job", default_image) |
| 1035 | + |
| 1036 | + return create_optimization_job_args |
| 1037 | + |
| 1038 | + def _get_default_vllm_image(self, image: str) -> bool: |
| 1039 | + """Ensures the minimum working image version for vLLM enabled optimization techniques |
| 1040 | +
|
| 1041 | + Args: |
| 1042 | + image (str): JumpStart provided default image |
| 1043 | +
|
| 1044 | + Returns: |
| 1045 | + str: minimum working image version |
| 1046 | + """ |
| 1047 | + dlc_name, _ = image.split(":") |
| 1048 | + major_version_number, _, _ = self._parse_lmi_version(image) |
| 1049 | + |
| 1050 | + if int(major_version_number) < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]: |
| 1051 | + minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name) |
| 1052 | + return minimum_version_default |
| 1053 | + return image |
| 1054 | + |
| 1055 | + def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool: |
| 1056 | + """LMI version comparator |
| 1057 | +
|
| 1058 | + Args: |
| 1059 | + version (str): current version |
| 1060 | + version_to_compare (str): version to compare to |
| 1061 | +
|
| 1062 | + Returns: |
| 1063 | + bool: if version_to_compare larger or equal to version |
| 1064 | + """ |
| 1065 | + parse_lmi_version = self._parse_lmi_version(version) |
| 1066 | + parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare) |
| 1067 | + |
| 1068 | + # Check major version |
| 1069 | + if parse_lmi_version_to_compare[0] > parse_lmi_version[0]: |
| 1070 | + return True |
| 1071 | + # Check minor version |
| 1072 | + if parse_lmi_version_to_compare[0] == parse_lmi_version[0]: |
| 1073 | + if parse_lmi_version_to_compare[1] > parse_lmi_version[1]: |
| 1074 | + return True |
| 1075 | + if parse_lmi_version_to_compare[1] == parse_lmi_version[1]: |
| 1076 | + # Check patch version |
| 1077 | + if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]: |
| 1078 | + return True |
| 1079 | + return False |
| 1080 | + return False |
| 1081 | + return False |
| 1082 | + |
| 1083 | + def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: |
| 1084 | + """Parse out LMI version |
| 1085 | +
|
| 1086 | + Args: |
| 1087 | + image (str): image to parse version out of |
| 1088 | +
|
| 1089 | + Returns: |
| 1090 | + Tuple[int, int, it]: LMI version split into major, minor, patch |
| 1091 | + """ |
| 1092 | + _, dlc_tag = image.split(":") |
| 1093 | + _, lmi_version, _ = dlc_tag.split("-") |
| 1094 | + major_version, minor_version, patch_version = lmi_version.split(".") |
| 1095 | + major_version_number = major_version[3:] |
| 1096 | + |
| 1097 | + return (int(major_version_number), int(minor_version), int(patch_version)) |
0 commit comments