Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Added pagination for predictions.list + a new predictions.list_before_date function #108

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from json import JSONDecodeError
from typing import Any, Dict, Iterator, Optional, Union
from urllib.parse import parse_qs, urlparse

import requests
from requests.adapters import HTTPAdapter, Retry
Expand Down Expand Up @@ -82,6 +83,14 @@ def _request(self, method: str, path: str, **kwargs) -> requests.Response:
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
return resp

def _extract_url_param(self, url: str, param: str) -> str:
"""Extract a query parameter from a URL. First used in replicate/prediction.py
to extract pagination cursors from API-returned URLs."""
parsed_url = urlparse(url)
params = parse_qs(parsed_url.query)
cursor = params.get(param, [None])[0]
return cursor

def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Token {self._api_token()}",
Expand Down
53 changes: 48 additions & 5 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional

from replicate.base_model import BaseModel
Expand Down Expand Up @@ -54,14 +55,56 @@ def cancel(self) -> None:
class PredictionCollection(Collection):
model = Prediction

def list(self) -> List[Prediction]:
resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
def list(self, paginate=False, cursor=None) -> List[Prediction]:
# if paginating, use passed in cursor
req_url = "/v1/predictions" if not paginate or cursor is None else f"/v1/predictions?cursor={cursor}"

resp = self._client._request("GET", req_url)
resp_dict = resp.json()

predictions = resp_dict["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]

predictions_results = [self.prepare_model(obj) for obj in predictions]

# backwards compatibility for non-paginated results
if paginate:
# make sure this code can handle entirely missing "next" field in API response
if "next" in resp_dict and resp_dict["next"] is not None:
next_cursor = self._client._extract_url_param(resp_dict["next"], "cursor")
return predictions_results, next_cursor
else:
# None cursor is treated as "no more results"
return predictions_results, None
else:
return predictions_results

def list_after_date(self, date: str) -> List[Prediction]:
"""List predictions created after a given date (in ISO 8601 format, e.g. '2023-06-06T20:25:13.031191Z'.
Will continously get pages of size 100 until the date is reached (might take a while)."""
date = datetime.fromisoformat(date.replace("Z", "+00:00"))
results, cursor = self.list(paginate=True)
while True:
next_results, cursor = self.list(paginate=True, cursor=cursor)
results.extend(next_results)

if cursor is None:
break

datetime_objects = [datetime.fromisoformat(prediction.created_at.replace("Z", "+00:00")) for prediction in next_results]
earliest_datetime = min(datetime_objects)
if earliest_datetime < date:
break

# filter out predictions created before date
results = [prediction for prediction in results if datetime.fromisoformat(prediction.created_at.replace("Z", "+00:00")) >= date]

return results




def get(self, id: str) -> Prediction:
resp = self._client._request("GET", f"/v1/predictions/{id}")
Expand Down