diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 7281e419e..7b12d8062 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -285,11 +285,6 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_ return NULL; } -typedef struct RAI_ONNXBuffer { - char* data; - size_t len; -} RAI_ONNXBuffer; - OrtEnv* env = NULL; RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, @@ -368,10 +363,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); memcpy(buffer, modeldef, modellen); - RAI_ONNXBuffer* onnxbuffer = RedisModule_Calloc(1, sizeof(*onnxbuffer)); - onnxbuffer->data = buffer; - onnxbuffer->len = modellen; - RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = NULL; ret->session = session; @@ -379,7 +370,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo ret->devicestr = RedisModule_Strdup(devicestr); ret->refCount = 1; ret->opts = opts; - ret->data = onnxbuffer; + ret->data = buffer; + ret->datalen = modellen; return ret; @@ -392,7 +384,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo void RAI_ModelFreeORT(RAI_Model* model, RAI_Error* error) { const OrtApi* ort = OrtGetApiBase()->GetApi(1); - RedisModule_Free(((RAI_ONNXBuffer*)(model->data))->data); RedisModule_Free(model->data); RedisModule_Free(model->devicestr); ort->ReleaseSession(model->session); @@ -570,10 +561,9 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) } int RAI_ModelSerializeORT(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) { - RAI_ONNXBuffer* onnxbuffer = (RAI_ONNXBuffer*)model->data; - *buffer = RedisModule_Calloc(onnxbuffer->len, sizeof(char)); - memcpy(*buffer, onnxbuffer->data, onnxbuffer->len); - *len = onnxbuffer->len; + *buffer = RedisModule_Calloc(model->datalen, sizeof(char)); + memcpy(*buffer, model->data, model->datalen); + *len = model->datalen; return 0; } diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 4444ae267..750206679 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -238,13 +238,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod TF_ImportGraphDefOptions* options = TF_NewImportGraphDefOptions(); - TF_Buffer *buffer = TF_NewBuffer(); - buffer->length = modellen; - buffer->data = modeldef; + TF_Buffer *tfbuffer = TF_NewBuffer(); + tfbuffer->length = modellen; + tfbuffer->data = modeldef; TF_Status *status = TF_NewStatus(); - TF_GraphImportGraphDef(model, buffer, options, status); + TF_GraphImportGraphDef(model, tfbuffer, options, status); if (TF_GetCode(status) != TF_OK) { char* errorMessage = RedisModule_Strdup(TF_Message(status)); @@ -276,7 +276,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod } TF_DeleteImportGraphDefOptions(options); - TF_DeleteBuffer(buffer); + TF_DeleteBuffer(tfbuffer); TF_DeleteStatus(status); TF_Status *optionsStatus = TF_NewStatus(); @@ -394,6 +394,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod array_append(outputs_, RedisModule_Strdup(outputs[i])); } + char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); + memcpy(buffer, modeldef, modellen); + RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = model; ret->session = session; @@ -403,7 +406,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod ret->outputs = outputs_; ret->opts = opts; ret->refCount = 1; - + ret->data = buffer; + ret->datalen = modellen; + return ret; } @@ -445,6 +450,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) { array_free(model->outputs); } + if (model->data) { + RedisModule_Free(model->data); + } + TF_DeleteStatus(status); } @@ -534,24 +543,32 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) { } int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) { - TF_Buffer *tf_buffer = TF_NewBuffer(); - TF_Status *status = TF_NewStatus(); - TF_GraphToGraphDef(model->model, tf_buffer, status); + if (model->data) { + *buffer = RedisModule_Calloc(model->datalen, sizeof(char)); + memcpy(*buffer, model->data, model->datalen); + *len = model->datalen; + } + else { + TF_Buffer *tf_buffer = TF_NewBuffer(); + TF_Status *status = TF_NewStatus(); + + TF_GraphToGraphDef(model->model, tf_buffer, status); + + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model"); + TF_DeleteBuffer(tf_buffer); + TF_DeleteStatus(status); + return 1; + } + + *buffer = RedisModule_Alloc(tf_buffer->length); + memcpy(*buffer, tf_buffer->data, tf_buffer->length); + *len = tf_buffer->length; - if (TF_GetCode(status) != TF_OK) { - RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model"); TF_DeleteBuffer(tf_buffer); TF_DeleteStatus(status); - return 1; } - *buffer = RedisModule_Alloc(tf_buffer->length); - memcpy(*buffer, tf_buffer->data, tf_buffer->length); - *len = tf_buffer->length; - - TF_DeleteBuffer(tf_buffer); - TF_DeleteStatus(status); - return 0; } diff --git a/src/backends/tflite.c b/src/backends/tflite.c index 78f83e294..8e921daa2 100644 --- a/src/backends/tflite.c +++ b/src/backends/tflite.c @@ -14,11 +14,6 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) { return REDISMODULE_OK; } -typedef struct RAI_TfLiteBuffer { - char* data; - size_t len; -} RAI_TfLiteBuffer; - RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *error) { @@ -55,10 +50,6 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); memcpy(buffer, modeldef, modellen); - RAI_TfLiteBuffer* tflitebuffer = RedisModule_Calloc(1, sizeof(*tflitebuffer)); - tflitebuffer->data = buffer; - tflitebuffer->len = modellen; - RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = model; ret->session = NULL; @@ -68,13 +59,13 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI ret->outputs = NULL; ret->refCount = 1; ret->opts = opts; - ret->data = tflitebuffer; + ret->data = buffer; + ret->datalen = modellen; return ret; } void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) { - RedisModule_Free(((RAI_TfLiteBuffer*)(model->data))->data); RedisModule_Free(model->data); RedisModule_Free(model->devicestr); tfliteDeallocContext(model->model); @@ -178,10 +169,9 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx** mctxs, RAI_Error *error) { } int RAI_ModelSerializeTFLite(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) { - RAI_TfLiteBuffer* tflitebuffer = (RAI_TfLiteBuffer*)model->data; - *buffer = RedisModule_Calloc(tflitebuffer->len, sizeof(char)); - memcpy(*buffer, tflitebuffer->data, tflitebuffer->len); - *len = tflitebuffer->len; + *buffer = RedisModule_Calloc(model->datalen, sizeof(char)); + memcpy(*buffer, model->data, model->datalen); + *len = model->datalen; return 0; } diff --git a/src/backends/torch.c b/src/backends/torch.c index 671192a65..61c676fc4 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -48,6 +48,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_ return NULL; } + char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer)); + memcpy(buffer, modeldef, modellen); + RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret)); ret->model = model; ret->session = NULL; @@ -57,7 +60,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_ ret->outputs = NULL; ret->opts = opts; ret->refCount = 1; - + ret->data = buffer; + ret->datalen = modellen; + return ret; } @@ -65,6 +70,9 @@ void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) { if(model->devicestr){ RedisModule_Free(model->devicestr); } + if (model->data) { + RedisModule_Free(model->data); + } torchDeallocContext(model->model); } @@ -157,13 +165,21 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) { } int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) { - char* error_descr = NULL; - torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc); - if (*buffer == NULL) { - RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr); - RedisModule_Free(error_descr); - return 1; + if (model->data) { + *buffer = RedisModule_Calloc(model->datalen, sizeof(char)); + memcpy(*buffer, model->data, model->datalen); + *len = model->datalen; + } + else { + char* error_descr = NULL; + torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc); + + if (*buffer == NULL) { + RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr); + RedisModule_Free(error_descr); + return 1; + } } return 0; diff --git a/src/model_struct.h b/src/model_struct.h index 1c0381a89..3b7024246 100644 --- a/src/model_struct.h +++ b/src/model_struct.h @@ -29,7 +29,8 @@ typedef struct RAI_Model { char **outputs; size_t noutputs; long long refCount; - void* data; + char* data; + long long datalen; void* infokey; } RAI_Model;