Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add support for models.create and hardware.list endpoints #184

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,29 @@ urlretrieve(out[0], "/tmp/out.png")
background = Image.open("/tmp/out.png")
```

## Create a model

You can create a model for a user or organization
with a given name, visibility, and hardware SKU:

```python
import replicate

model = replicate.models.create(
owner="your-username",
name="my-model",
visibility="public",
hardware="gpu-a40-large"
)
```

Here's how to list of all the available hardware for running models on Replicate:

```python
>>> [hw.sku for hw in replicate.hardware.list()]
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
```

## Development

See [CONTRIBUTING.md](CONTRIBUTING.md)
3 changes: 2 additions & 1 deletion replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

default_client = Client()
run = default_client.run
hardware = default_client.hardware
deployments = default_client.deployments
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
deployments = default_client.deployments
22 changes: 15 additions & 7 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.exceptions import ModelError, ReplicateError
from replicate.hardware import HardwareCollection
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.schema import make_schema_backwards_compatible
Expand Down Expand Up @@ -83,6 +84,20 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:

return resp

@property
def deployments(self) -> DeploymentCollection:
"""
Namespace for operations related to deployments.
"""
return DeploymentCollection(client=self)

@property
def hardware(self) -> HardwareCollection:
"""
Namespace for operations related to hardware.
"""
return HardwareCollection(client=self)

@property
def models(self) -> ModelCollection:
"""
Expand All @@ -104,13 +119,6 @@ def trainings(self) -> TrainingCollection:
"""
return TrainingCollection(client=self)

@property
def deployments(self) -> DeploymentCollection:
"""
Namespace for operations related to deployments.
"""
return DeploymentCollection(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.
Expand Down
50 changes: 50 additions & 0 deletions replicate/hardware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Dict, List, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection


class Hardware(BaseModel):
"""
Hardware for running a model on Replicate.
"""

sku: str
"""
The SKU of the hardware.
"""

name: str
"""
The name of the hardware.
"""


class HardwareCollection(Collection):
"""
Namespace for operations related to hardware.
"""

model = Hardware

def list(self) -> List[Hardware]:
"""
List all public models.

Returns:
A list of models.
"""

resp = self._client._request("GET", "/v1/hardware")
hardware = resp.json()
return [self._prepare_model(obj) for obj in hardware]

def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
if isinstance(attrs, BaseModel):
attrs.id = attrs.sku
elif isinstance(attrs, dict):
attrs["id"] = attrs["sku"]

hardware = super()._prepare_model(attrs)

return hardware
57 changes: 57 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,63 @@ def get(self, key: str) -> Model:
resp = self._client._request("GET", f"/v1/models/{key}")
return self._prepare_model(resp.json())

def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
owner: str,
name: str,
*,
visibility: str,
hardware: str,
description: Optional[str] = None,
github_url: Optional[str] = None,
paper_url: Optional[str] = None,
license_url: Optional[str] = None,
cover_image_url: Optional[str] = None,
) -> Model:
"""
Create a model.

Args:
owner: The name of the user or organization that will own the model.
name: The name of the model.
visibility: Whether the model should be public or private.
hardware: The SKU for the hardware used to run the model. Possible values can be found by calling `replicate.hardware.list()`.
description: A description of the model.
github_url: A URL for the model's source code on GitHub.
paper_url: A URL for the model's paper.
license_url: A URL for the model's license.
cover_image_url: A URL for the model's cover image.

Returns:
The created model.
"""

body = {
"owner": owner,
"name": name,
"visibility": visibility,
"hardware": hardware,
}

if description is not None:
body["description"] = description

if github_url is not None:
body["github_url"] = github_url

if paper_url is not None:
body["paper_url"] = paper_url

if license_url is not None:
body["license_url"] = license_url

if cover_image_url is not None:
body["cover_image_url"] = cover_image_url

resp = self._client._request("POST", "/v1/models", json=body)

return self._prepare_model(resp.json())

def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.owner}/{attrs.name}"
Expand Down
74 changes: 74 additions & 0 deletions tests/cassettes/hardware-list.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
interactions:
- request:
body: ''
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
host:
- api.replicate.com
user-agent:
- replicate-python/0.15.5
method: GET
uri: https://api.replicate.com/v1/hardware
response:
content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia
A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]'
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 81fbfed29fe1c58a-SEA
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 02 Nov 2023 11:21:41 GMT
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
Transfer-Encoding:
- chunked
allow:
- OPTIONS, GET
content-security-policy-report-only:
- 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery
https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io;
worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample''
''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery
https://*.replicate.delivery https://*.githubusercontent.com https://github.com;
default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri'
cross-origin-opener-policy:
- same-origin
nel:
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
ratelimit-remaining:
- '2999'
ratelimit-reset:
- '1'
referrer-policy:
- same-origin
report-to:
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}'
reporting-endpoints:
- heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D
vary:
- Cookie, origin
via:
- 1.1 vegur, 1.1 google
x-content-type-options:
- nosniff
x-frame-options:
- DENY
http_version: HTTP/1.1
status_code: 200
version: 1
80 changes: 80 additions & 0 deletions tests/cassettes/models-create.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
interactions:
- request:
body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware":
"cpu", "description": "An example model"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '123'
content-type:
- application/json
host:
- api.replicate.com
user-agent:
- replicate-python/0.15.6
method: POST
uri: https://api.replicate.com/v1/models
response:
content: '{"url": "https://replicate.com/test/python-example", "owner": "test",
"name": "python-example", "description": "An example model", "visibility": "private",
"github_url": null, "paper_url": null, "license_url": null, "run_count": 0,
"cover_image_url": null, "default_example": null, "latest_version": null}'
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 81ff2e098ec0eb5b-SEA
Connection:
- keep-alive
Content-Length:
- '307'
Content-Type:
- application/json
Date:
- Thu, 02 Nov 2023 20:38:12 GMT
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
allow:
- GET, POST, HEAD, OPTIONS
content-security-policy-report-only:
- 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self''
data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com
https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample''
''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com
https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src
''none''; media-src ''report-sample'' ''self'' https://replicate.delivery
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src
''self''; report-uri'
cross-origin-opener-policy:
- same-origin
nel:
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
ratelimit-remaining:
- '2999'
ratelimit-reset:
- '1'
referrer-policy:
- same-origin
report-to:
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}'
reporting-endpoints:
- heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D
vary:
- Cookie, origin
via:
- 1.1 vegur, 1.1 google
x-content-type-options:
- nosniff
x-frame-options:
- DENY
http_version: HTTP/1.1
status_code: 201
version: 1
13 changes: 13 additions & 0 deletions tests/test_hardware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

import replicate


@pytest.mark.vcr("hardware-list.yaml")
@pytest.mark.asyncio
async def test_hardware_list(mock_replicate_api_token):
hardware = replicate.hardware.list()

assert hardware is not None
assert isinstance(hardware, list)
assert len(hardware) > 0
Loading