From c41c8c5176afddaf3a41d5b19e82af25a692c01c Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 1 Nov 2023 14:30:38 -0700 Subject: [PATCH 1/3] Add support for hardware.list endpoint Signed-off-by: Mattt Zmuda --- replicate/__init__.py | 3 +- replicate/client.py | 22 ++++++--- replicate/hardware.py | 50 ++++++++++++++++++++ tests/cassettes/hardware-list.yaml | 74 ++++++++++++++++++++++++++++++ tests/test_hardware.py | 15 ++++++ 5 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 replicate/hardware.py create mode 100644 tests/cassettes/hardware-list.yaml create mode 100644 tests/test_hardware.py diff --git a/replicate/__init__.py b/replicate/__init__.py index d7cfeb8d..f03a4519 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -2,7 +2,8 @@ default_client = Client() run = default_client.run +hardware = default_client.hardware +deployments = default_client.deployments models = default_client.models predictions = default_client.predictions trainings = default_client.trainings -deployments = default_client.deployments diff --git a/replicate/client.py b/replicate/client.py index dccc1433..0f2277c2 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -17,6 +17,7 @@ from replicate.__about__ import __version__ from replicate.deployment import DeploymentCollection from replicate.exceptions import ModelError, ReplicateError +from replicate.hardware import HardwareCollection from replicate.model import ModelCollection from replicate.prediction import PredictionCollection from replicate.schema import make_schema_backwards_compatible @@ -83,6 +84,20 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response: return resp + @property + def deployments(self) -> DeploymentCollection: + """ + Namespace for operations related to deployments. + """ + return DeploymentCollection(client=self) + + @property + def hardware(self) -> HardwareCollection: + """ + Namespace for operations related to hardware. + """ + return HardwareCollection(client=self) + @property def models(self) -> ModelCollection: """ @@ -104,13 +119,6 @@ def trainings(self) -> TrainingCollection: """ return TrainingCollection(client=self) - @property - def deployments(self) -> DeploymentCollection: - """ - Namespace for operations related to deployments. - """ - return DeploymentCollection(client=self) - def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. diff --git a/replicate/hardware.py b/replicate/hardware.py new file mode 100644 index 00000000..5feee199 --- /dev/null +++ b/replicate/hardware.py @@ -0,0 +1,50 @@ +from typing import Dict, List, Union + +from replicate.base_model import BaseModel +from replicate.collection import Collection + + +class Hardware(BaseModel): + """ + Hardware for running a model on Replicate. + """ + + sku: str + """ + The SKU of the hardware. + """ + + name: str + """ + The name of the hardware. + """ + + +class HardwareCollection(Collection): + """ + Namespace for operations related to hardware. + """ + + model = Hardware + + def list(self) -> List[Hardware]: + """ + List all public models. + + Returns: + A list of models. + """ + + resp = self._client._request("GET", "/v1/hardware") + hardware = resp.json() + return [self._prepare_model(obj) for obj in hardware] + + def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware: + if isinstance(attrs, BaseModel): + attrs.id = attrs.sku + elif isinstance(attrs, dict): + attrs["id"] = attrs["sku"] + + hardware = super()._prepare_model(attrs) + + return hardware diff --git a/tests/cassettes/hardware-list.yaml b/tests/cassettes/hardware-list.yaml new file mode 100644 index 00000000..11e2988e --- /dev/null +++ b/tests/cassettes/hardware-list.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.15.5 + method: GET + uri: https://api.replicate.com/v1/hardware + response: + content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia + A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 81fbfed29fe1c58a-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Thu, 02 Nov 2023 11:21:41 GMT + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - OPTIONS, GET + content-security-policy-report-only: + - 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery + https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io; + worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample'' + ''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery + https://*.replicate.delivery https://*.githubusercontent.com https://github.com; + default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery + https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri' + cross-origin-opener-policy: + - same-origin + nel: + - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + referrer-policy: + - same-origin + report-to: + - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}' + reporting-endpoints: + - heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + http_version: HTTP/1.1 + status_code: 200 +version: 1 diff --git a/tests/test_hardware.py b/tests/test_hardware.py new file mode 100644 index 00000000..828a368d --- /dev/null +++ b/tests/test_hardware.py @@ -0,0 +1,15 @@ +import httpx +import pytest +import respx + +import replicate + + +@pytest.mark.vcr("hardware-list.yaml") +@pytest.mark.asyncio +async def test_hardware_list(mock_replicate_api_token): + hardware = replicate.hardware.list() + + assert hardware is not None + assert isinstance(hardware, list) + assert len(hardware) > 0 From fed75e2e61851cf7570147e2d1a5f33bad3b9316 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 05:28:49 -0700 Subject: [PATCH 2/3] Add support for models.create endpoint Signed-off-by: Mattt Zmuda --- replicate/model.py | 57 +++++++++++++++++++++ tests/cassettes/models-create.yaml | 80 ++++++++++++++++++++++++++++++ tests/test_hardware.py | 2 - tests/test_model.py | 31 ++++++++++++ 4 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 tests/cassettes/models-create.yaml diff --git a/replicate/model.py b/replicate/model.py index babe9931..485e2c83 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -150,6 +150,63 @@ def get(self, key: str) -> Model: resp = self._client._request("GET", f"/v1/models/{key}") return self._prepare_model(resp.json()) + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + owner: str, + name: str, + *, + visibility: str, + hardware: str, + description: Optional[str] = None, + github_url: Optional[str] = None, + paper_url: Optional[str] = None, + license_url: Optional[str] = None, + cover_image_url: Optional[str] = None, + ) -> Model: + """ + Create a model. + + Args: + owner: The name of the user or organization that will own the model. + name: The name of the model. + visibility: Whether the model should be public or private. + hardware: The SKU for the hardware used to run the model. Possible values can be found by calling `replicate.hardware.list()`. + description: A description of the model. + github_url: A URL for the model's source code on GitHub. + paper_url: A URL for the model's paper. + license_url: A URL for the model's license. + cover_image_url: A URL for the model's cover image. + + Returns: + The created model. + """ + + body = { + "owner": owner, + "name": name, + "visibility": visibility, + "hardware": hardware, + } + + if description is not None: + body["description"] = description + + if github_url is not None: + body["github_url"] = github_url + + if paper_url is not None: + body["paper_url"] = paper_url + + if license_url is not None: + body["license_url"] = license_url + + if cover_image_url is not None: + body["cover_image_url"] = cover_image_url + + resp = self._client._request("POST", "/v1/models", json=body) + + return self._prepare_model(resp.json()) + def _prepare_model(self, attrs: Union[Model, Dict]) -> Model: if isinstance(attrs, BaseModel): attrs.id = f"{attrs.owner}/{attrs.name}" diff --git a/tests/cassettes/models-create.yaml b/tests/cassettes/models-create.yaml new file mode 100644 index 00000000..c7a8639f --- /dev/null +++ b/tests/cassettes/models-create.yaml @@ -0,0 +1,80 @@ +interactions: +- request: + body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware": + "cpu", "description": "An example model"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '123' + content-type: + - application/json + host: + - api.replicate.com + user-agent: + - replicate-python/0.15.6 + method: POST + uri: https://api.replicate.com/v1/models + response: + content: '{"url": "https://replicate.com/test/python-example", "owner": "test", + "name": "python-example", "description": "An example model", "visibility": "private", + "github_url": null, "paper_url": null, "license_url": null, "run_count": 0, + "cover_image_url": null, "default_example": null, "latest_version": null}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 81ff2e098ec0eb5b-SEA + Connection: + - keep-alive + Content-Length: + - '307' + Content-Type: + - application/json + Date: + - Thu, 02 Nov 2023 20:38:12 GMT + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + allow: + - GET, POST, HEAD, OPTIONS + content-security-policy-report-only: + - 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self'' + data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com + https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample'' + ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src + ''none''; media-src ''report-sample'' ''self'' https://replicate.delivery + https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src + ''self''; report-uri' + cross-origin-opener-policy: + - same-origin + nel: + - '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}' + ratelimit-remaining: + - '2999' + ratelimit-reset: + - '1' + referrer-policy: + - same-origin + report-to: + - '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}' + reporting-endpoints: + - heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + http_version: HTTP/1.1 + status_code: 201 +version: 1 diff --git a/tests/test_hardware.py b/tests/test_hardware.py index 828a368d..15da0e79 100644 --- a/tests/test_hardware.py +++ b/tests/test_hardware.py @@ -1,6 +1,4 @@ -import httpx import pytest -import respx import replicate diff --git a/tests/test_model.py b/tests/test_model.py index a228b6d2..2ae5a42e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -27,3 +27,34 @@ async def test_models_list(mock_replicate_api_token): assert models[0].owner is not None assert models[0].name is not None assert models[0].visibility == "public" + + +@pytest.mark.vcr("models-create.yaml") +@pytest.mark.asyncio +async def test_models_create(mock_replicate_api_token): + model = replicate.models.create( + owner="test", + name="python-example", + visibility="private", + hardware="cpu", + description="An example model", + ) + + assert model.owner == "test" + assert model.name == "python-example" + assert model.visibility == "private" + + +@pytest.mark.vcr("models-create.yaml") +@pytest.mark.asyncio +async def test_models_create_with_positional_arguments(mock_replicate_api_token): + model = replicate.models.create( + "test", + "python-example", + visibility="private", + hardware="cpu", + ) + + assert model.owner == "test" + assert model.name == "python-example" + assert model.visibility == "private" From 041089af37f7ca710d6c6bea215c6dd1f3286ed5 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 5 Nov 2023 05:16:07 -0800 Subject: [PATCH 3/3] Update README Signed-off-by: Mattt Zmuda --- README.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/README.md b/README.md index c4ad4d7f..c627414b 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,29 @@ urlretrieve(out[0], "/tmp/out.png") background = Image.open("/tmp/out.png") ``` +## Create a model + +You can create a model for a user or organization +with a given name, visibility, and hardware SKU: + +```python +import replicate + +model = replicate.models.create( + owner="your-username", + name="my-model", + visibility="public", + hardware="gpu-a40-large" +) +``` + +Here's how to list of all the available hardware for running models on Replicate: + +```python +>>> [hw.sku for hw in replicate.hardware.list()] +['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large'] +``` + ## Development See [CONTRIBUTING.md](CONTRIBUTING.md)