Skip to content

remove s3 output location requirement from hub class init #5081

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 6 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 16 additions & 53 deletions src/sagemaker/jumpstart/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,25 @@
from datetime import datetime
import logging
from typing import Optional, Dict, List, Any, Union
from botocore import exceptions

from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.session import Session

from sagemaker.jumpstart.constants import (
JUMPSTART_LOGGER,
)
from sagemaker.jumpstart.types import (
HubContentType,
)
from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues
from sagemaker.jumpstart.hub.utils import (
get_hub_model_version,
get_info_from_hub_resource_arn,
create_hub_bucket_if_it_does_not_exist,
generate_default_hub_bucket_name,
create_s3_object_reference_from_uri,
construct_hub_arn_from_name,
)

from sagemaker.jumpstart.notebook_utils import (
list_jumpstart_models,
)

from sagemaker.jumpstart.hub.types import (
S3ObjectLocation,
)
from sagemaker.jumpstart.hub.interfaces import (
DescribeHubResponse,
DescribeHubContentResponse,
Expand All @@ -66,8 +56,8 @@ class Hub:
def __init__(
self,
hub_name: str,
sagemaker_session: Session,
bucket_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
) -> None:
"""Instantiates a SageMaker ``Hub``.

Expand All @@ -78,41 +68,11 @@ def __init__(
"""
self.hub_name = hub_name
self.region = sagemaker_session.boto_region_name
self.bucket_name = bucket_name
self._sagemaker_session = (
sagemaker_session
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
)
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)

def _fetch_hub_bucket_name(self) -> str:
"""Retrieves hub bucket name from Hub config if exists"""
try:
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
if hub_output_location:
location = create_s3_object_reference_from_uri(hub_output_location)
return location.bucket
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
JUMPSTART_LOGGER.warning(
"There is not a Hub bucket associated with %s. Using %s",
self.hub_name,
default_bucket_name,
)
return default_bucket_name
except exceptions.ClientError:
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
JUMPSTART_LOGGER.warning(
"There is not a Hub bucket associated with %s. Using %s",
self.hub_name,
hub_bucket_name,
)
return hub_bucket_name

def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
"""Generates an ``S3ObjectLocation`` given a Hub name."""
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
curr_timestamp = datetime.now().timestamp()
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")

