12
12
# language governing permissions and limitations under the License.
13
13
"""Holds the ModelBuilder class and the ModelServer enum."""
14
14
from __future__ import absolute_import
15
+
16
+ import importlib .util
15
17
import uuid
16
18
from typing import Any , Type , List , Dict , Optional , Union
17
19
from dataclasses import dataclass , field
18
20
import logging
19
21
import os
22
+ import re
20
23
21
24
from pathlib import Path
22
25
43
46
from sagemaker .predictor import Predictor
44
47
from sagemaker .serve .model_format .mlflow .constants import (
45
48
MLFLOW_MODEL_PATH ,
49
+ MLFLOW_TRACKING_ARN ,
50
+ MLFLOW_RUN_ID_REGEX ,
51
+ MLFLOW_REGISTRY_PATH_REGEX ,
52
+ MODEL_PACKAGE_ARN_REGEX ,
46
53
MLFLOW_METADATA_FILE ,
47
54
MLFLOW_PIP_DEPENDENCY_FILE ,
48
55
)
49
56
from sagemaker .serve .model_format .mlflow .utils import (
50
57
_get_default_model_server_for_mlflow ,
51
- _mlflow_input_is_local_path ,
52
58
_download_s3_artifacts ,
53
59
_select_container_for_mlflow_model ,
54
60
_generate_mlflow_artifact_path ,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276
282
default = None ,
277
283
metadata = {
278
284
"help" : "Define the model metadata to override, currently supports `HF_TASK`, "
279
- "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280
- "the Hub, Adding unsupported task types will throw an exception"
285
+ "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286
+ "models without task metadata in the Hub, Adding unsupported task types will "
287
+ "throw an exception"
281
288
},
282
289
)
283
290
@@ -502,6 +509,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
502
509
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
503
510
s3_upload_path = self .s3_upload_path ,
504
511
sagemaker_session = self .sagemaker_session ,
512
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
505
513
)
506
514
return new_model_package
507
515
@@ -572,6 +580,7 @@ def _model_builder_deploy_wrapper(
572
580
mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
573
581
s3_upload_path = self .s3_upload_path ,
574
582
sagemaker_session = self .sagemaker_session ,
583
+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
575
584
)
576
585
return predictor
577
586
@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625
634
626
635
return wrapper
627
636
628
- def _check_if_input_is_mlflow_model (self ) -> bool :
629
- """Checks whether an MLmodel file exists in the given directory.
637
+ def _handle_mlflow_input (self ):
638
+ """Check whether an MLflow model is present and handle accordingly"""
639
+ self ._is_mlflow_model = self ._has_mlflow_arguments ()
640
+ if not self ._is_mlflow_model :
641
+ return
642
+
643
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
644
+ artifact_path = self ._get_artifact_path (mlflow_model_path )
645
+ if not self ._mlflow_metadata_exists (artifact_path ):
646
+ logger .info (
647
+ "MLflow model metadata not detected in %s. ModelBuilder is not "
648
+ "handling MLflow model input" ,
649
+ mlflow_model_path ,
650
+ )
651
+ return
652
+
653
+ self ._initialize_for_mlflow (artifact_path )
654
+ _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
655
+
656
+ def _has_mlflow_arguments (self ) -> bool :
657
+ """Check whether MLflow model arguments are present
630
658
631
659
Returns:
632
- bool: True if the MLmodel file exists , False otherwise.
660
+ bool: True if MLflow arguments are present , False otherwise.
633
661
"""
634
662
if self .inference_spec or self .model :
635
663
logger .info (
@@ -644,16 +672,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644
672
)
645
673
return False
646
674
647
- path = self .model_metadata .get (MLFLOW_MODEL_PATH )
648
- if not path :
675
+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
676
+ if not mlflow_model_path :
649
677
logger .info (
650
678
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651
679
"input" ,
652
680
MLFLOW_MODEL_PATH ,
653
681
)
654
682
return False
655
683
656
- # Check for S3 path
684
+ return True
685
+
686
+ def _get_artifact_path (self , mlflow_model_path : str ) -> str :
687
+ """Retrieves the model artifact location given the Mlflow model input.
688
+
689
+ Args:
690
+ mlflow_model_path (str): The MLflow model path input.
691
+
692
+ Returns:
693
+ str: The path to the model artifact.
694
+ """
695
+ if (is_run_id_type := re .match (MLFLOW_RUN_ID_REGEX , mlflow_model_path )) or re .match (
696
+ MLFLOW_REGISTRY_PATH_REGEX , mlflow_model_path
697
+ ):
698
+ mlflow_tracking_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN )
699
+ if not mlflow_tracking_arn :
700
+ raise ValueError (
701
+ "%s is not provided in ModelMetadata or through set_tracking_arn "
702
+ "but MLflow model path was provided." % MLFLOW_TRACKING_ARN ,
703
+ )
704
+
705
+ if not importlib .util .find_spec ("sagemaker_mlflow" ):
706
+ raise ImportError (
707
+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
708
+ )
709
+
710
+ import mlflow
711
+
712
+ mlflow .set_tracking_uri (mlflow_tracking_arn )
713
+ if is_run_id_type :
714
+ _ , run_id , model_path = mlflow_model_path .split ("/" , 2 )
715
+ artifact_uri = mlflow .get_run (run_id ).info .artifact_uri
716
+ if not artifact_uri .endswith ("/" ):
717
+ artifact_uri += "/"
718
+ return artifact_uri + model_path
719
+
720
+ mlflow_client = mlflow .MlflowClient ()
721
+ if not mlflow_model_path .endswith ("/" ):
722
+ mlflow_model_path += "/"
723
+
724
+ if "@" in mlflow_model_path :
725
+ _ , model_name_and_alias , artifact_uri = mlflow_model_path .split ("/" , 2 )
726
+ model_name , model_alias = model_name_and_alias .split ("@" )
727
+ model_metadata = mlflow_client .get_model_version_by_alias (model_name , model_alias )
728
+ else :
729
+ _ , model_name , model_version , artifact_uri = mlflow_model_path .split ("/" , 3 )
730
+ model_metadata = mlflow_client .get_model_version (model_name , model_version )
731
+
732
+ source = model_metadata .source
733
+ if not source .endswith ("/" ):
734
+ source += "/"
735
+ return source + artifact_uri
736
+
737
+ if re .match (MODEL_PACKAGE_ARN_REGEX , mlflow_model_path ):
738
+ model_package = self .sagemaker_session .sagemaker_client .describe_model_package (
739
+ ModelPackageName = mlflow_model_path
740
+ )
741
+ return model_package ["SourceUri" ]
742
+
743
+ return mlflow_model_path
744
+
745
+ def _mlflow_metadata_exists (self , path : str ) -> bool :
746
+ """Checks whether an MLmodel file exists in the given directory.
747
+
748
+ Returns:
749
+ bool: True if the MLmodel file exists, False otherwise.
750
+ """
657
751
if path .startswith ("s3://" ):
658
752
s3_downloader = S3Downloader ()
659
753
if not path .endswith ("/" ):
@@ -665,17 +759,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665
759
file_path = os .path .join (path , MLFLOW_METADATA_FILE )
666
760
return os .path .isfile (file_path )
667
761
668
- def _initialize_for_mlflow (self ) -> None :
669
- """Initialize mlflow model artifacts, image uri and model server."""
670
- mlflow_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
671
- if not _mlflow_input_is_local_path (mlflow_path ):
672
- # TODO: extend to package arn, run id and etc.
673
- logger .info (
674
- "Start downloading model artifacts from %s to %s" , mlflow_path , self .model_path
675
- )
676
- _download_s3_artifacts (mlflow_path , self .model_path , self .sagemaker_session )
762
+ def _initialize_for_mlflow (self , artifact_path : str ) -> None :
763
+ """Initialize mlflow model artifacts, image uri and model server.
764
+
765
+ Args:
766
+ artifact_path (str): The path to the artifact store.
767
+ """
768
+ if artifact_path .startswith ("s3://" ):
769
+ _download_s3_artifacts (artifact_path , self .model_path , self .sagemaker_session )
770
+ elif os .path .exists (artifact_path ):
771
+ _copy_directory_contents (artifact_path , self .model_path )
677
772
else :
678
- _copy_directory_contents ( mlflow_path , self . model_path )
773
+ raise ValueError ( "Invalid path: %s" % artifact_path )
679
774
mlflow_model_metadata_path = _generate_mlflow_artifact_path (
680
775
self .model_path , MLFLOW_METADATA_FILE
681
776
)
@@ -728,6 +823,8 @@ def build( # pylint: disable=R0911
728
823
self .role_arn = role_arn
729
824
self .sagemaker_session = sagemaker_session or Session ()
730
825
826
+ self .sagemaker_session .settings ._local_download_dir = self .model_path
827
+
731
828
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732
829
# decorate to_string() due to
733
830
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +836,8 @@ def build( # pylint: disable=R0911
739
836
self .serve_settings = self ._get_serve_setting ()
740
837
741
838
self ._is_custom_image_uri = self .image_uri is not None
742
- self ._is_mlflow_model = self ._check_if_input_is_mlflow_model ()
743
- if self ._is_mlflow_model :
744
- logger .warning (
745
- "Support of MLflow format models is experimental and is not intended"
746
- " for production at this moment."
747
- )
748
- self ._initialize_for_mlflow ()
749
- _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
839
+
840
+ self ._handle_mlflow_input ()
750
841
751
842
if isinstance (self .model , str ):
752
843
model_task = None
@@ -836,6 +927,19 @@ def validate(self, model_dir: str) -> Type[bool]:
836
927
837
928
return get_metadata (model_dir )
838
929
930
+ def set_tracking_arn (self , arn : str ):
931
+ """Set tracking server ARN"""
932
+ # TODO: support native MLflow URIs
933
+ if importlib .util .find_spec ("sagemaker_mlflow" ):
934
+ import mlflow
935
+
936
+ mlflow .set_tracking_uri (arn )
937
+ self .model_metadata [MLFLOW_TRACKING_ARN ] = arn
938
+ else :
939
+ raise ImportError (
940
+ "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
941
+ )
942
+
839
943
def _hf_schema_builder_init (self , model_task : str ):
840
944
"""Initialize the schema builder for the given HF_TASK
841
945
0 commit comments