Skip to content

Commit 7629e3c

Browse files
ananth102knikure
authored andcommitted
feat: (sagemaker-mlflow) Adding Presigned Url function to SDK (aws#1462) (aws#1477)
* mlflow presigned url changes * addressing design feedback * test changes
1 parent ca40ac8 commit 7629e3c

File tree

6 files changed

+152
-0
lines changed

6 files changed

+152
-0
lines changed

src/sagemaker/mlflow/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
15+
"""This module contains code related to the Mlflow Tracking Server."""
16+
17+
from __future__ import absolute_import
18+
from typing import Optional, TYPE_CHECKING
19+
from sagemaker.apiutils import _utils
20+
21+
if TYPE_CHECKING:
22+
from sagemaker import Session
23+
24+
25+
def generate_mlflow_presigned_url(
26+
name: str,
27+
expires_in_seconds: Optional[int] = None,
28+
session_expiration_duration_in_seconds: Optional[int] = None,
29+
sagemaker_session: Optional["Session"] = None,
30+
) -> str:
31+
"""Generate a presigned url to acess the Mlflow UI.
32+
33+
Args:
34+
name (str): Name of the Mlflow Tracking Server
35+
expires_in_seconds (int): Expiration time of the first usage
36+
of the presigned url in seconds.
37+
session_expiration_duration_in_seconds (int): Session duration of the presigned url in
38+
seconds after the first use.
39+
sagemaker_session (sagemaker.session.Session): Session object which
40+
manages interactions with Amazon SageMaker APIs and any other
41+
AWS services needed. If not specified, one is created using the
42+
default AWS configuration chain.
43+
Returns:
44+
(str): Authorized Url to acess the Mlflow UI.
45+
"""
46+
session = sagemaker_session or _utils.default_session()
47+
api_response = session.create_presigned_mlflow_tracking_server_url(
48+
name, expires_in_seconds, session_expiration_duration_in_seconds
49+
)
50+
return api_response["AuthorizedUrl"]

src/sagemaker/session.py

+30
Original file line numberDiff line numberDiff line change
@@ -6705,6 +6705,36 @@ def wait_for_inference_recommendations_job(
67056705
_check_job_status(job_name, desc, "Status")
67066706
return desc
67076707

6708+
def create_presigned_mlflow_tracking_server_url(
6709+
self,
6710+
tracking_server_name: str,
6711+
expires_in_seconds: int = None,
6712+
session_expiration_duration_in_seconds: int = None,
6713+
) -> Dict[str, Any]:
6714+
"""Creates a Presigned Url to acess the Mlflow UI.
6715+
6716+
Args:
6717+
tracking_server_name (str): Name of the Mlflow Tracking Server.
6718+
expires_in_seconds (int): Expiration duration of the URL.
6719+
session_expiration_duration_in_seconds (int): Session duration of the URL.
6720+
Returns:
6721+
(dict): Return value from the ``CreatePresignedMlflowTrackingServerUrl`` API.
6722+
6723+
"""
6724+
6725+
create_presigned_url_args = {"TrackingServerName": tracking_server_name}
6726+
if expires_in_seconds is not None:
6727+
create_presigned_url_args["ExpiresInSeconds"] = expires_in_seconds
6728+
6729+
if session_expiration_duration_in_seconds is not None:
6730+
create_presigned_url_args["SessionExpirationDurationInSeconds"] = (
6731+
session_expiration_duration_in_seconds
6732+
)
6733+
6734+
return self.sagemaker_client.create_presigned_mlflow_tracking_server_url(
6735+
**create_presigned_url_args
6736+
)
6737+
67086738

67096739
def get_model_package_args(
67106740
content_types=None,

tests/unit/sagemaker/mlflow/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
from sagemaker.mlflow.tracking_server import generate_mlflow_presigned_url
16+
17+
18+
def test_generate_presigned_url(sagemaker_session):
19+
client = sagemaker_session.sagemaker_client
20+
client.create_presigned_mlflow_tracking_server_url.return_value = {
21+
"AuthorizedUrl": "https://t-wo.example.com",
22+
}
23+
url = generate_mlflow_presigned_url(
24+
"w",
25+
expires_in_seconds=10,
26+
session_expiration_duration_in_seconds=5,
27+
sagemaker_session=sagemaker_session,
28+
)
29+
client.create_presigned_mlflow_tracking_server_url.assert_called_with(
30+
TrackingServerName="w", ExpiresInSeconds=10, SessionExpirationDurationInSeconds=5
31+
)
32+
assert url == "https://t-wo.example.com"
33+
34+
35+
def test_generate_presigned_url_minimal(sagemaker_session):
36+
client = sagemaker_session.sagemaker_client
37+
client.create_presigned_mlflow_tracking_server_url.return_value = {
38+
"AuthorizedUrl": "https://t-wo.example.com",
39+
}
40+
url = generate_mlflow_presigned_url("w", sagemaker_session=sagemaker_session)
41+
client.create_presigned_mlflow_tracking_server_url.assert_called_with(TrackingServerName="w")
42+
assert url == "https://t-wo.example.com"

tests/unit/test_session.py

+18
Original file line numberDiff line numberDiff line change
@@ -6263,6 +6263,24 @@ def test_create_inference_recommendations_job_propogate_other_exception(
62636263
assert "AccessDeniedException" in str(error)
62646264

62656265

6266+
def test_create_presigned_mlflow_tracking_server_url(sagemaker_session):
6267+
sagemaker_session.create_presigned_mlflow_tracking_server_url("ts", 1, 2)
6268+
assert (
6269+
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with(
6270+
TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2
6271+
)
6272+
)
6273+
6274+
6275+
def test_create_presigned_mlflow_tracking_server_url_minimal(sagemaker_session):
6276+
sagemaker_session.create_presigned_mlflow_tracking_server_url("ts")
6277+
assert (
6278+
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with(
6279+
TrackingServerName="ts"
6280+
)
6281+
)
6282+
6283+
62666284
DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER = [
62676285
MockBotoException("ResourceNotFoundException"),
62686286
{"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]},

0 commit comments

Comments
 (0)