From 209400b67c57dabca3c575349ec6d43a9aff301d Mon Sep 17 00:00:00 2001 From: mikhailandreev Date: Tue, 6 Jun 2023 16:01:17 -0700 Subject: [PATCH 1/2] extract_url_param helper function in client Signed-off-by: mikhailandreev --- replicate/client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/replicate/client.py b/replicate/client.py index a08dacf0..869e0abc 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -2,6 +2,7 @@ import re from json import JSONDecodeError from typing import Any, Dict, Iterator, Optional, Union +from urllib.parse import parse_qs, urlparse import requests from requests.adapters import HTTPAdapter, Retry @@ -82,6 +83,14 @@ def _request(self, method: str, path: str, **kwargs) -> requests.Response: raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp + def _extract_url_param(self, url: str, param: str) -> str: + """Extract a query parameter from a URL. First used in replicate/prediction.py + to extract pagination cursors from API-returned URLs.""" + parsed_url = urlparse(url) + params = parse_qs(parsed_url.query) + cursor = params.get(param, [None])[0] + return cursor + def _headers(self) -> Dict[str, str]: return { "Authorization": f"Token {self._api_token()}", From fb50936692c42dd320c22d3990d2eb83d9c529ec Mon Sep 17 00:00:00 2001 From: mikhailandreev Date: Tue, 6 Jun 2023 16:01:35 -0700 Subject: [PATCH 2/2] pagination for predictions.list + predictions.list_after_date fn Signed-off-by: mikhailandreev --- replicate/prediction.py | 53 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index dd7a593c..7f3b2478 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,4 +1,5 @@ import time +from datetime import datetime from typing import Any, Dict, Iterator, List, Optional from replicate.base_model import BaseModel @@ -54,14 +55,56 @@ def cancel(self) -> None: class PredictionCollection(Collection): model = Prediction - def list(self) -> List[Prediction]: - resp = self._client._request("GET", "/v1/predictions") - # TODO: paginate - predictions = resp.json()["results"] + def list(self, paginate=False, cursor=None) -> List[Prediction]: + # if paginating, use passed in cursor + req_url = "/v1/predictions" if not paginate or cursor is None else f"/v1/predictions?cursor={cursor}" + + resp = self._client._request("GET", req_url) + resp_dict = resp.json() + + predictions = resp_dict["results"] for prediction in predictions: # HACK: resolve this? make it lazy somehow? del prediction["version"] - return [self.prepare_model(obj) for obj in predictions] + + predictions_results = [self.prepare_model(obj) for obj in predictions] + + # backwards compatibility for non-paginated results + if paginate: + # make sure this code can handle entirely missing "next" field in API response + if "next" in resp_dict and resp_dict["next"] is not None: + next_cursor = self._client._extract_url_param(resp_dict["next"], "cursor") + return predictions_results, next_cursor + else: + # None cursor is treated as "no more results" + return predictions_results, None + else: + return predictions_results + + def list_after_date(self, date: str) -> List[Prediction]: + """List predictions created after a given date (in ISO 8601 format, e.g. '2023-06-06T20:25:13.031191Z'. + Will continously get pages of size 100 until the date is reached (might take a while).""" + date = datetime.fromisoformat(date.replace("Z", "+00:00")) + results, cursor = self.list(paginate=True) + while True: + next_results, cursor = self.list(paginate=True, cursor=cursor) + results.extend(next_results) + + if cursor is None: + break + + datetime_objects = [datetime.fromisoformat(prediction.created_at.replace("Z", "+00:00")) for prediction in next_results] + earliest_datetime = min(datetime_objects) + if earliest_datetime < date: + break + + # filter out predictions created before date + results = [prediction for prediction in results if datetime.fromisoformat(prediction.created_at.replace("Z", "+00:00")) >= date] + + return results + + + def get(self, id: str) -> Prediction: resp = self._client._request("GET", f"/v1/predictions/{id}")