def _get_latest_model_version(self, model_id: str) -> str:
"""Populates the lastest version of a model from specs no matter what is passed.
Expand All @@ -132,19 +92,22 @@ def create(
tags: Optional[str] = None,
) -> Dict[str, str]:
"""Creates a hub with the given description"""
curr_timestamp = datetime.now().timestamp()

create_hub_bucket_if_it_does_not_exist(
self.hub_storage_location.bucket, self._sagemaker_session
)
request = {
"hub_name": self.hub_name,
"hub_description": description,
"hub_display_name": display_name,
"hub_search_keywords": search_keywords,
"tags": tags,
}

return self._sagemaker_session.create_hub(
hub_name=self.hub_name,
hub_description=description,
hub_display_name=display_name,
hub_search_keywords=search_keywords,
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
tags=tags,
)
if self.bucket_name:
request["s3_storage_config"] = {
"S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}")
}

return self._sagemaker_session.create_hub(**request)

def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse:
"""Returns descriptive information about the Hub"""
Expand Down
57 changes: 0 additions & 57 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from __future__ import absolute_import
import re
from typing import Optional, List, Any
from sagemaker.jumpstart.hub.types import S3ObjectLocation
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
from sagemaker.utils import aws_partition
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
Expand Down Expand Up @@ -139,61 +137,6 @@ def generate_hub_arn_for_init_kwargs(
return hub_arn


def generate_default_hub_bucket_name(
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.

Returns:
str: The name of the default bucket. If the name was not explicitly specified through
the Session or sagemaker_config, the bucket will take the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
account_id: str = sagemaker_session.account_id()

# TODO: Validate and fast fail

return f"sagemaker-hubs-{region}-{account_id}"


def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
"""Utiity to help generate an S3 object reference"""
if not s3_uri:
return None

bucket, key = parse_s3_url(s3_uri)

return S3ObjectLocation(
bucket=bucket,
key=key,
)


def create_hub_bucket_if_it_does_not_exist(
bucket_name: Optional[str] = None,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Creates the default SageMaker Hub bucket if it does not exist.

Returns:
str: The name of the default bucket. Takes the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
if bucket_name is None:
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)

sagemaker_session._create_s3_bucket_if_it_does_not_exist(
bucket_name=bucket_name,
region=region,
)

return bucket_name


def is_gated_bucket(bucket_name: str) -> bool:
"""Returns true if the bucket name is the JumpStart gated bucket."""
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET
Expand Down
31 changes: 9 additions & 22 deletions tests/unit/sagemaker/jumpstart/hub/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest
from mock import Mock
from sagemaker.jumpstart.hub.hub import Hub
from sagemaker.jumpstart.hub.types import S3ObjectLocation


REGION = "us-east-1"
Expand Down Expand Up @@ -60,48 +59,34 @@ def test_instantiates(sagemaker_session):


@pytest.mark.parametrize(
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"),
[
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None),
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None),
pytest.param(
"MockHub2",
"this is my sagemaker hub two",
None,
"DisplayMockHub2",
["mock", "hub", "123"],
[{"Key": "tag-key-1", "Value": "tag-value-1"}],
),
],
)
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
def test_create_with_no_bucket_name(
mock_generate_hub_storage_location,
sagemaker_session,
hub_name,
hub_description,
hub_bucket_name,
hub_display_name,
hub_search_keywords,
tags,
):
storage_location = S3ObjectLocation(
"sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}"
)
mock_generate_hub_storage_location.return_value = storage_location
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
sagemaker_session.create_hub = Mock(return_value=create_hub)
sagemaker_session.describe_hub.return_value = {
"S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"}
}
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session)
request = {
"hub_name": hub_name,
"hub_description": hub_description,
"hub_display_name": hub_display_name,
"hub_search_keywords": hub_search_keywords,
"s3_storage_config": {
"S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}"
},
"tags": tags,
}
response = hub.create(
Expand All @@ -128,9 +113,9 @@ def test_create_with_no_bucket_name(
),
],
)
@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location")
@patch("sagemaker.jumpstart.hub.hub.datetime")
def test_create_with_bucket_name(
mock_generate_hub_storage_location,
mock_datetime,
sagemaker_session,
hub_name,
hub_description,
Expand All @@ -139,8 +124,8 @@ def test_create_with_bucket_name(
hub_search_keywords,
tags,
):
storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}")
mock_generate_hub_storage_location.return_value = storage_location
mock_datetime.now.return_value = FAKE_TIME

create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
sagemaker_session.create_hub = Mock(return_value=create_hub)
hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name)
Expand All @@ -149,7 +134,9 @@ def test_create_with_bucket_name(
"hub_description": hub_description,
"hub_display_name": hub_display_name,
"hub_search_keywords": hub_search_keywords,
"s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"},
"s3_storage_config": {
"S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}"
},
"tags": tags,
}
response = hub.create(
Expand Down
41 changes: 0 additions & 41 deletions tests/unit/sagemaker/jumpstart/hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs():
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


def test_create_hub_bucket_if_it_does_not_exist_hub_arn():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
# Mock custom session with custom values
mock_custom_session = Mock()
mock_custom_session.account_id.return_value = "000000000000"
mock_custom_session.boto_region_name = "us-east-2"
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"

bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


def test_is_gated_bucket():
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True

Expand All @@ -207,23 +183,6 @@ def test_is_gated_bucket():
assert utils.is_gated_bucket("") is False


def test_create_hub_bucket_if_it_does_not_exist():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name


@patch("sagemaker.session.Session")
def test_get_hub_model_version_success(mock_session):
hub_name = "test_hub"
Expand Down