From c36e28e87d1650bf1204d6c5d074b115d9ee6ed3 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Sun, 3 Jan 2021 21:54:15 +0200 Subject: [PATCH 1/6] expose model inputs and outputs with respect to model definition --- src/backends/onnxruntime.c | 68 +++++++++++++++++++++++++++++ src/backends/tensorflow.c | 2 + src/backends/tflite.c | 69 +++++++++++++++++++++++++++-- src/backends/torch.c | 79 ++++++++++++++++++++++++++++++---- src/command_parser.c | 6 +-- src/libtflite_c/tflite_c.cpp | 50 +++++++++++++++++++++ src/libtflite_c/tflite_c.h | 8 ++++ src/libtorch_c/torch_c.cpp | 66 ++++++++++++++++++++++++++++ src/libtorch_c/torch_c.h | 7 +++ src/model.c | 2 - tests/flow/tests_onnx.py | 19 ++++---- tests/flow/tests_pytorch.py | 4 +- tests/flow/tests_tensorflow.py | 2 +- tests/flow/tests_tflite.py | 12 +++--- 14 files changed, 356 insertions(+), 38 deletions(-) diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index db15994cb..e5bf2bc48 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo RAI_Device device; int64_t deviceid; + char** inputs_ = NULL; + char** outputs_ = NULL; if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device"); @@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo goto error; } + size_t n_input_nodes; + status = ort->SessionGetInputCount(session, &n_input_nodes); + if (status != NULL) { + goto error; + } + + size_t n_output_nodes; + status = ort->SessionGetOutputCount(session, &n_output_nodes); + if (status != NULL) { + goto error; + } + + OrtAllocator *allocator; + status = ort->GetAllocatorWithDefaultOptions(&allocator); + + inputs_ = array_new(char*, n_input_nodes); + for (long long i = 0; i < n_input_nodes; i++) { + char* input_name; + status = ort->SessionGetInputName(session, i, allocator, &input_name); + if (status != NULL) { + goto error; + } + inputs_ = array_append(inputs_, input_name); + } + + outputs_ = array_new(char *, n_output_nodes); + for (long long i = 0; i < n_output_nodes; i++) { + char* output_name; + status = ort->SessionGetOutputName(session, i, allocator, &output_name); + if (status != NULL) { + goto error; + } + outputs_ = array_append(outputs_, output_name); + } + // Since ONNXRuntime doesn't have a re-serialization function, // we cache the blob in order to re-serialize it. // Not optimal for storage purposes, but again, it may be temporary @@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo ret->opts = opts; ret->data = buffer; ret->datalen = modellen; + ret->ninputs = n_input_nodes; + ret->noutputs = n_output_nodes; + ret->inputs = inputs_; + ret->outputs = outputs_; return ret; error: RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status)); + if(inputs_) { + n_input_nodes = array_len(inputs_); + for(uint32_t i = 0; i AllocatorFree(allocator, inputs_[i]); + } + array_free(inputs_); + } + if(outputs_){ + n_output_nodes = array_len(outputs_); + for(uint32_t i = 0; i AllocatorFree(allocator, outputs_[i]); + } + array_free(outputs_); + } ort->ReleaseStatus(status); return NULL; } @@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) { RedisModule_Free(model->data); RedisModule_Free(model->devicestr); + OrtAllocator *allocator; + OrtStatus *status = NULL; + status = ort->GetAllocatorWithDefaultOptions(&allocator); + for(uint32_t i = 0; i < model->ninputs; i++) { + status = ort->AllocatorFree(allocator, model->inputs[i]); + } + array_free(model->inputs); + + for(uint32_t i = 0; i < model->noutputs; i++) { + status = ort->AllocatorFree(allocator, model->outputs[i]); + } + array_free(model->outputs); + ort->ReleaseSession(model->session); model->model = NULL; diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 12820d98b..a681e326a 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod ret->session = session; ret->backend = backend; ret->devicestr = RedisModule_Strdup(devicestr); + ret->ninputs = ninputs; ret->inputs = inputs_; + ret->noutputs = noutputs; ret->outputs = outputs_; ret->opts = opts; ret->refCount = 1; diff --git a/src/backends/tflite.c b/src/backends/tflite.c index e006d3894..3bdbe2d1f 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) { RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *error) { DLDeviceType dl_device; - RAI_Device device; int64_t deviceid; + char** inputs_ = NULL; + char** outputs_ = NULL; if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device"); return NULL; @@ -48,6 +49,35 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI return NULL; } + size_t ninputs = tfliteModelNumInputs(model, &error_descr, RedisModule_Alloc); + if(error_descr) { + goto cleanup; + } + + size_t noutputs = tfliteModelNumOutputs(model, &error_descr, RedisModule_Alloc); + if(error_descr) { + goto cleanup; + } + + inputs_ = array_new(char*, ninputs); + outputs_ = array_new(char*, noutputs); + + for (size_t i = 0; i < ninputs; i++) { + const char* input = tfliteModelInputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); + if(error_descr) { + goto cleanup; + } + inputs_ = array_append(inputs_, RedisModule_Strdup(input)); + } + + for (size_t i = 0; i < noutputs; i++) { + const char* output = tfliteModelOutputNameAtIndex(model, i, &error_descr, RedisModule_Alloc);; + if(error_descr) { + goto cleanup; + } + outputs_ = array_append(outputs_, RedisModule_Strdup(output)); + } + char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); memcpy(buffer, modeldef, modellen); @@ -56,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI ret->session = NULL; ret->backend = backend; ret->devicestr = RedisModule_Strdup(devicestr); - ret->inputs = NULL; - ret->outputs = NULL; + ret->ninputs = ninputs; + ret->inputs = inputs_; + ret->noutputs = noutputs; + ret->outputs = outputs_; ret->refCount = 1; ret->opts = opts; ret->data = buffer; ret->datalen = modellen; - return ret; + +cleanup: + RAI_SetError(error, RAI_EMODELCREATE, error_descr); + RedisModule_Free(error_descr); + if(inputs_) { + ninputs = array_len(inputs_); + for(size_t i =0 ; i < ninputs; i++) { + RedisModule_Free(inputs_[i]); + } + array_free(inputs_); + } + if(outputs_) { + noutputs = array_len(outputs_); + for(size_t i =0 ; i < noutputs; i++) { + RedisModule_Free(outputs_[i]); + } + array_free(outputs_); + } + return NULL; } void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) { RedisModule_Free(model->data); RedisModule_Free(model->devicestr); tfliteDeallocContext(model->model); + size_t ninputs = model->ninputs; + for(size_t i =0 ; i < ninputs; i++) { + RedisModule_Free(model->inputs[i]); + } + array_free(model->inputs); + + size_t noutputs = model->noutputs; + for(size_t i =0 ; i < noutputs; i++) { + RedisModule_Free(model->outputs[i]); + } + array_free(model->outputs); model->model = NULL; } diff --git a/src/backends/torch.c b/src/backends/torch.c index f1d67a0c3..3b84e508c 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ RAI_Device device = RAI_DEVICE_CPU; int64_t deviceid = 0; + char** inputs_ = NULL; + char** outputs_ = NULL; + if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device"); return NULL; @@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ if (opts.backends_intra_op_parallelism > 0) { torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc); } - if (error_descr != NULL) { + if (error_descr) { RAI_SetError(error, RAI_EMODELCREATE, error_descr); RedisModule_Free(error_descr); return NULL; @@ -62,28 +65,76 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ void *model = torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc); - if (model == NULL) { - RAI_SetError(error, RAI_EMODELCREATE, error_descr); - RedisModule_Free(error_descr); - return NULL; + if (error_descr) { + goto cleanup; + } + + size_t ninputs = torchModelNumInputs(model, &error_descr); + if(error_descr) { + goto cleanup; + } + + size_t noutputs = torchModelNumOutputs(model, &error_descr); + if(error_descr) { + goto cleanup; + } + + inputs_ = array_new(char*, ninputs); + outputs_ = array_new(char*, noutputs); + + for (size_t i = 0; i < ninputs; i++) { + const char* input = torchModelInputNameAtIndex(model, i, &error_descr); + if(error_descr) { + goto cleanup; + } + inputs_ = array_append(inputs_, RedisModule_Strdup(input)); + } + + for (size_t i = 0; i < noutputs; i++) { + const char* output =""; + if(error_descr) { + goto cleanup; + } + outputs_ = array_append(outputs_, RedisModule_Strdup(output)); } char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); memcpy(buffer, modeldef, modellen); + RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = model; ret->session = NULL; ret->backend = backend; ret->devicestr = RedisModule_Strdup(devicestr); - ret->inputs = NULL; - ret->outputs = NULL; + ret->ninputs = ninputs; + ret->inputs = inputs_; + ret->noutputs = noutputs; + ret->outputs = outputs_; ret->opts = opts; ret->refCount = 1; ret->data = buffer; ret->datalen = modellen; - return ret; + +cleanup: + RAI_SetError(error, RAI_EMODELCREATE, error_descr); + RedisModule_Free(error_descr); + if(inputs_) { + ninputs = array_len(inputs_); + for(size_t i =0 ; i < ninputs; i++) { + RedisModule_Free(inputs_[i]); + } + array_free(inputs_); + } + if(outputs_) { + noutputs = array_len(outputs_); + for(size_t i =0 ; i < noutputs; i++) { + RedisModule_Free(outputs_[i]); + } + array_free(outputs_); + } + return NULL; } void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) { @@ -93,6 +144,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) { if (model->data) { RedisModule_Free(model->data); } + size_t ninputs = model->ninputs; + for(size_t i =0 ; i < ninputs; i++) { + RedisModule_Free(model->inputs[i]); + } + array_free(model->inputs); + + size_t noutputs = model->noutputs; + for(size_t i =0 ; i < noutputs; i++) { + RedisModule_Free(model->outputs[i]); + } + array_free(model->outputs); + torchDeallocContext(model->model); } diff --git a/src/command_parser.c b/src/command_parser.c index 2c92c14f8..0caf9bc14 100644 --- a/src/command_parser.c +++ b/src/command_parser.c @@ -71,15 +71,13 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a } if ((*model)->inputs && (*model)->ninputs != ninputs) { RAI_SetError(error, RAI_EMODELRUN, - "Number of names given as INPUTS during MODELSET and keys given as " - "INPUTS here do not match"); + "Number of keys given as INPUTS here does not match model definition"); return REDISMODULE_ERR; } if ((*model)->outputs && (*model)->noutputs != noutputs) { RAI_SetError(error, RAI_EMODELRUN, - "Number of names given as OUTPUTS during MODELSET and keys given as " - "OUTPUTS here do not match"); + "Number of keys given as OUTPUTS here does not match model definition"); return REDISMODULE_ERR; } return REDISMODULE_OK; diff --git a/src/libtflite_c/tflite_c.cpp b/src/libtflite_c/tflite_c.cpp index 637cfd469..f7c8d2dab 100644 --- a/src/libtflite_c/tflite_c.cpp +++ b/src/libtflite_c/tflite_c.cpp @@ -261,6 +261,56 @@ extern "C" void *tfliteLoadModel(const char *graph, size_t graphlen, DLDeviceTyp return ctx; } +extern "C" size_t tfliteModelNumInputs(void* ctx, char** error, void *(*alloc)(size_t)) { + ModelContext *ctx_ = (ModelContext*) ctx; + size_t ret = 0; + try { + auto interpreter = ctx_->interpreter; + ret = interpreter->inputs().size(); + } + catch(std::exception ex) { + setError(ex.what(), error, alloc); + } + return ret; +} + +extern "C" const char* tfliteModelInputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)) { + ModelContext *ctx_ = (ModelContext*) modelCtx; + const char* ret = NULL; + try { + ret = ctx_->interpreter->GetInputName(index); + } + catch(std::exception ex) { + setError(ex.what(), error, alloc); + } + return ret; +} + +extern "C" size_t tfliteModelNumOutputs(void* ctx, char** error, void *(*alloc)(size_t)) { + ModelContext *ctx_ = (ModelContext*) ctx; + size_t ret = 0; + try { + auto interpreter = ctx_->interpreter; + ret = interpreter->outputs().size(); + } + catch(std::exception ex) { + setError(ex.what(), error, alloc); + } + return ret; +} + +extern "C" const char* tfliteModelOutputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)) { + ModelContext *ctx_ = (ModelContext*) modelCtx; + const char* ret = NULL; + try { + ret = ctx_->interpreter->GetOutputName(index); + } + catch(std::exception ex) { + setError(ex.what(), error, alloc); + } + return ret; +} + extern "C" void tfliteRunModel(void *ctx, long n_inputs, DLManagedTensor **inputs, long n_outputs, DLManagedTensor **outputs, char **error, void *(*alloc)(size_t)) { ModelContext *ctx_ = (ModelContext *)ctx; diff --git a/src/libtflite_c/tflite_c.h b/src/libtflite_c/tflite_c.h index 147bbf216..55a9e37ba 100644 --- a/src/libtflite_c/tflite_c.h +++ b/src/libtflite_c/tflite_c.h @@ -20,6 +20,14 @@ void tfliteSerializeModel(void *ctx, char **buffer, size_t *len, char **error, void tfliteDeallocContext(void *ctx); +size_t tfliteModelNumInputs(void* ctx, char** error, void *(*alloc)(size_t)); + +const char* tfliteModelInputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)); + +size_t tfliteModelNumOutputs(void* ctx, char** error, void *(*alloc)(size_t)); + +const char* tfliteModelOutputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)); + #ifdef __cplusplus } #endif diff --git a/src/libtorch_c/torch_c.cpp b/src/libtorch_c/torch_c.cpp index 228d9ed9b..94a1216fd 100644 --- a/src/libtorch_c/torch_c.cpp +++ b/src/libtorch_c/torch_c.cpp @@ -437,3 +437,69 @@ extern "C" void torchSetIntraOpThreads(int num_threads, char **error, void *(*al } } } + +extern "C" size_t torchModelNumInputs(void *modelCtx, char** error) { + ModuleContext *ctx = (ModuleContext *)modelCtx; + size_t ninputs = 0; + try { + const c10::FunctionSchema& schema = ctx->module->get_method("forward").function().getSchema(); + // First argument is `self` + ninputs = schema.arguments().size() - 1; + } + catch(std::exception ex) { + int printed = asprintf(error, "Erorr while trying to retrive model inputs number: %s", ex.what()); + } + return ninputs; +} + +static int getArgumentTensorCount(const c10::Argument& arg){ + switch (arg.type()->kind()) + { + case c10::TypeKind::TensorType: + return 1; + break; + case c10::TypeKind::TupleType: { + int count = 0; + for(auto const& obj: arg.type()->containedTypes()) { + if(obj->kind() == c10::TypeKind::TensorType) { + count++; + } + } + return count; + } + case c10::TypeKind::ListType: { + return arg.N().value(); + } + + default: + return 0; + } +} + +extern "C" size_t torchModelNumOutputs(void *modelCtx, char** error) { + ModuleContext *ctx = (ModuleContext *)modelCtx; + size_t noutputs = 0; + try { + const c10::FunctionSchema& schema = ctx->module->get_method("forward").function().getSchema(); + for (auto const& arg :schema.returns()){ + noutputs += getArgumentTensorCount(arg); + } + } + catch(std::exception ex) { + int printed = asprintf(error, "Erorr while trying to retrive model outputs number: %s", ex.what()); + } + return noutputs; +} + +extern "C" const char* torchModelInputNameAtIndex(void* modelCtx, size_t index, char** error) { + ModuleContext *ctx = (ModuleContext *)modelCtx; + const char* ret = NULL; + try { + const c10::FunctionSchema& schema = ctx->module->get_method("forward").function().getSchema(); + ret = schema.arguments()[index + 1].name().c_str(); + } + catch(std::exception ex) { + int printed = asprintf(error, "Erorr while trying to retrive model intput at index %ld: %s", index, ex.what()); + } + return ret; +} diff --git a/src/libtorch_c/torch_c.h b/src/libtorch_c/torch_c.h index 96a111129..ae9b99890 100644 --- a/src/libtorch_c/torch_c.h +++ b/src/libtorch_c/torch_c.h @@ -34,6 +34,13 @@ void torchSetInterOpThreads(int num_threads, char **error, void *(*alloc)(size_t void torchSetIntraOpThreads(int num_threadsm, char **error, void *(*alloc)(size_t)); +size_t torchModelNumInputs(void *modelCtx, char **error); + +const char* torchModelInputNameAtIndex(void* modelCtx, size_t index, char** error); + +size_t torchModelNumOutputs(void *modelCtx, char** error); + + #ifdef __cplusplus } #endif diff --git a/src/model.c b/src/model.c index e9fd53a59..477043906 100644 --- a/src/model.c +++ b/src/model.c @@ -334,8 +334,6 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModu } else { model->tag = RedisModule_CreateString(NULL, "", 0); } - model->ninputs = ninputs; - model->noutputs = noutputs; } return model; diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 1aea211ee..6e0fe7362 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -38,9 +38,8 @@ def test_onnx_modelrun_mnist(env): ret = con.execute_command('AI.MODELGET', 'm{1}', 'META') env.assertEqual(len(ret), 14) env.assertEqual(ret[5], b'') - # assert there are no inputs or outputs - env.assertEqual(len(ret[11]), 0) - env.assertEqual(len(ret[13]), 0) + env.assertEqual(len(ret[11]), 1) + env.assertEqual(len(ret[13]), 1) ret = con.execute_command('AI.MODELSET', 'm{1}', 'ONNX', DEVICE, 'TAG', 'version:2', 'BLOB', model_pb) env.assertEqual(ret, b'OK') @@ -54,9 +53,8 @@ def test_onnx_modelrun_mnist(env): env.assertEqual(ret[1], b'ONNX') env.assertEqual(ret[3], b'CPU') env.assertEqual(ret[5], b'version:2') - # assert there are no inputs or outputs - env.assertEqual(len(ret[11]), 0) - env.assertEqual(len(ret[13]), 0) + env.assertEqual(len(ret[11]), 1) + env.assertEqual(len(ret[13]), 1) try: con.execute_command('AI.MODELSET', 'm{1}', 'ONNX', DEVICE, 'BLOB', wrong_model_pb) @@ -121,7 +119,7 @@ def test_onnx_modelrun_mnist(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", str(exception)) + env.assertEqual("Number of keys given as INPUTS here does not match model definition", str(exception)) try: con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'OUTPUTS') @@ -143,7 +141,7 @@ def test_onnx_modelrun_mnist(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual('Expected 1 inputs but got 2', str(exception)) + env.assertEqual('Number of keys given as INPUTS here does not match model definition', str(exception)) con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}') @@ -214,9 +212,8 @@ def test_onnx_modelrun_mnist_autobatch(env): env.assertEqual(ret[5], b'') env.assertEqual(ret[7], 2) env.assertEqual(ret[9], 2) - # assert there are no inputs or outputs - env.assertEqual(len(ret[11]), 0) - env.assertEqual(len(ret[13]), 0) + env.assertEqual(len(ret[11]), 1) + env.assertEqual(len(ret[13]), 1) con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) con.execute_command('AI.TENSORSET', 'c{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index 3ba4ea527..212b8b858 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -83,8 +83,8 @@ def test_pytorch_modelrun(env): env.assertEqual(ret[7], 0) env.assertEqual(ret[9], 0) # assert there are no inputs or outputs - env.assertEqual(len(ret[11]), 0) - env.assertEqual(len(ret[13]), 0) + env.assertEqual(len(ret[11]), 2) + env.assertEqual(len(ret[13]), 1) ret = con.execute_command('AI.MODELSET', 'm{1}', 'TORCH', DEVICE, 'TAG', 'my:tag:v3', 'BLOB', model_pb) env.assertEqual(ret, b'OK') diff --git a/tests/flow/tests_tensorflow.py b/tests/flow/tests_tensorflow.py index ecbe4f146..98c65b33f 100644 --- a/tests/flow/tests_tensorflow.py +++ b/tests/flow/tests_tensorflow.py @@ -458,7 +458,7 @@ def test_run_tf_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Number of names given as OUTPUTS during MODELSET and keys given as OUTPUTS here do not match", exception.__str__()) + env.assertEqual("Number of keys given as OUTPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm{1}', 'OUTPUTS', 'c{1}') diff --git a/tests/flow/tests_tflite.py b/tests/flow/tests_tflite.py index cb72d15c3..e03c0f56e 100644 --- a/tests/flow/tests_tflite.py +++ b/tests/flow/tests_tflite.py @@ -120,7 +120,7 @@ def test_run_tflite_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + env.assertEqual("Number of keys given as OUTPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm_2{1}') @@ -141,7 +141,7 @@ def test_run_tflite_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + env.assertEqual("Number of keys given as INPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm_2{1}', 'a{1}', 'b{1}', 'c{1}') @@ -169,28 +169,28 @@ def test_run_tflite_model_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + env.assertEqual("Number of keys given as INPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'OUTPUTS') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Inconsistent number of inputs", exception.__str__()) + env.assertEqual("Number of keys given as INPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Inconsistent number of outputs", exception.__str__()) + env.assertEqual("Number of keys given as OUTPUTS here does not match model definition", exception.__str__()) try: con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}') except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Inconsistent number of outputs", exception.__str__()) + env.assertEqual("Number of keys given as OUTPUTS here does not match model definition", exception.__str__()) # TODO: Autobatch is tricky with TFLITE because TFLITE expects a fixed batch From 21cac0c8cf3775217b2ec07be736bb104731bb67 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Sun, 3 Jan 2021 21:55:29 +0200 Subject: [PATCH 2/6] make format --- src/backends/onnxruntime.c | 22 ++++++++++----------- src/backends/tflite.c | 38 +++++++++++++++++++------------------ src/backends/torch.c | 39 +++++++++++++++++++------------------- src/libtflite_c/tflite_c.h | 10 ++++++---- src/libtorch_c/torch_c.h | 5 ++--- 5 files changed, 58 insertions(+), 56 deletions(-) diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index e5bf2bc48..0402c4b88 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -288,8 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo RAI_Device device; int64_t deviceid; - char** inputs_ = NULL; - char** outputs_ = NULL; + char **inputs_ = NULL; + char **outputs_ = NULL; if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device"); @@ -369,9 +369,9 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo OrtAllocator *allocator; status = ort->GetAllocatorWithDefaultOptions(&allocator); - inputs_ = array_new(char*, n_input_nodes); + inputs_ = array_new(char *, n_input_nodes); for (long long i = 0; i < n_input_nodes; i++) { - char* input_name; + char *input_name; status = ort->SessionGetInputName(session, i, allocator, &input_name); if (status != NULL) { goto error; @@ -381,7 +381,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo outputs_ = array_new(char *, n_output_nodes); for (long long i = 0; i < n_output_nodes; i++) { - char* output_name; + char *output_name; status = ort->SessionGetOutputName(session, i, allocator, &output_name); if (status != NULL) { goto error; @@ -413,16 +413,16 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo error: RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status)); - if(inputs_) { + if (inputs_) { n_input_nodes = array_len(inputs_); - for(uint32_t i = 0; i AllocatorFree(allocator, inputs_[i]); } array_free(inputs_); } - if(outputs_){ + if (outputs_) { n_output_nodes = array_len(outputs_); - for(uint32_t i = 0; i AllocatorFree(allocator, outputs_[i]); } array_free(outputs_); @@ -439,12 +439,12 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) { OrtAllocator *allocator; OrtStatus *status = NULL; status = ort->GetAllocatorWithDefaultOptions(&allocator); - for(uint32_t i = 0; i < model->ninputs; i++) { + for (uint32_t i = 0; i < model->ninputs; i++) { status = ort->AllocatorFree(allocator, model->inputs[i]); } array_free(model->inputs); - for(uint32_t i = 0; i < model->noutputs; i++) { + for (uint32_t i = 0; i < model->noutputs; i++) { status = ort->AllocatorFree(allocator, model->outputs[i]); } array_free(model->outputs); diff --git a/src/backends/tflite.c b/src/backends/tflite.c index 3bdbe2d1f..82a5fbdc0 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -20,8 +20,8 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI DLDeviceType dl_device; RAI_Device device; int64_t deviceid; - char** inputs_ = NULL; - char** outputs_ = NULL; + char **inputs_ = NULL; + char **outputs_ = NULL; if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device"); return NULL; @@ -50,29 +50,31 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI } size_t ninputs = tfliteModelNumInputs(model, &error_descr, RedisModule_Alloc); - if(error_descr) { + if (error_descr) { goto cleanup; } size_t noutputs = tfliteModelNumOutputs(model, &error_descr, RedisModule_Alloc); - if(error_descr) { - goto cleanup; + if (error_descr) { + goto cleanup; } - inputs_ = array_new(char*, ninputs); - outputs_ = array_new(char*, noutputs); + inputs_ = array_new(char *, ninputs); + outputs_ = array_new(char *, noutputs); for (size_t i = 0; i < ninputs; i++) { - const char* input = tfliteModelInputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); - if(error_descr) { + const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); + if (error_descr) { goto cleanup; } inputs_ = array_append(inputs_, RedisModule_Strdup(input)); } - for (size_t i = 0; i < noutputs; i++) { - const char* output = tfliteModelOutputNameAtIndex(model, i, &error_descr, RedisModule_Alloc);; - if(error_descr) { + for (size_t i = 0; i < noutputs; i++) { + const char *output = + tfliteModelOutputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); + ; + if (error_descr) { goto cleanup; } outputs_ = array_append(outputs_, RedisModule_Strdup(output)); @@ -99,16 +101,16 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI cleanup: RAI_SetError(error, RAI_EMODELCREATE, error_descr); RedisModule_Free(error_descr); - if(inputs_) { + if (inputs_) { ninputs = array_len(inputs_); - for(size_t i =0 ; i < ninputs; i++) { + for (size_t i = 0; i < ninputs; i++) { RedisModule_Free(inputs_[i]); } array_free(inputs_); } - if(outputs_) { + if (outputs_) { noutputs = array_len(outputs_); - for(size_t i =0 ; i < noutputs; i++) { + for (size_t i = 0; i < noutputs; i++) { RedisModule_Free(outputs_[i]); } array_free(outputs_); @@ -121,13 +123,13 @@ void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) { RedisModule_Free(model->devicestr); tfliteDeallocContext(model->model); size_t ninputs = model->ninputs; - for(size_t i =0 ; i < ninputs; i++) { + for (size_t i = 0; i < ninputs; i++) { RedisModule_Free(model->inputs[i]); } array_free(model->inputs); size_t noutputs = model->noutputs; - for(size_t i =0 ; i < noutputs; i++) { + for (size_t i = 0; i < noutputs; i++) { RedisModule_Free(model->outputs[i]); } array_free(model->outputs); diff --git a/src/backends/torch.c b/src/backends/torch.c index 3b84e508c..319367c3b 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -22,8 +22,8 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ RAI_Device device = RAI_DEVICE_CPU; int64_t deviceid = 0; - char** inputs_ = NULL; - char** outputs_ = NULL; + char **inputs_ = NULL; + char **outputs_ = NULL; if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device"); @@ -66,33 +66,33 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc); if (error_descr) { - goto cleanup; + goto cleanup; } size_t ninputs = torchModelNumInputs(model, &error_descr); - if(error_descr) { + if (error_descr) { goto cleanup; } size_t noutputs = torchModelNumOutputs(model, &error_descr); - if(error_descr) { - goto cleanup; + if (error_descr) { + goto cleanup; } - inputs_ = array_new(char*, ninputs); - outputs_ = array_new(char*, noutputs); + inputs_ = array_new(char *, ninputs); + outputs_ = array_new(char *, noutputs); for (size_t i = 0; i < ninputs; i++) { - const char* input = torchModelInputNameAtIndex(model, i, &error_descr); - if(error_descr) { + const char *input = torchModelInputNameAtIndex(model, i, &error_descr); + if (error_descr) { goto cleanup; } inputs_ = array_append(inputs_, RedisModule_Strdup(input)); } - for (size_t i = 0; i < noutputs; i++) { - const char* output =""; - if(error_descr) { + for (size_t i = 0; i < noutputs; i++) { + const char *output = ""; + if (error_descr) { goto cleanup; } outputs_ = array_append(outputs_, RedisModule_Strdup(output)); @@ -101,7 +101,6 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); memcpy(buffer, modeldef, modellen); - RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = model; ret->session = NULL; @@ -120,16 +119,16 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ cleanup: RAI_SetError(error, RAI_EMODELCREATE, error_descr); RedisModule_Free(error_descr); - if(inputs_) { + if (inputs_) { ninputs = array_len(inputs_); - for(size_t i =0 ; i < ninputs; i++) { + for (size_t i = 0; i < ninputs; i++) { RedisModule_Free(inputs_[i]); } array_free(inputs_); } - if(outputs_) { + if (outputs_) { noutputs = array_len(outputs_); - for(size_t i =0 ; i < noutputs; i++) { + for (size_t i = 0; i < noutputs; i++) { RedisModule_Free(outputs_[i]); } array_free(outputs_); @@ -145,13 +144,13 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) { RedisModule_Free(model->data); } size_t ninputs = model->ninputs; - for(size_t i =0 ; i < ninputs; i++) { + for (size_t i = 0; i < ninputs; i++) { RedisModule_Free(model->inputs[i]); } array_free(model->inputs); size_t noutputs = model->noutputs; - for(size_t i =0 ; i < noutputs; i++) { + for (size_t i = 0; i < noutputs; i++) { RedisModule_Free(model->outputs[i]); } array_free(model->outputs); diff --git a/src/libtflite_c/tflite_c.h b/src/libtflite_c/tflite_c.h index 55a9e37ba..bab455a44 100644 --- a/src/libtflite_c/tflite_c.h +++ b/src/libtflite_c/tflite_c.h @@ -20,13 +20,15 @@ void tfliteSerializeModel(void *ctx, char **buffer, size_t *len, char **error, void tfliteDeallocContext(void *ctx); -size_t tfliteModelNumInputs(void* ctx, char** error, void *(*alloc)(size_t)); +size_t tfliteModelNumInputs(void *ctx, char **error, void *(*alloc)(size_t)); -const char* tfliteModelInputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)); +const char *tfliteModelInputNameAtIndex(void *modelCtx, size_t index, char **error, + void *(*alloc)(size_t)); -size_t tfliteModelNumOutputs(void* ctx, char** error, void *(*alloc)(size_t)); +size_t tfliteModelNumOutputs(void *ctx, char **error, void *(*alloc)(size_t)); -const char* tfliteModelOutputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)); +const char *tfliteModelOutputNameAtIndex(void *modelCtx, size_t index, char **error, + void *(*alloc)(size_t)); #ifdef __cplusplus } diff --git a/src/libtorch_c/torch_c.h b/src/libtorch_c/torch_c.h index ae9b99890..45e6bfc99 100644 --- a/src/libtorch_c/torch_c.h +++ b/src/libtorch_c/torch_c.h @@ -36,10 +36,9 @@ void torchSetIntraOpThreads(int num_threadsm, char **error, void *(*alloc)(size_ size_t torchModelNumInputs(void *modelCtx, char **error); -const char* torchModelInputNameAtIndex(void* modelCtx, size_t index, char** error); - -size_t torchModelNumOutputs(void *modelCtx, char** error); +const char *torchModelInputNameAtIndex(void *modelCtx, size_t index, char **error); +size_t torchModelNumOutputs(void *modelCtx, char **error); #ifdef __cplusplus } From 452bf1a814ec7a14a18e5be9dafcce18e56d77a2 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Wed, 13 Jan 2021 09:29:19 +0200 Subject: [PATCH 3/6] fixed tests --- tests/flow/test_serializations.py | 8 ++++---- tests/flow/tests_dag.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/flow/test_serializations.py b/tests/flow/test_serializations.py index 765aa00d3..bffd90d6d 100644 --- a/tests/flow/test_serializations.py +++ b/tests/flow/test_serializations.py @@ -78,7 +78,7 @@ def test_v0_torch_model(self): model_rdb = b'\x07\x81\x00\x8f\xff0\xe0\xc4,\x00\x02\x02\x05\x04CPU\x00\x05\x0ePT_MINIMAL_V0\x00\x02\x00\x02\x00\x02\x00\x02\x00\x05\xc3C\x0eEH\x0ePK\x03\x04\x00\x00\x08\x08\x00\x00\x86\xb0zO\x00\xe0\x02\x00\x1a\x12\x00\x10\x00pt-minimal/versionFB\x0c\x00Z\xe0\x02\x00\n1\nPK\x07\x08S\xfcQg\x02 ;@\x03\x00P Q\x00\x14 Q\x00\x08\xe0\x08Q\x02\x1c\x004\xe0\x03Q\x13code/__torch__.pyFB0\xe0\x04[\xe0\x1b\x00\x1f5LK\n\x830\x10\xdd{\x8a\xb7T\xb0\x82\xdb\x80\xbd\x81\xbb\xeeJ\t\xa3\x19\xab\x90fd\x12[z\xfb\x1f\x06\xad\xab\xf7\x7f\xb2\xda7k\\$\xd8\xc8\t\x1d\xdab\xf4\x14#\xfao/n\xf3\\\x1eP\x99\x02\xb0v\x1f%\xa5\x17\xa7\xbc\xb06\x97\xef\x8f\xec&\xa5%,\xe1\t\x83A\xc4g\xc7\xf1\x84I\xf4C\xea\xca\xc8~2\x1fy\x99D\xc7\xd9\xda\xe6\xfc\xads\x0f \x83\x1b\x87(z\xc8\xe1\x94\x15.\xd7?5{\xa2\x9c6\r\xd8_\r\x1ar\xae\xa4\x1aC\r\xf2\xebL][\x15?A\x0b\x04a\xc1#I\x8e!\x0b\x00\xc8 \x03\xe1\x11\x0b\x02&\x00\x1e\xe1\x14\x0b\x0c.debug_pklFB\x1a\xe1\x12\x15\x1f5\x8eA\n\xc20\x10EcU\x90\x82+/0\xcb\x8a%\x07p\xe5V\x06t\xdb\x9d\xa4mB"m\xd3\x1f\xa4\x11q\xe7\xca\x1e\xc7S\xa8\xd72)\xe4m\x06\xde\x87?\xff\x99d\xb8)\x88\xc7\x10I\x90\x8cf\x86\xe1\x1f\xbc\xf0]\x9c\xbd\x05\xcf\xc1i[IzU\x8e\x0e\x95U\xbd\xbb\xb4\xdcI]\xa7!\xac\xb9\x00\xa1\xed\x9d\xd9\x1f:\x1bx#r`9\x94\xdb\xfd\x14\x06,w7\xdb\x01\x83\x1d\x94\xa9I\x8a\xb5o\r\x15\xaaS-kh\x15\xff0s\\\x8df\x81G<\xf9\xb7\x1f\x19\x07|et\xbf\xe8\x9cY\xd2a\x08\x04\xa7\x94\x1a\x02\x97!\x04\x00\xb2 \x03A\x08\xe2\rf\x02\x18\x00#\xe1\x05\x08\x07nstants.`\xfa\x00\x1f\xe0\x12\xfa`\x00\x03\x80\x02).Au\x03m/\tW a\x00\x00@\x03\xe0\x11l\x02\x13\x00;\xe0\x03l\x03data\x80g\x007\xe0\x17g\xe0\x0f\x00\x02\x80\x02c\xe2\x00\xc2\x10\nMyModule\nq\x00)\x81}(X#U\x0f\x00trainingq\x01\x88ubq\x02`\xac\x04z\xb8\x18\x811 \x1b@\x03\x02PK\x01C5#0\x83\x82\xe3\x03J\x00\x12 \x17\xe0\x05\x00\xe3\t\x90\x80?\xe3\x01p\xe2\x03~\x00\x1c\xe0\x04<\x00R \r\xe0\x02?\xe3\x08~\xe0\x07I\xe1\x03\xbf\x00& ;\xe0\x01\x00\x01^\x01\xe0\x15I\xe2\x01\xbc\x80S\xe0\x01\xdd\xe1\x03\xa6\x00\x18 \x17\xe0\x01\x00D?\xe0\x04\x9d\xe2\x02\x07\xe0\x07E\xe1\x03?\x00\x13\xe0\x01B \x00\x00\xd4!K\xe0\x02E\xc1\xe0\x04PK\x06\x06, \x1e@\x00\x02\x1e\x03-@\x06`\x00\x00\x05`\x05\xe0\x01\x07\x00e \xd8@\x00\x01\x81\x03@\x05A\x9c\x01\x06\x07 \x06\x01\x00\xe6BV \x00@\x1e\x03PK\x05\x06 \n ;\x00\x05`/@+\x01\x00\x00\x00\t\x00MQ\xab\x8e\xfdc\x97>' con.restore(key_name, 0, model_rdb, True) _, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs = con.execute_command("AI.MODELGET", key_name, "META") - self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"TORCH", b"CPU", b"PT_MINIMAL_V0", 0, 0, [], []]) + self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"TORCH", b"CPU", b"PT_MINIMAL_V0", 0, 0, [b'a', b'b'], [b'']]) torch_model_run(self.env, key_name) def test_v0_troch_script(self): @@ -96,7 +96,7 @@ def test_v0_onnx_model(self): model_rdb = b'\x07\x81\x00\x8f\xff0\xe0\xc4,\x00\x02\x03\x05\x04CPU\x00\x05\x14ONNX_LINEAR_IRIS_V0\x00\x02\x00\x02\x00\x02\x00\x02\x00\x05\xc3@\xe6A\x15\x17\x08\x05\x12\x08skl2onnx\x1a\x051.4.9"\x07ai.@\x0f\x1f(\x002\x00:\xe2\x01\n\x82\x01\n\x0bfloat_input\x12\x08variabl\x12e\x1a\x0fLinearRegressor"\xe0\x07\x10\x1f*%\n\x0ccoefficients=K\xfe\xc2\xbd=\xf7\xbe\x1c\xbd=/ii>=\x12\xe81\x1a?\xa0\x01\x06*\x14\n\nintercep $\x03\xa8\x1d\xb7= \x15\x01:\n\xa0\x88\x1f.ml\x12 2d76caf265cd4138a74199640a1\x06fc408Z\x1d\xe0\x05\xa5\n\x0e\n\x0c\x08\x01\x12\x08\n\x02\x08\x01 \x03\x03\x04b\x1a\n\xe0\x00\xb7\xe0\x06\x1b\x03\x01B\x0e\n\xe0\x02j\x01\x10\x01\x00\t\x00\x04EU\x04\xd8\\\xdb\x99' con.restore(key_name, 0, model_rdb, True) _, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs = con.execute_command("AI.MODELGET", key_name, "META") - self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS_V0", 0, 0, [], []]) + self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS_V0", 0, 0, [b'float_input'], [b'variable']]) onnx_model_run(self.env, key_name) def test_v0_tensor(self): @@ -130,7 +130,7 @@ def test_v1_torch_model(self): model_rdb = b'\x07\x81\x00\x8f\xff0\xe0\xc4,\x01\x02\x02\x05\x04CPU\x00\x05\rPT_MINIMAL_V1\x02\x00\x02\x00\x02\x00\x02\x00\x02EH\x02\x01\x05\xc3C\x0eEH\x0ePK\x03\x04\x00\x00\x08\x08\x00\x00\x86\xb0zO\x00\xe0\x02\x00\x1a\x12\x00\x10\x00pt-minimal/versionFB\x0c\x00Z\xe0\x02\x00\n1\nPK\x07\x08S\xfcQg\x02 ;@\x03\x00P Q\x00\x14 Q\x00\x08\xe0\x08Q\x02\x1c\x004\xe0\x03Q\x13code/__torch__.pyFB0\xe0\x04[\xe0\x1b\x00\x1f5LK\n\x830\x10\xdd{\x8a\xb7T\xb0\x82\xdb\x80\xbd\x81\xbb\xeeJ\t\xa3\x19\xab\x90fd\x12[z\xfb\x1f\x06\xad\xab\xf7\x7f\xb2\xda7k\\$\xd8\xc8\t\x1d\xdab\xf4\x14#\xfao/n\xf3\\\x1eP\x99\x02\xb0v\x1f%\xa5\x17\xa7\xbc\xb06\x97\xef\x8f\xec&\xa5%,\xe1\t\x83A\xc4g\xc7\xf1\x84I\xf4C\xea\xca\xc8~2\x1fy\x99D\xc7\xd9\xda\xe6\xfc\xads\x0f \x83\x1b\x87(z\xc8\xe1\x94\x15.\xd7?5{\xa2\x9c6\r\xd8_\r\x1ar\xae\xa4\x1aC\r\xf2\xebL][\x15?A\x0b\x04a\xc1#I\x8e!\x0b\x00\xc8 \x03\xe1\x11\x0b\x02&\x00\x1e\xe1\x14\x0b\x0c.debug_pklFB\x1a\xe1\x12\x15\x1f5\x8eA\n\xc20\x10EcU\x90\x82+/0\xcb\x8a%\x07p\xe5V\x06t\xdb\x9d\xa4mB"m\xd3\x1f\xa4\x11q\xe7\xca\x1e\xc7S\xa8\xd72)\xe4m\x06\xde\x87?\xff\x99d\xb8)\x88\xc7\x10I\x90\x8cf\x86\xe1\x1f\xbc\xf0]\x9c\xbd\x05\xcf\xc1i[IzU\x8e\x0e\x95U\xbd\xbb\xb4\xdcI]\xa7!\xac\xb9\x00\xa1\xed\x9d\xd9\x1f:\x1bx#r`9\x94\xdb\xfd\x14\x06,w7\xdb\x01\x83\x1d\x94\xa9I\x8a\xb5o\r\x15\xaaS-kh\x15\xff0s\\\x8df\x81G<\xf9\xb7\x1f\x19\x07|et\xbf\xe8\x9cY\xd2a\x08\x04\xa7\x94\x1a\x02\x97!\x04\x00\xb2 \x03A\x08\xe2\rf\x02\x18\x00#\xe1\x05\x08\x07nstants.`\xfa\x00\x1f\xe0\x12\xfa`\x00\x03\x80\x02).Au\x03m/\tW a\x00\x00@\x03\xe0\x11l\x02\x13\x00;\xe0\x03l\x03data\x80g\x007\xe0\x17g\xe0\x0f\x00\x02\x80\x02c\xe2\x00\xc2\x10\nMyModule\nq\x00)\x81}(X#U\x0f\x00trainingq\x01\x88ubq\x02`\xac\x04z\xb8\x18\x811 \x1b@\x03\x02PK\x01C5#0\x83\x82\xe3\x03J\x00\x12 \x17\xe0\x05\x00\xe3\t\x90\x80?\xe3\x01p\xe2\x03~\x00\x1c\xe0\x04<\x00R \r\xe0\x02?\xe3\x08~\xe0\x07I\xe1\x03\xbf\x00& ;\xe0\x01\x00\x01^\x01\xe0\x15I\xe2\x01\xbc\x80S\xe0\x01\xdd\xe1\x03\xa6\x00\x18 \x17\xe0\x01\x00D?\xe0\x04\x9d\xe2\x02\x07\xe0\x07E\xe1\x03?\x00\x13\xe0\x01B \x00\x00\xd4!K\xe0\x02E\xc1\xe0\x04PK\x06\x06, \x1e@\x00\x02\x1e\x03-@\x06`\x00\x00\x05`\x05\xe0\x01\x07\x00e \xd8@\x00\x01\x81\x03@\x05A\x9c\x01\x06\x07 \x06\x01\x00\xe6BV \x00@\x1e\x03PK\x05\x06 \n ;\x00\x05`/@+\x01\x00\x00\x00\t\x00\xa4D\x13\x90\xf6\\x@' con.restore(key_name, 0, model_rdb, True) _, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs = con.execute_command("AI.MODELGET", key_name, "META") - self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"TORCH", b"CPU", b"PT_MINIMAL_V1", 0, 0, [], []]) + self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"TORCH", b"CPU", b"PT_MINIMAL_V1", 0, 0, [b'a', b'b'], [b'']]) torch_model_run(self.env, key_name) @@ -149,7 +149,7 @@ def test_v1_onnx_model(self): model_rdb = b'\x07\x81\x00\x8f\xff0\xe0\xc4,\x01\x02\x03\x05\x04CPU\x00\x05\x13ONNX_LINEAR_IRIS_V1\x02\x00\x02\x00\x02\x00\x02\x00\x02A\x15\x02\x01\x05\xc3@\xe6A\x15\x17\x08\x05\x12\x08skl2onnx\x1a\x051.4.9"\x07ai.@\x0f\x1f(\x002\x00:\xe2\x01\n\x82\x01\n\x0bfloat_input\x12\x08variabl\x12e\x1a\x0fLinearRegressor"\xe0\x07\x10\x1f*%\n\x0ccoefficients=K\xfe\xc2\xbd=\xf7\xbe\x1c\xbd=/ii>=\x12\xe81\x1a?\xa0\x01\x06*\x14\n\nintercep $\x03\xa8\x1d\xb7= \x15\x01:\n\xa0\x88\x1f.ml\x12 2d76caf265cd4138a74199640a1\x06fc408Z\x1d\xe0\x05\xa5\n\x0e\n\x0c\x08\x01\x12\x08\n\x02\x08\x01 \x03\x03\x04b\x1a\n\xe0\x00\xb7\xe0\x06\x1b\x03\x01B\x0e\n\xe0\x02j\x01\x10\x01\x00\t\x00\xd4\x0f\xa0F\x851\xdb\xa0' con.restore(key_name, 0, model_rdb, True) _, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs = con.execute_command("AI.MODELGET", key_name, "META") - self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS_V1", 0, 0, [], []]) + self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, inputs, outputs], [b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS_V1", 0, 0, [b'float_input'], [b'variable']]) onnx_model_run(self.env, key_name) def test_v1_tensor(self): diff --git a/tests/flow/tests_dag.py b/tests/flow/tests_dag.py index 87c6d3edd..4d574e116 100644 --- a/tests/flow/tests_dag.py +++ b/tests/flow/tests_dag.py @@ -315,7 +315,7 @@ def test_dag_modelrun_financialNet_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Number of names given as INPUTS during MODELSET and keys given as INPUTS here do not match",exception.__str__()) + env.assertEqual("Number of keys given as INPUTS here does not match model definition",exception.__str__()) def test_dag_local_tensorset(env): @@ -941,7 +941,7 @@ def test_dagrun_modelrun_multidevice_resnet(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("Number of names given as INPUTS during MODELSET and keys given as INPUTS here do not match", exception.__str__()) + env.assertEqual("Number of keys given as INPUTS here does not match model definition", exception.__str__()) ret = con.execute_command( 'AI.DAGRUN', From 0f58880faa80fc506ed195c89bcb873ff54b40f3 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Wed, 13 Jan 2021 09:38:59 +0200 Subject: [PATCH 4/6] added test for torch model with tuple output --- tests/flow/tests_pytorch.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index 212b8b858..1116f08f5 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -962,3 +962,24 @@ def test_parallelism(): for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2") + +def test_modelget_for_tuple_output(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + con = env.getConnection() + + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + model_filename = os.path.join(test_data_path, 'pt-minimal-bb.pt') + with open(model_filename, 'rb') as f: + model_pb = f.read() + ret = con.execute_command('AI.MODELSET', 'm{1}', 'TORCH', DEVICE, 'BLOB', model_pb) + ensureSlaveSynced(con, env) + env.assertEqual(b'OK', ret) + ret = con.execute_command('AI.MODELGET', 'm{1}', 'META') + env.assertEqual(ret[1], b'TORCH') + env.assertEqual(ret[5], b'') + env.assertEqual(ret[7], 0) + env.assertEqual(ret[9], 0) + env.assertEqual(len(ret[11]), 2) + env.assertEqual(len(ret[13]), 2) From 887d7acb61242b97b38ef82ad5583d0281c94ec0 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Wed, 13 Jan 2021 09:43:39 +0200 Subject: [PATCH 5/6] fixed PR comment --- src/command_parser.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/command_parser.c b/src/command_parser.c index 82989d84f..adc8cc542 100644 --- a/src/command_parser.c +++ b/src/command_parser.c @@ -68,13 +68,13 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a } } } - if ((*model)->inputs && (*model)->ninputs != ninputs) { + if ((*model)->ninputs != ninputs) { RAI_SetError(error, RAI_EMODELRUN, "Number of keys given as INPUTS here does not match model definition"); return REDISMODULE_ERR; } - if ((*model)->outputs && (*model)->noutputs != noutputs) { + if ((*model)->noutputs != noutputs) { RAI_SetError(error, RAI_EMODELRUN, "Number of keys given as OUTPUTS here does not match model definition"); return REDISMODULE_ERR; From 21b39ccc2169db383e921acc064e444aa4725947 Mon Sep 17 00:00:00 2001 From: DvirDukhan Date: Wed, 13 Jan 2021 12:58:51 +0200 Subject: [PATCH 6/6] fixed tflite files --- src/backends/tflite.c | 9 ++++----- src/libtflite_c/tflite_c.cpp | 16 ++++++++-------- src/libtflite_c/tflite_c.h | 10 ++++------ 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/backends/tflite.c b/src/backends/tflite.c index 865fcd525..3fe88b6e0 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -48,12 +48,12 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI return NULL; } - size_t ninputs = tfliteModelNumInputs(model, &error_descr, RedisModule_Alloc); + size_t ninputs = tfliteModelNumInputs(model, &error_descr); if (error_descr) { goto cleanup; } - size_t noutputs = tfliteModelNumOutputs(model, &error_descr, RedisModule_Alloc); + size_t noutputs = tfliteModelNumOutputs(model, &error_descr); if (error_descr) { goto cleanup; } @@ -62,7 +62,7 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI outputs_ = array_new(char *, noutputs); for (size_t i = 0; i < ninputs; i++) { - const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); + const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr); if (error_descr) { goto cleanup; } @@ -70,8 +70,7 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI } for (size_t i = 0; i < noutputs; i++) { - const char *output = - tfliteModelOutputNameAtIndex(model, i, &error_descr, RedisModule_Alloc); + const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr); ; if (error_descr) { goto cleanup; diff --git a/src/libtflite_c/tflite_c.cpp b/src/libtflite_c/tflite_c.cpp index 82a490a2b..167beb508 100644 --- a/src/libtflite_c/tflite_c.cpp +++ b/src/libtflite_c/tflite_c.cpp @@ -263,7 +263,7 @@ extern "C" void *tfliteLoadModel(const char *graph, size_t graphlen, DLDeviceTyp return ctx; } -extern "C" size_t tfliteModelNumInputs(void* ctx, char** error, void *(*alloc)(size_t)) { +extern "C" size_t tfliteModelNumInputs(void* ctx, char** error) { ModelContext *ctx_ = (ModelContext*) ctx; size_t ret = 0; try { @@ -271,24 +271,24 @@ extern "C" size_t tfliteModelNumInputs(void* ctx, char** error, void *(*alloc)(s ret = interpreter->inputs().size(); } catch(std::exception ex) { - setError(ex.what(), error, alloc); + _setError(ex.what(), error); } return ret; } -extern "C" const char* tfliteModelInputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)) { +extern "C" const char* tfliteModelInputNameAtIndex(void* modelCtx, size_t index, char** error) { ModelContext *ctx_ = (ModelContext*) modelCtx; const char* ret = NULL; try { ret = ctx_->interpreter->GetInputName(index); } catch(std::exception ex) { - setError(ex.what(), error, alloc); + _setError(ex.what(), error); } return ret; } -extern "C" size_t tfliteModelNumOutputs(void* ctx, char** error, void *(*alloc)(size_t)) { +extern "C" size_t tfliteModelNumOutputs(void* ctx, char** error) { ModelContext *ctx_ = (ModelContext*) ctx; size_t ret = 0; try { @@ -296,19 +296,19 @@ extern "C" size_t tfliteModelNumOutputs(void* ctx, char** error, void *(*alloc)( ret = interpreter->outputs().size(); } catch(std::exception ex) { - setError(ex.what(), error, alloc); + _setError(ex.what(), error); } return ret; } -extern "C" const char* tfliteModelOutputNameAtIndex(void* modelCtx, size_t index, char** error, void *(*alloc)(size_t)) { +extern "C" const char* tfliteModelOutputNameAtIndex(void* modelCtx, size_t index, char** error) { ModelContext *ctx_ = (ModelContext*) modelCtx; const char* ret = NULL; try { ret = ctx_->interpreter->GetOutputName(index); } catch(std::exception ex) { - setError(ex.what(), error, alloc); + _setError(ex.what(), error); } return ret; } diff --git a/src/libtflite_c/tflite_c.h b/src/libtflite_c/tflite_c.h index 16fc0bcc5..3368e6db9 100644 --- a/src/libtflite_c/tflite_c.h +++ b/src/libtflite_c/tflite_c.h @@ -19,15 +19,13 @@ void tfliteSerializeModel(void *ctx, char **buffer, size_t *len, char **error); void tfliteDeallocContext(void *ctx); -size_t tfliteModelNumInputs(void *ctx, char **error, void *(*alloc)(size_t)); +size_t tfliteModelNumInputs(void *ctx, char **error); -const char *tfliteModelInputNameAtIndex(void *modelCtx, size_t index, char **error, - void *(*alloc)(size_t)); +const char *tfliteModelInputNameAtIndex(void *modelCtx, size_t index, char **error); -size_t tfliteModelNumOutputs(void *ctx, char **error, void *(*alloc)(size_t)); +size_t tfliteModelNumOutputs(void *ctx, char **error); -const char *tfliteModelOutputNameAtIndex(void *modelCtx, size_t index, char **error, - void *(*alloc)(size_t)); +const char *tfliteModelOutputNameAtIndex(void *modelCtx, size_t index, char **error); #ifdef __cplusplus }