Skip to content

Commit 1a016b2

Browse files
ananth102grenmesterJacky Leejiapinwbenieric
authored
feat(sagemaker-mlflow): New features for SageMaker MLflow (#4744)
* feat: add support for mlflow inputs (#1441) * feat: add support for mlflow inputs * fix: typo * fix: doc * fix: S3 regex * fix: refactor * fix: refactor typo * fix: pylint * fix: pylint * fix: black and pylint --------- Co-authored-by: Jacky Lee <drjacky@amazon.com> * fix: lineage tracking bug (#1447) * fix: lineage bug * fix: lineage * fix: add validation for tracking ARN input with MLflow input type * fix: bug * fix: unit tests * fix: mock * fix: args --------- Co-authored-by: Jacky Lee <drjacky@amazon.com> * [Fix] regex for RunId to handle empty artifact path and change mlflow plugin name (#1455) * [Fix] run id regex pattern such that empty artifact path is handled * Change mlflow plugin name as per legal team requirement * Update describe_mlflow_tracking_server call to align with api changes (#1466) * feat: (sagemaker-mlflow) Adding Presigned Url function to SDK (#1462) (#1477) * mlflow presigned url changes * addressing design feedback * test changes * change: mlflow plugin name (#1489) Co-authored-by: Jacky Lee <drjacky@amazon.com> --------- Co-authored-by: Jacky Lee <dr.jackylee@gmail.com> Co-authored-by: Jacky Lee <drjacky@amazon.com> Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com>
1 parent 3243d3f commit 1a016b2

File tree

15 files changed

+598
-92
lines changed

15 files changed

+598
-92
lines changed

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ nbformat>=5.9,<6
3737
accelerate>=0.24.1,<=0.27.0
3838
schema==0.7.5
3939
tensorflow>=2.1,<=2.16
40+
mlflow>=2.12.2,<2.13

src/sagemaker/mlflow/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
15+
"""This module contains code related to the Mlflow Tracking Server."""
16+
17+
from __future__ import absolute_import
18+
from typing import Optional, TYPE_CHECKING
19+
from sagemaker.apiutils import _utils
20+
21+
if TYPE_CHECKING:
22+
from sagemaker import Session
23+
24+
25+
def generate_mlflow_presigned_url(
26+
name: str,
27+
expires_in_seconds: Optional[int] = None,
28+
session_expiration_duration_in_seconds: Optional[int] = None,
29+
sagemaker_session: Optional["Session"] = None,
30+
) -> str:
31+
"""Generate a presigned url to acess the Mlflow UI.
32+
33+
Args:
34+
name (str): Name of the Mlflow Tracking Server
35+
expires_in_seconds (int): Expiration time of the first usage
36+
of the presigned url in seconds.
37+
session_expiration_duration_in_seconds (int): Session duration of the presigned url in
38+
seconds after the first use.
39+
sagemaker_session (sagemaker.session.Session): Session object which
40+
manages interactions with Amazon SageMaker APIs and any other
41+
AWS services needed. If not specified, one is created using the
42+
default AWS configuration chain.
43+
Returns:
44+
(str): Authorized Url to acess the Mlflow UI.
45+
"""
46+
session = sagemaker_session or _utils.default_session()
47+
api_response = session.create_presigned_mlflow_tracking_server_url(
48+
name, expires_in_seconds, session_expiration_duration_in_seconds
49+
)
50+
return api_response["AuthorizedUrl"]

src/sagemaker/serve/builder/model_builder.py

+131-27
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
"""Holds the ModelBuilder class and the ModelServer enum."""
1414
from __future__ import absolute_import
15+
16+
import importlib.util
1517
import uuid
1618
from typing import Any, Type, List, Dict, Optional, Union
1719
from dataclasses import dataclass, field
1820
import logging
1921
import os
22+
import re
2023

2124
from pathlib import Path
2225

@@ -43,12 +46,15 @@
4346
from sagemaker.predictor import Predictor
4447
from sagemaker.serve.model_format.mlflow.constants import (
4548
MLFLOW_MODEL_PATH,
49+
MLFLOW_TRACKING_ARN,
50+
MLFLOW_RUN_ID_REGEX,
51+
MLFLOW_REGISTRY_PATH_REGEX,
52+
MODEL_PACKAGE_ARN_REGEX,
4653
MLFLOW_METADATA_FILE,
4754
MLFLOW_PIP_DEPENDENCY_FILE,
4855
)
4956
from sagemaker.serve.model_format.mlflow.utils import (
5057
_get_default_model_server_for_mlflow,
51-
_mlflow_input_is_local_path,
5258
_download_s3_artifacts,
5359
_select_container_for_mlflow_model,
5460
_generate_mlflow_artifact_path,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276282
default=None,
277283
metadata={
278284
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
279-
"`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280-
"the Hub, Adding unsupported task types will throw an exception"
285+
"`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286+
"models without task metadata in the Hub, Adding unsupported task types will "
287+
"throw an exception"
281288
},
282289
)
283290

@@ -502,6 +509,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
502509
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
503510
s3_upload_path=self.s3_upload_path,
504511
sagemaker_session=self.sagemaker_session,
512+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
505513
)
506514
return new_model_package
507515

@@ -572,6 +580,7 @@ def _model_builder_deploy_wrapper(
572580
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
573581
s3_upload_path=self.s3_upload_path,
574582
sagemaker_session=self.sagemaker_session,
583+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
575584
)
576585
return predictor
577586

@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625634

626635
return wrapper
627636

628-
def _check_if_input_is_mlflow_model(self) -> bool:
629-
"""Checks whether an MLmodel file exists in the given directory.
637+
def _handle_mlflow_input(self):
638+
"""Check whether an MLflow model is present and handle accordingly"""
639+
self._is_mlflow_model = self._has_mlflow_arguments()
640+
if not self._is_mlflow_model:
641+
return
642+
643+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
644+
artifact_path = self._get_artifact_path(mlflow_model_path)
645+
if not self._mlflow_metadata_exists(artifact_path):
646+
logger.info(
647+
"MLflow model metadata not detected in %s. ModelBuilder is not "
648+
"handling MLflow model input",
649+
mlflow_model_path,
650+
)
651+
return
652+
653+
self._initialize_for_mlflow(artifact_path)
654+
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
655+
656+
def _has_mlflow_arguments(self) -> bool:
657+
"""Check whether MLflow model arguments are present
630658
631659
Returns:
632-
bool: True if the MLmodel file exists, False otherwise.
660+
bool: True if MLflow arguments are present, False otherwise.
633661
"""
634662
if self.inference_spec or self.model:
635663
logger.info(
@@ -644,16 +672,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644672
)
645673
return False
646674

647-
path = self.model_metadata.get(MLFLOW_MODEL_PATH)
648-
if not path:
675+
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
676+
if not mlflow_model_path:
649677
logger.info(
650678
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651679
"input",
652680
MLFLOW_MODEL_PATH,
653681
)
654682
return False
655683

656-
# Check for S3 path
684+
return True
685+
686+
def _get_artifact_path(self, mlflow_model_path: str) -> str:
687+
"""Retrieves the model artifact location given the Mlflow model input.
688+
689+
Args:
690+
mlflow_model_path (str): The MLflow model path input.
691+
692+
Returns:
693+
str: The path to the model artifact.
694+
"""
695+
if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match(
696+
MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path
697+
):
698+
mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN)
699+
if not mlflow_tracking_arn:
700+
raise ValueError(
701+
"%s is not provided in ModelMetadata or through set_tracking_arn "
702+
"but MLflow model path was provided." % MLFLOW_TRACKING_ARN,
703+
)
704+
705+
if not importlib.util.find_spec("sagemaker_mlflow"):
706+
raise ImportError(
707+
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
708+
)
709+
710+
import mlflow
711+
712+
mlflow.set_tracking_uri(mlflow_tracking_arn)
713+
if is_run_id_type:
714+
_, run_id, model_path = mlflow_model_path.split("/", 2)
715+
artifact_uri = mlflow.get_run(run_id).info.artifact_uri
716+
if not artifact_uri.endswith("/"):
717+
artifact_uri += "/"
718+
return artifact_uri + model_path
719+
720+
mlflow_client = mlflow.MlflowClient()
721+
if not mlflow_model_path.endswith("/"):
722+
mlflow_model_path += "/"
723+
724+
if "@" in mlflow_model_path:
725+
_, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2)
726+
model_name, model_alias = model_name_and_alias.split("@")
727+
model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias)
728+
else:
729+
_, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3)
730+
model_metadata = mlflow_client.get_model_version(model_name, model_version)
731+
732+
source = model_metadata.source
733+
if not source.endswith("/"):
734+
source += "/"
735+
return source + artifact_uri
736+
737+
if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path):
738+
model_package = self.sagemaker_session.sagemaker_client.describe_model_package(
739+
ModelPackageName=mlflow_model_path
740+
)
741+
return model_package["SourceUri"]
742+
743+
return mlflow_model_path
744+
745+
def _mlflow_metadata_exists(self, path: str) -> bool:
746+
"""Checks whether an MLmodel file exists in the given directory.
747+
748+
Returns:
749+
bool: True if the MLmodel file exists, False otherwise.
750+
"""
657751
if path.startswith("s3://"):
658752
s3_downloader = S3Downloader()
659753
if not path.endswith("/"):
@@ -665,17 +759,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665759
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
666760
return os.path.isfile(file_path)
667761

668-
def _initialize_for_mlflow(self) -> None:
669-
"""Initialize mlflow model artifacts, image uri and model server."""
670-
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
671-
if not _mlflow_input_is_local_path(mlflow_path):
672-
# TODO: extend to package arn, run id and etc.
673-
logger.info(
674-
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
675-
)
676-
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
762+
def _initialize_for_mlflow(self, artifact_path: str) -> None:
763+
"""Initialize mlflow model artifacts, image uri and model server.
764+
765+
Args:
766+
artifact_path (str): The path to the artifact store.
767+
"""
768+
if artifact_path.startswith("s3://"):
769+
_download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session)
770+
elif os.path.exists(artifact_path):
771+
_copy_directory_contents(artifact_path, self.model_path)
677772
else:
678-
_copy_directory_contents(mlflow_path, self.model_path)
773+
raise ValueError("Invalid path: %s" % artifact_path)
679774
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
680775
self.model_path, MLFLOW_METADATA_FILE
681776
)
@@ -728,6 +823,8 @@ def build( # pylint: disable=R0911
728823
self.role_arn = role_arn
729824
self.sagemaker_session = sagemaker_session or Session()
730825

826+
self.sagemaker_session.settings._local_download_dir = self.model_path
827+
731828
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732829
# decorate to_string() due to
733830
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +836,8 @@ def build( # pylint: disable=R0911
739836
self.serve_settings = self._get_serve_setting()
740837

741838
self._is_custom_image_uri = self.image_uri is not None
742-
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
743-
if self._is_mlflow_model:
744-
logger.warning(
745-
"Support of MLflow format models is experimental and is not intended"
746-
" for production at this moment."
747-
)
748-
self._initialize_for_mlflow()
749-
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
839+
840+
self._handle_mlflow_input()
750841

751842
if isinstance(self.model, str):
752843
model_task = None
@@ -836,6 +927,19 @@ def validate(self, model_dir: str) -> Type[bool]:
836927

837928
return get_metadata(model_dir)
838929

930+
def set_tracking_arn(self, arn: str):
931+
"""Set tracking server ARN"""
932+
# TODO: support native MLflow URIs
933+
if importlib.util.find_spec("sagemaker_mlflow"):
934+
import mlflow
935+
936+
mlflow.set_tracking_uri(arn)
937+
self.model_metadata[MLFLOW_TRACKING_ARN] = arn
938+
else:
939+
raise ImportError(
940+
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
941+
)
942+
839943
def _hf_schema_builder_init(self, model_task: str):
840944
"""Initialize the schema builder for the given HF_TASK
841945

src/sagemaker/serve/model_format/mlflow/constants.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
MODEL_PACKAGE_ARN_REGEX = (
2323
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
2424
)
25-
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
26-
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$"
26+
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$"
2727
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
28+
MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN"
2829
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2930
MLFLOW_METADATA_FILE = "MLmodel"
3031
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"

src/sagemaker/serve/model_format/mlflow/utils.py

-22
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file(
227227
raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.")
228228

229229

230-
def _mlflow_input_is_local_path(model_path: str) -> bool:
231-
"""Checks if the given model_path is a local filesystem path.
232-
233-
Args:
234-
- model_path (str): The model path to check.
235-
236-
Returns:
237-
- bool: True if model_path is a local path, False otherwise.
238-
"""
239-
if model_path.startswith("s3://"):
240-
return False
241-
242-
if "/runs/" in model_path or model_path.startswith("runs:"):
243-
return False
244-
245-
# Check if it's not a local file path
246-
if not os.path.exists(model_path):
247-
return False
248-
249-
return True
250-
251-
252230
def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None:
253231
"""Downloads all artifacts from a specified S3 path to a local destination path.
254232

src/sagemaker/serve/utils/lineage_constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
LINEAGE_POLLER_INTERVAL_SECS = 15
1818
LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120
19+
TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$"
20+
TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
1921
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData"
2022
MLFLOW_S3_PATH = "S3"
2123
MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage"

0 commit comments

Comments
 (0)