87
87
_extract_speculative_draft_model_provider ,
88
88
_jumpstart_speculative_decoding ,
89
89
)
90
- from sagemaker .serve .utils .predictors import _get_local_mode_predictor , InProcessModePredictor
90
+ from sagemaker .serve .utils .predictors import (
91
+ _get_local_mode_predictor ,
92
+ _get_in_process_mode_predictor ,
93
+ )
91
94
from sagemaker .serve .utils .hardware_detector import (
92
95
_get_gpu_info ,
93
96
_get_gpu_info_fallback ,
@@ -435,11 +438,11 @@ def _prepare_for_mode(
435
438
# init the InProcessMode object
436
439
self .modes [str (Mode .IN_PROCESS )] = InProcessMode (
437
440
inference_spec = self .inference_spec ,
441
+ model = self .model ,
438
442
schema_builder = self .schema_builder ,
439
443
session = self .sagemaker_session ,
440
444
model_path = self .model_path ,
441
445
env_vars = self .env_vars ,
442
- model_server = self .model_server ,
443
446
)
444
447
self .modes [str (Mode .IN_PROCESS )].prepare ()
445
448
return None
@@ -575,7 +578,7 @@ def _model_builder_deploy_wrapper(
575
578
if self .mode == Mode .IN_PROCESS :
576
579
serializer , deserializer = self ._get_client_translators ()
577
580
578
- predictor = InProcessModePredictor (
581
+ predictor = _get_in_process_mode_predictor (
579
582
self .modes [str (Mode .IN_PROCESS )], serializer , deserializer
580
583
)
581
584
@@ -597,6 +600,7 @@ def _model_builder_deploy_wrapper(
597
600
self .image_uri , container_timeout_in_second , self .secret_key , predictor
598
601
)
599
602
return predictor
603
+
600
604
if self .mode == Mode .SAGEMAKER_ENDPOINT :
601
605
# Validate parameters
602
606
# Instance type and instance count parameter validation is done based on deployment type
@@ -650,16 +654,17 @@ def _build_for_torchserve(self) -> Type[Model]:
650
654
"""Build the model for torchserve"""
651
655
self ._save_model_inference_spec ()
652
656
653
- self ._auto_detect_container ()
657
+ if self .mode != Mode .IN_PROCESS :
658
+ self ._auto_detect_container ()
654
659
655
- self .secret_key = prepare_for_torchserve (
656
- model_path = self .model_path ,
657
- shared_libs = self .shared_libs ,
658
- dependencies = self .dependencies ,
659
- session = self .sagemaker_session ,
660
- image_uri = self .image_uri ,
661
- inference_spec = self .inference_spec ,
662
- )
660
+ self .secret_key = prepare_for_torchserve (
661
+ model_path = self .model_path ,
662
+ shared_libs = self .shared_libs ,
663
+ dependencies = self .dependencies ,
664
+ session = self .sagemaker_session ,
665
+ image_uri = self .image_uri ,
666
+ inference_spec = self .inference_spec ,
667
+ )
663
668
664
669
self ._prepare_for_mode ()
665
670
self .model = self ._create_model ()
@@ -854,6 +859,7 @@ def build( # pylint: disable=R0911
854
859
Returns:
855
860
Type[Model]: A deployable ``Model`` object.
856
861
"""
862
+ from sagemaker .modules .train .model_trainer import ModelTrainer
857
863
858
864
self .modes = dict ()
859
865
@@ -908,10 +914,25 @@ def build( # pylint: disable=R0911
908
914
909
915
if isinstance (self .model , str ):
910
916
model_task = None
911
- if self ._is_jumpstart_model_id () or self ._use_jumpstart_equivalent ():
917
+
918
+ if self ._is_jumpstart_model_id ():
919
+ if self .mode == Mode .IN_PROCESS :
920
+ raise ValueError (
921
+ f"{ self .mode } is not supported for Jumpstart models. "
922
+ "Please use LOCAL_CONTAINER mode to deploy a Jumpstart model"
923
+ " on your local machine."
924
+ )
912
925
self .model_hub = ModelHub .JUMPSTART
926
+ logger .debug ("Building for Jumpstart model Id..." )
913
927
self .built_model = self ._build_for_jumpstart ()
914
928
return self .built_model
929
+
930
+ if self .mode != Mode .IN_PROCESS :
931
+ if self ._use_jumpstart_equivalent ():
932
+ self .model_hub = ModelHub .JUMPSTART
933
+ logger .debug ("Building for Jumpstart equiavalent model Id..." )
934
+ self .built_model = self ._build_for_jumpstart ()
935
+ return self .built_model
915
936
self .model_hub = ModelHub .HUGGINGFACE
916
937
917
938
if self .model_metadata :
@@ -931,7 +952,7 @@ def build( # pylint: disable=R0911
931
952
if model_task == "text-generation" :
932
953
self .built_model = self ._build_for_tgi ()
933
954
return self .built_model
934
- if model_task == "sentence-similarity" :
955
+ if model_task in [ "sentence-similarity" , "feature-extraction" ] :
935
956
self .built_model = self ._build_for_tei ()
936
957
return self .built_model
937
958
elif self ._can_fit_on_single_gpu ():
@@ -951,16 +972,6 @@ def build( # pylint: disable=R0911
951
972
952
973
def _build_validations (self ):
953
974
"""Validations needed for model server overrides, or auto-detection or fallback"""
954
- if (
955
- self .mode == Mode .IN_PROCESS
956
- and self .model_server is not ModelServer .MMS
957
- and self .model_server is not ModelServer .DJL_SERVING
958
- and self .model_server is not ModelServer .TORCHSERVE
959
- ):
960
- raise ValueError (
961
- "IN_PROCESS mode is only supported for the following servers "
962
- "in beta release: MMS/Transformers, TORCHSERVE, DJL_SERVING server"
963
- )
964
975
if self .inference_spec and self .model :
965
976
raise ValueError ("Can only set one of the following: model, inference_spec." )
966
977
0 commit comments