Skip to content

expose model inputs and outputs with respect to model definition #552

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 9 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo

RAI_Device device;
int64_t deviceid;
char **inputs_ = NULL;
char **outputs_ = NULL;

if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
Expand Down Expand Up @@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
goto error;
}

size_t n_input_nodes;
status = ort->SessionGetInputCount(session, &n_input_nodes);
if (status != NULL) {
goto error;
}

size_t n_output_nodes;
status = ort->SessionGetOutputCount(session, &n_output_nodes);
if (status != NULL) {
goto error;
}

OrtAllocator *allocator;
status = ort->GetAllocatorWithDefaultOptions(&allocator);

inputs_ = array_new(char *, n_input_nodes);
for (long long i = 0; i < n_input_nodes; i++) {
char *input_name;
status = ort->SessionGetInputName(session, i, allocator, &input_name);
if (status != NULL) {
goto error;
}
inputs_ = array_append(inputs_, input_name);
}

outputs_ = array_new(char *, n_output_nodes);
for (long long i = 0; i < n_output_nodes; i++) {
char *output_name;
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
if (status != NULL) {
goto error;
}
outputs_ = array_append(outputs_, output_name);
}

// Since ONNXRuntime doesn't have a re-serialization function,
// we cache the blob in order to re-serialize it.
// Not optimal for storage purposes, but again, it may be temporary
Expand All @@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
ret->opts = opts;
ret->data = buffer;
ret->datalen = modellen;
ret->ninputs = n_input_nodes;
ret->noutputs = n_output_nodes;
ret->inputs = inputs_;
ret->outputs = outputs_;

return ret;

error:
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
if (inputs_) {
n_input_nodes = array_len(inputs_);
for (uint32_t i = 0; i < n_input_nodes; i++) {
status = ort->AllocatorFree(allocator, inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
n_output_nodes = array_len(outputs_);
for (uint32_t i = 0; i < n_output_nodes; i++) {
status = ort->AllocatorFree(allocator, outputs_[i]);
}
array_free(outputs_);
}
ort->ReleaseStatus(status);
return NULL;
}
Expand All @@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {

RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
OrtAllocator *allocator;
OrtStatus *status = NULL;
status = ort->GetAllocatorWithDefaultOptions(&allocator);
for (uint32_t i = 0; i < model->ninputs; i++) {
status = ort->AllocatorFree(allocator, model->inputs[i]);
}
array_free(model->inputs);

for (uint32_t i = 0; i < model->noutputs; i++) {
status = ort->AllocatorFree(allocator, model->outputs[i]);
}
array_free(model->outputs);

ort->ReleaseSession(model->session);

model->model = NULL;
Expand Down
2 changes: 2 additions & 0 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
ret->session = session;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->opts = opts;
ret->refCount = 1;
Expand Down
70 changes: 66 additions & 4 deletions src/backends/tflite.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
const char *modeldef, size_t modellen, RAI_Error *error) {
DLDeviceType dl_device;

RAI_Device device;
int64_t deviceid;
char **inputs_ = NULL;
char **outputs_ = NULL;
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device");
return NULL;
Expand All @@ -47,6 +48,36 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
return NULL;
}

size_t ninputs = tfliteModelNumInputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

size_t noutputs = tfliteModelNumOutputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

inputs_ = array_new(char *, ninputs);
outputs_ = array_new(char *, noutputs);

for (size_t i = 0; i < ninputs; i++) {
const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr);
if (error_descr) {
goto cleanup;
}
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
}

for (size_t i = 0; i < noutputs; i++) {
const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr);
;
if (error_descr) {
goto cleanup;
}
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
}

char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

Expand All @@ -55,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
ret->session = NULL;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->inputs = NULL;
ret->outputs = NULL;
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->refCount = 1;
ret->opts = opts;
ret->data = buffer;
ret->datalen = modellen;

return ret;

cleanup:
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
if (inputs_) {
ninputs = array_len(inputs_);
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
noutputs = array_len(outputs_);
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(outputs_[i]);
}
array_free(outputs_);
}
return NULL;
}

void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) {
RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
tfliteDeallocContext(model->model);
size_t ninputs = model->ninputs;
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(model->inputs[i]);
}
array_free(model->inputs);

size_t noutputs = model->noutputs;
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(model->outputs[i]);
}
array_free(model->outputs);

model->model = NULL;
}
Expand Down
78 changes: 70 additions & 8 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
RAI_Device device = RAI_DEVICE_CPU;
int64_t deviceid = 0;

char **inputs_ = NULL;
char **outputs_ = NULL;

if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
return NULL;
Expand Down Expand Up @@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
if (opts.backends_intra_op_parallelism > 0) {
torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc);
}
if (error_descr != NULL) {
if (error_descr) {
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
return NULL;
Expand All @@ -62,10 +65,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
void *model =
torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc);

if (model == NULL) {
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
return NULL;
if (error_descr) {
goto cleanup;
}

size_t ninputs = torchModelNumInputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

size_t noutputs = torchModelNumOutputs(model, &error_descr);
if (error_descr) {
goto cleanup;
}

inputs_ = array_new(char *, ninputs);
outputs_ = array_new(char *, noutputs);

for (size_t i = 0; i < ninputs; i++) {
const char *input = torchModelInputNameAtIndex(model, i, &error_descr);
if (error_descr) {
goto cleanup;
}
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
}

for (size_t i = 0; i < noutputs; i++) {
const char *output = "";
if (error_descr) {
goto cleanup;
}
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
}

char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
Expand All @@ -76,14 +106,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
ret->session = NULL;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->inputs = NULL;
ret->outputs = NULL;
ret->ninputs = ninputs;
ret->inputs = inputs_;
ret->noutputs = noutputs;
ret->outputs = outputs_;
ret->opts = opts;
ret->refCount = 1;
ret->data = buffer;
ret->datalen = modellen;

return ret;

cleanup:
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
RedisModule_Free(error_descr);
if (inputs_) {
ninputs = array_len(inputs_);
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(inputs_[i]);
}
array_free(inputs_);
}
if (outputs_) {
noutputs = array_len(outputs_);
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(outputs_[i]);
}
array_free(outputs_);
}
return NULL;
}

void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
Expand All @@ -93,6 +143,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
if (model->data) {
RedisModule_Free(model->data);
}
size_t ninputs = model->ninputs;
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free(model->inputs[i]);
}
array_free(model->inputs);

size_t noutputs = model->noutputs;
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free(model->outputs[i]);
}
array_free(model->outputs);

torchDeallocContext(model->model);
}

Expand Down
10 changes: 4 additions & 6 deletions src/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,15 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
}
}
}
if ((*model)->inputs && (*model)->ninputs != ninputs) {
if ((*model)->ninputs != ninputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of names given as INPUTS during MODELSET and keys given as "
"INPUTS here do not match");
"Number of keys given as INPUTS here does not match model definition");
return REDISMODULE_ERR;
}

if ((*model)->outputs && (*model)->noutputs != noutputs) {
if ((*model)->noutputs != noutputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of names given as OUTPUTS during MODELSET and keys given as "
"OUTPUTS here do not match");
"Number of keys given as OUTPUTS here does not match model definition");
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
Expand Down
Loading