Skip to content

Commit fa02963

Browse files
pravali96pintaoz-aws
authored andcommitted
add modelID support to model builder InProcess model (#1580)
* add modelID support to model builder InProcess model * fix format * import pkg only if model id is specified * fix format and remove unused imports * add inference_spec to triton inProcess mode * fix format * fix format using black * fix app.py * Implement stop_server for fast api, remove model_server param from in_process mode * fix tests and format * rebase * remove Inference spec support arg in DJL builder, add fastapi reqs to pyproject * make changes to not support JS in in_process mode * fix code style errors * fix flake8 warnings * fix tgi unit tests to build for TGI usecase instead of JS
1 parent 611ea9a commit fa02963

File tree

12 files changed

+242
-112
lines changed

12 files changed

+242
-112
lines changed

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"boto3>=1.34.142,<2.0",
3636
"cloudpickle==2.2.1",
3737
"docker",
38+
"fastapi",
3839
"google-pasta",
3940
"importlib-metadata>=1.4.0,<7.0",
4041
"jsonschema",
@@ -54,6 +55,7 @@ dependencies = [
5455
"tblib>=1.7.0,<4",
5556
"tqdm",
5657
"urllib3>=1.26.8,<3.0.0",
58+
"uvicorn"
5759
]
5860

5961
[project.scripts]

src/sagemaker/serve/builder/djl_builder.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def __init__(self):
7878
self.mode = None
7979
self.model_server = None
8080
self.image_uri = None
81-
self.inference_spec = None
8281
self._is_custom_image_uri = False
8382
self.image_config = None
8483
self.vpc_config = None
@@ -263,12 +262,7 @@ def _build_for_hf_djl(self):
263262

264263
_create_dir_structure(self.model_path)
265264
if not hasattr(self, "pysdk_model"):
266-
if self.inference_spec is not None:
267-
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
268-
else:
269-
self.env_vars.update({"HF_MODEL_ID": self.model})
270-
271-
logger.info(self.env_vars)
265+
self.env_vars.update({"HF_MODEL_ID": self.model})
272266

273267
self.hf_model_config = _get_model_config_properties_from_hf(
274268
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN")

src/sagemaker/serve/builder/model_builder.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@
8787
_extract_speculative_draft_model_provider,
8888
_jumpstart_speculative_decoding,
8989
)
90-
from sagemaker.serve.utils.predictors import _get_local_mode_predictor, InProcessModePredictor
90+
from sagemaker.serve.utils.predictors import (
91+
_get_local_mode_predictor,
92+
_get_in_process_mode_predictor,
93+
)
9194
from sagemaker.serve.utils.hardware_detector import (
9295
_get_gpu_info,
9396
_get_gpu_info_fallback,
@@ -435,11 +438,11 @@ def _prepare_for_mode(
435438
# init the InProcessMode object
436439
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
437440
inference_spec=self.inference_spec,
441+
model=self.model,
438442
schema_builder=self.schema_builder,
439443
session=self.sagemaker_session,
440444
model_path=self.model_path,
441445
env_vars=self.env_vars,
442-
model_server=self.model_server,
443446
)
444447
self.modes[str(Mode.IN_PROCESS)].prepare()
445448
return None
@@ -575,7 +578,7 @@ def _model_builder_deploy_wrapper(
575578
if self.mode == Mode.IN_PROCESS:
576579
serializer, deserializer = self._get_client_translators()
577580

578-
predictor = InProcessModePredictor(
581+
predictor = _get_in_process_mode_predictor(
579582
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
580583
)
581584

@@ -597,6 +600,7 @@ def _model_builder_deploy_wrapper(
597600
self.image_uri, container_timeout_in_second, self.secret_key, predictor
598601
)
599602
return predictor
603+
600604
if self.mode == Mode.SAGEMAKER_ENDPOINT:
601605
# Validate parameters
602606
# Instance type and instance count parameter validation is done based on deployment type
@@ -650,16 +654,17 @@ def _build_for_torchserve(self) -> Type[Model]:
650654
"""Build the model for torchserve"""
651655
self._save_model_inference_spec()
652656

653-
self._auto_detect_container()
657+
if self.mode != Mode.IN_PROCESS:
658+
self._auto_detect_container()
654659

655-
self.secret_key = prepare_for_torchserve(
656-
model_path=self.model_path,
657-
shared_libs=self.shared_libs,
658-
dependencies=self.dependencies,
659-
session=self.sagemaker_session,
660-
image_uri=self.image_uri,
661-
inference_spec=self.inference_spec,
662-
)
660+
self.secret_key = prepare_for_torchserve(
661+
model_path=self.model_path,
662+
shared_libs=self.shared_libs,
663+
dependencies=self.dependencies,
664+
session=self.sagemaker_session,
665+
image_uri=self.image_uri,
666+
inference_spec=self.inference_spec,
667+
)
663668

664669
self._prepare_for_mode()
665670
self.model = self._create_model()
@@ -854,6 +859,7 @@ def build( # pylint: disable=R0911
854859
Returns:
855860
Type[Model]: A deployable ``Model`` object.
856861
"""
862+
from sagemaker.modules.train.model_trainer import ModelTrainer
857863

858864
self.modes = dict()
859865

@@ -908,10 +914,25 @@ def build( # pylint: disable=R0911
908914

909915
if isinstance(self.model, str):
910916
model_task = None
911-
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
917+
918+
if self._is_jumpstart_model_id():
919+
if self.mode == Mode.IN_PROCESS:
920+
raise ValueError(
921+
f"{self.mode} is not supported for Jumpstart models. "
922+
"Please use LOCAL_CONTAINER mode to deploy a Jumpstart model"
923+
" on your local machine."
924+
)
912925
self.model_hub = ModelHub.JUMPSTART
926+
logger.debug("Building for Jumpstart model Id...")
913927
self.built_model = self._build_for_jumpstart()
914928
return self.built_model
929+
930+
if self.mode != Mode.IN_PROCESS:
931+
if self._use_jumpstart_equivalent():
932+
self.model_hub = ModelHub.JUMPSTART
933+
logger.debug("Building for Jumpstart equiavalent model Id...")
934+
self.built_model = self._build_for_jumpstart()
935+
return self.built_model
915936
self.model_hub = ModelHub.HUGGINGFACE
916937

917938
if self.model_metadata:
@@ -931,7 +952,7 @@ def build( # pylint: disable=R0911
931952
if model_task == "text-generation":
932953
self.built_model = self._build_for_tgi()
933954
return self.built_model
934-
if model_task == "sentence-similarity":
955+
if model_task in ["sentence-similarity", "feature-extraction"]:
935956
self.built_model = self._build_for_tei()
936957
return self.built_model
937958
elif self._can_fit_on_single_gpu():
@@ -951,16 +972,6 @@ def build( # pylint: disable=R0911
951972

952973
def _build_validations(self):
953974
"""Validations needed for model server overrides, or auto-detection or fallback"""
954-
if (
955-
self.mode == Mode.IN_PROCESS
956-
and self.model_server is not ModelServer.MMS
957-
and self.model_server is not ModelServer.DJL_SERVING
958-
and self.model_server is not ModelServer.TORCHSERVE
959-
):
960-
raise ValueError(
961-
"IN_PROCESS mode is only supported for the following servers "
962-
"in beta release: MMS/Transformers, TORCHSERVE, DJL_SERVING server"
963-
)
964975
if self.inference_spec and self.model:
965976
raise ValueError("Can only set one of the following: model, inference_spec.")
966977

src/sagemaker/serve/builder/tei_builder.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
)
2727
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
2828
from sagemaker.serve.utils.optimize_utils import _is_optimized
29-
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
29+
from sagemaker.serve.utils.predictors import InProcessModePredictor, TeiLocalModePredictor
3030
from sagemaker.serve.utils.types import ModelServer
3131
from sagemaker.serve.mode.function_pointers import Mode
3232
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
3333
from sagemaker.base_predictor import PredictorBase
3434

3535
logger = logging.getLogger(__name__)
36+
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]
3637

3738
_CODE_FOLDER = "code"
3839

@@ -141,6 +142,17 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
141142

142143
serializer = self.schema_builder.input_serializer
143144
deserializer = self.schema_builder._output_deserializer
145+
if self.mode == Mode.IN_PROCESS:
146+
self._prepare_for_mode()
147+
predictor = InProcessModePredictor(
148+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
149+
)
150+
151+
self.modes[str(Mode.IN_PROCESS)].create_server(
152+
predictor,
153+
)
154+
return predictor
155+
144156
if self.mode == Mode.LOCAL_CONTAINER:
145157
timeout = kwargs.get("model_data_download_timeout")
146158

@@ -222,7 +234,7 @@ def _build_for_hf_tei(self):
222234

223235
self.pysdk_model = self._create_tei_model()
224236

225-
if self.mode == Mode.LOCAL_CONTAINER:
237+
if self.mode in LOCAL_MODES:
226238
self._prepare_for_mode()
227239

228240
return self.pysdk_model

src/sagemaker/serve/builder/tgi_builder.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@
4949
_get_gpu_info_fallback,
5050
)
5151
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
52-
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
52+
from sagemaker.serve.utils.predictors import TgiLocalModePredictor, InProcessModePredictor
5353
from sagemaker.serve.utils.types import ModelServer
5454
from sagemaker.serve.mode.function_pointers import Mode
5555
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
5656
from sagemaker.base_predictor import PredictorBase
5757

5858
logger = logging.getLogger(__name__)
59+
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]
5960

6061
_CODE_FOLDER = "code"
6162
_INVALID_SAMPLE_DATA_EX = (
@@ -176,6 +177,17 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
176177

177178
serializer = self.schema_builder.input_serializer
178179
deserializer = self.schema_builder._output_deserializer
180+
181+
if self.mode == Mode.IN_PROCESS:
182+
predictor = InProcessModePredictor(
183+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
184+
)
185+
186+
self.modes[str(Mode.IN_PROCESS)].create_server(
187+
predictor,
188+
)
189+
return predictor
190+
179191
if self.mode == Mode.LOCAL_CONTAINER:
180192
timeout = kwargs.get("model_data_download_timeout")
181193

@@ -280,7 +292,7 @@ def _build_for_hf_tgi(self):
280292
] = _default_max_new_tokens
281293
self.pysdk_model = self._create_tgi_model()
282294

283-
if self.mode == Mode.LOCAL_CONTAINER:
295+
if self.mode in LOCAL_MODES:
284296
self._prepare_for_mode()
285297

286298
return self.pysdk_model

src/sagemaker/serve/mode/in_process_mode.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
from pathlib import Path
66
import logging
7-
from typing import Dict, Type
7+
from typing import Dict, Type, Optional
88
import time
99
from datetime import datetime, timedelta
1010

1111
from sagemaker.base_predictor import PredictorBase
1212
from sagemaker.serve.spec.inference_spec import InferenceSpec
1313
from sagemaker.serve.builder.schema_builder import SchemaBuilder
14-
from sagemaker.serve.utils.types import ModelServer
1514
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1615
from sagemaker.serve.model_server.in_process_model_server.in_process_server import InProcessServing
1716
from sagemaker.session import Session
@@ -26,8 +25,8 @@ class InProcessMode(InProcessServing):
2625

2726
def __init__(
2827
self,
29-
model_server: ModelServer,
30-
inference_spec: Type[InferenceSpec],
28+
model: Optional[str],
29+
inference_spec: Optional[InferenceSpec],
3130
schema_builder: Type[SchemaBuilder],
3231
session: Session,
3332
model_path: str = None,
@@ -36,12 +35,12 @@ def __init__(
3635
# pylint: disable=bad-super-call
3736
super().__init__()
3837

38+
self.model = model
3939
self.inference_spec = inference_spec
4040
self.model_path = model_path
4141
self.env_vars = env_vars
4242
self.session = session
4343
self.schema_builder = schema_builder
44-
self.model_server = model_server
4544
self._ping_local_server = None
4645

4746
def load(self, model_path: str = None):
@@ -61,18 +60,15 @@ def create_server(
6160
self,
6261
predictor: PredictorBase,
6362
):
64-
"""Creating the server and checking ping health."""
65-
logger.info("Waiting for model server %s to start up...", self.model_server)
66-
67-
if self.model_server == ModelServer.MMS:
68-
self._ping_local_server = self._deep_ping
69-
self._start_serving()
70-
elif self.model_server == ModelServer.DJL_SERVING:
71-
self._ping_local_server = self._deep_ping
72-
self._start_serving()
73-
elif self.model_server == ModelServer.TORCHSERVE:
74-
self._ping_local_server = self._deep_ping
75-
self._start_serving()
63+
"""Creating the fast api server and checking ping health."""
64+
65+
logger.info("Waiting for fastapi server to start up...")
66+
67+
logger.warning("Note: This is not a standard model server.")
68+
logger.warning("The model is being hosted directly on the FastAPI server.")
69+
70+
self._ping_local_server = self._deep_ping
71+
self._start_serving()
7672

7773
# allow some time for server to be ready.
7874
time.sleep(1)

0 commit comments

Comments
 (0)