Skip to content

Cache model blobs for faster serialization and thread-safety #331

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Apr 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,6 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
return NULL;
}

typedef struct RAI_ONNXBuffer {
char* data;
size_t len;
} RAI_ONNXBuffer;

OrtEnv* env = NULL;

RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts,
Expand Down Expand Up @@ -368,18 +363,15 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

RAI_ONNXBuffer* onnxbuffer = RedisModule_Calloc(1, sizeof(*onnxbuffer));
onnxbuffer->data = buffer;
onnxbuffer->len = modellen;

RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
ret->model = NULL;
ret->session = session;
ret->backend = backend;
ret->devicestr = RedisModule_Strdup(devicestr);
ret->refCount = 1;
ret->opts = opts;
ret->data = onnxbuffer;
ret->data = buffer;
ret->datalen = modellen;

return ret;

Expand All @@ -392,7 +384,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo
void RAI_ModelFreeORT(RAI_Model* model, RAI_Error* error) {
const OrtApi* ort = OrtGetApiBase()->GetApi(1);

RedisModule_Free(((RAI_ONNXBuffer*)(model->data))->data);
RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
ort->ReleaseSession(model->session);
Expand Down Expand Up @@ -570,10 +561,9 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
}

int RAI_ModelSerializeORT(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
RAI_ONNXBuffer* onnxbuffer = (RAI_ONNXBuffer*)model->data;
*buffer = RedisModule_Calloc(onnxbuffer->len, sizeof(char));
memcpy(*buffer, onnxbuffer->data, onnxbuffer->len);
*len = onnxbuffer->len;
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
memcpy(*buffer, model->data, model->datalen);
*len = model->datalen;

return 0;
}
55 changes: 36 additions & 19 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod

TF_ImportGraphDefOptions* options = TF_NewImportGraphDefOptions();

TF_Buffer *buffer = TF_NewBuffer();
buffer->length = modellen;
buffer->data = modeldef;
TF_Buffer *tfbuffer = TF_NewBuffer();
tfbuffer->length = modellen;
tfbuffer->data = modeldef;

TF_Status *status = TF_NewStatus();

TF_GraphImportGraphDef(model, buffer, options, status);
TF_GraphImportGraphDef(model, tfbuffer, options, status);

if (TF_GetCode(status) != TF_OK) {
char* errorMessage = RedisModule_Strdup(TF_Message(status));
Expand Down Expand Up @@ -276,7 +276,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
}

TF_DeleteImportGraphDefOptions(options);
TF_DeleteBuffer(buffer);
TF_DeleteBuffer(tfbuffer);
TF_DeleteStatus(status);

TF_Status *optionsStatus = TF_NewStatus();
Expand Down Expand Up @@ -394,6 +394,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
array_append(outputs_, RedisModule_Strdup(outputs[i]));
}

char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
ret->model = model;
ret->session = session;
Expand All @@ -403,7 +406,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
ret->outputs = outputs_;
ret->opts = opts;
ret->refCount = 1;

ret->data = buffer;
ret->datalen = modellen;

return ret;
}

Expand Down Expand Up @@ -445,6 +450,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
array_free(model->outputs);
}

if (model->data) {
RedisModule_Free(model->data);
}

TF_DeleteStatus(status);
}

Expand Down Expand Up @@ -534,24 +543,32 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
}

int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
TF_Buffer *tf_buffer = TF_NewBuffer();
TF_Status *status = TF_NewStatus();

TF_GraphToGraphDef(model->model, tf_buffer, status);
if (model->data) {
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
memcpy(*buffer, model->data, model->datalen);
*len = model->datalen;
}
else {
TF_Buffer *tf_buffer = TF_NewBuffer();
TF_Status *status = TF_NewStatus();

TF_GraphToGraphDef(model->model, tf_buffer, status);

if (TF_GetCode(status) != TF_OK) {
RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model");
TF_DeleteBuffer(tf_buffer);
TF_DeleteStatus(status);
return 1;
}

*buffer = RedisModule_Alloc(tf_buffer->length);
memcpy(*buffer, tf_buffer->data, tf_buffer->length);
*len = tf_buffer->length;

if (TF_GetCode(status) != TF_OK) {
RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model");
TF_DeleteBuffer(tf_buffer);
TF_DeleteStatus(status);
return 1;
}

*buffer = RedisModule_Alloc(tf_buffer->length);
memcpy(*buffer, tf_buffer->data, tf_buffer->length);
*len = tf_buffer->length;

TF_DeleteBuffer(tf_buffer);
TF_DeleteStatus(status);

return 0;
}
20 changes: 5 additions & 15 deletions src/backends/tflite.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
return REDISMODULE_OK;
}

typedef struct RAI_TfLiteBuffer {
char* data;
size_t len;
} RAI_TfLiteBuffer;

RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts,
const char *modeldef, size_t modellen,
RAI_Error *error) {
Expand Down Expand Up @@ -55,10 +50,6 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

RAI_TfLiteBuffer* tflitebuffer = RedisModule_Calloc(1, sizeof(*tflitebuffer));
tflitebuffer->data = buffer;
tflitebuffer->len = modellen;

RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
ret->model = model;
ret->session = NULL;
Expand All @@ -68,13 +59,13 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI
ret->outputs = NULL;
ret->refCount = 1;
ret->opts = opts;
ret->data = tflitebuffer;
ret->data = buffer;
ret->datalen = modellen;

return ret;
}

void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) {
RedisModule_Free(((RAI_TfLiteBuffer*)(model->data))->data);
RedisModule_Free(model->data);
RedisModule_Free(model->devicestr);
tfliteDeallocContext(model->model);
Expand Down Expand Up @@ -178,10 +169,9 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
}

int RAI_ModelSerializeTFLite(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
RAI_TfLiteBuffer* tflitebuffer = (RAI_TfLiteBuffer*)model->data;
*buffer = RedisModule_Calloc(tflitebuffer->len, sizeof(char));
memcpy(*buffer, tflitebuffer->data, tflitebuffer->len);
*len = tflitebuffer->len;
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
memcpy(*buffer, model->data, model->datalen);
*len = model->datalen;

return 0;
}
30 changes: 23 additions & 7 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_
return NULL;
}

char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
memcpy(buffer, modeldef, modellen);

RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
ret->model = model;
ret->session = NULL;
Expand All @@ -57,14 +60,19 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_
ret->outputs = NULL;
ret->opts = opts;
ret->refCount = 1;

ret->data = buffer;
ret->datalen = modellen;

return ret;
}

void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) {
if(model->devicestr){
RedisModule_Free(model->devicestr);
}
if (model->data) {
RedisModule_Free(model->data);
}
torchDeallocContext(model->model);
}

Expand Down Expand Up @@ -157,13 +165,21 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
}

int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
char* error_descr = NULL;
torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc);

if (*buffer == NULL) {
RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr);
RedisModule_Free(error_descr);
return 1;
if (model->data) {
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
memcpy(*buffer, model->data, model->datalen);
*len = model->datalen;
}
else {
char* error_descr = NULL;
torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc);

if (*buffer == NULL) {
RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr);
RedisModule_Free(error_descr);
return 1;
}
}

return 0;
Expand Down
3 changes: 2 additions & 1 deletion src/model_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ typedef struct RAI_Model {
char **outputs;
size_t noutputs;
long long refCount;
void* data;
char* data;
long long datalen;
void* infokey;
} RAI_Model;

Expand Down