-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathtest_model_repack.py
63 lines (54 loc) · 3.34 KB
/
test_model_repack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import logging
import os
from pathlib import Path
import pytest
import sagemaker.config
from sagemaker import Model
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
logger = logging.getLogger('sagemaker-ssh-helper:test_model_repack')
@pytest.mark.skipif(os.getenv('PYTEST_IGNORE_SKIPS', "false") == "false",
reason="Not working yet")
def test_model_repacking_from_scratch():
model = Model(
image_uri="763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.20.0-deepspeed0.7.5-cu116",
role="arn:aws:iam::555555555555:role/service-role/AmazonSageMaker-ExecutionRole-Example",
entry_point=(p := Path('source_dir/inference_hf_accelerate/inference_ssh.py')).name,
source_dir=str(p.parents[0]),
dependencies=[SSHEstimatorWrapper.dependency_dir()],
sagemaker_session=sagemaker.Session(), # FIXME: otherwise AttributeError: 'NoneType' object has no attribute 'config'
)
_ = model.prepare_container_def(instance_type='ml.m5.xlarge')
logging.info("Model data: %s", model.repacked_model_data)
assert model.repacked_model_data is not None # FIXME: not working
# FIXME: SAGEMAKER_SUBMIT_DIRECTORY = file://source_dir/inference_hf_accelerate instead of /opt/ml/model/code
@pytest.mark.skipif(os.getenv('PYTEST_IGNORE_SKIPS', "false") == "false",
reason="Manual test so far, because needs existing model")
def test_model_repacking_with_existing_model():
model = Model(
model_data="s3://sagemaker-eu-west-1-169264033083/data/acc_model.tar.gz",
image_uri="763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.20.0-deepspeed0.7.5-cu116",
role="arn:aws:iam::555555555555:role/service-role/AmazonSageMaker-ExecutionRole-Example",
entry_point=(p := Path('source_dir/inference_hf_accelerate/inference_ssh.py')).name,
source_dir=str(p.parents[0]),
dependencies=[SSHEstimatorWrapper.dependency_dir()],
sagemaker_session=sagemaker.Session(), # FIXME: otherwise AttributeError: 'NoneType' object has no attribute 'config'
)
_ = model.prepare_container_def(instance_type='ml.m5.xlarge')
logging.info("Model data: %s", model.repacked_model_data)
assert model.repacked_model_data is not None
@pytest.mark.skipif(os.getenv('PYTEST_IGNORE_SKIPS', "false") == "false",
reason="Not working yet")
def test_model_repacking_default_entry_point_with_existing_model():
model = Model(
model_data="s3://sagemaker-eu-west-1-169264033083/data/acc_model.tar.gz",
image_uri="763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.20.0-deepspeed0.7.5-cu116",
role="arn:aws:iam::555555555555:role/service-role/AmazonSageMaker-ExecutionRole-Example",
source_dir=str(Path('source_dir/inference_hf_accelerate/')),
# entry_point is send in the DJL serving.properties file
dependencies=[SSHEstimatorWrapper.dependency_dir()],
sagemaker_session=sagemaker.Session(), # FIXME: otherwise AttributeError: 'NoneType' object has no attribute 'config'
)
_ = model.prepare_container_def(instance_type='ml.m5.xlarge')
logging.info("Model data: %s", model.repacked_model_data)
assert model.repacked_model_data is not None # FIXME: not working
# FIXME: SAGEMAKER_SUBMIT_DIRECTORY = file://source_dir/inference_hf_accelerate instead of /opt/ml/model/code