Skip to content

[WIP] Dataiku Python API for stratified models when forcing partition #64

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dataikuapi/apinode_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, uri, service_id, api_key=None):
"""
DSSBaseClient.__init__(self, "%s/%s" % (uri, "public/api/v1/%s" % service_id), api_key)

def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None):
def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None, partition=None):
"""
Predicts a single record on a DSS API node endpoint (standard or custom prediction)

Expand All @@ -25,6 +25,7 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:param context: Optional, Python dictionary of additional context information. The context information is logged, but not directly used.
:param partition: Optional, partition id of partitioned model to use. Guessed otherwise from record if needed.

:return: a Python dict of the API answer. The answer contains a "result" key (itself a dict)
"""
Expand All @@ -37,17 +38,20 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch
obj["dispatch"] = {"forcedGeneration" : forced_generation }
elif dispatch_key is not None:
obj["dispatch"] = {"dispatchKey" : dispatch_key }
if partition is not None:
obj["partition"] = partition

return self._perform_json("POST", "%s/predict" % endpoint_id, body = obj)

def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None):
def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None, partition=None):
"""
Predicts a batch of records on a DSS API node endpoint (standard or custom prediction)

:param str endpoint_id: Identifier of the endpoint to query
:param records: Python list of records. Each record must be a Python dict. Each record must contain a "features" dict (see predict_record) and optionally a "context" dict.
:param forced_generation: See documentation about multi-version prediction
:param dispatch_key: See documentation about multi-version prediction
:param partition: Optional, partition id of partitioned model to use for all records. Guessed otherwise from each record if needed.

:return: a Python dict of the API answer. The answer contains a "results" key (which is an array of result objects)
"""
Expand All @@ -64,6 +68,8 @@ def predict_records(self, endpoint_id, records, forced_generation=None, dispatch
obj["dispatch"] = {"forcedGeneration" : forced_generation }
elif dispatch_key is not None:
obj["dispatch"] = {"dispatchKey" : dispatch_key }
if partition is not None:
obj["partition"] = partition

return self._perform_json("POST", "%s/predict-multi" % endpoint_id, body = obj)

Expand Down