diff --git a/dataikuapi/apinode_client.py b/dataikuapi/apinode_client.py index 0528277e..fe75a1fc 100644 --- a/dataikuapi/apinode_client.py +++ b/dataikuapi/apinode_client.py @@ -16,7 +16,7 @@ def __init__(self, uri, service_id, api_key=None): """ DSSBaseClient.__init__(self, "%s/%s" % (uri, "public/api/v1/%s" % service_id), api_key) - def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None): + def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None, partition=None): """ Predicts a single record on a DSS API node endpoint (standard or custom prediction) @@ -25,6 +25,7 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch :param forced_generation: See documentation about multi-version prediction :param dispatch_key: See documentation about multi-version prediction :param context: Optional, Python dictionary of additional context information. The context information is logged, but not directly used. + :param partition: Optional, partition id of partitioned model to use. Guessed otherwise from record if needed. :return: a Python dict of the API answer. The answer contains a "result" key (itself a dict) """ @@ -37,10 +38,12 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch obj["dispatch"] = {"forcedGeneration" : forced_generation } elif dispatch_key is not None: obj["dispatch"] = {"dispatchKey" : dispatch_key } + if partition is not None: + obj["partition"] = partition return self._perform_json("POST", "%s/predict" % endpoint_id, body = obj) - def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None): + def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None, partition=None): """ Predicts a batch of records on a DSS API node endpoint (standard or custom prediction) @@ -48,6 +51,7 @@ def predict_records(self, endpoint_id, records, forced_generation=None, dispatch :param records: Python list of records. Each record must be a Python dict. Each record must contain a "features" dict (see predict_record) and optionally a "context" dict. :param forced_generation: See documentation about multi-version prediction :param dispatch_key: See documentation about multi-version prediction + :param partition: Optional, partition id of partitioned model to use for all records. Guessed otherwise from each record if needed. :return: a Python dict of the API answer. The answer contains a "results" key (which is an array of result objects) """ @@ -64,6 +68,8 @@ def predict_records(self, endpoint_id, records, forced_generation=None, dispatch obj["dispatch"] = {"forcedGeneration" : forced_generation } elif dispatch_key is not None: obj["dispatch"] = {"dispatchKey" : dispatch_key } + if partition is not None: + obj["partition"] = partition return self._perform_json("POST", "%s/predict-multi" % endpoint_id, body = obj)