Skip to content

Commit 614372e

Browse files
committed
Cache model blobs for faster serialization and thread-safety
1 parent 6f36dab commit 614372e

File tree

5 files changed

+71
-57
lines changed

5 files changed

+71
-57
lines changed

src/backends/onnxruntime.c

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,6 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
285285
return NULL;
286286
}
287287

288-
typedef struct RAI_ONNXBuffer {
289-
char* data;
290-
size_t len;
291-
} RAI_ONNXBuffer;
292-
293288
OrtEnv* env = NULL;
294289

295290
RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts,
@@ -368,18 +363,15 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo
368363
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
369364
memcpy(buffer, modeldef, modellen);
370365

371-
RAI_ONNXBuffer* onnxbuffer = RedisModule_Calloc(1, sizeof(*onnxbuffer));
372-
onnxbuffer->data = buffer;
373-
onnxbuffer->len = modellen;
374-
375366
RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
376367
ret->model = NULL;
377368
ret->session = session;
378369
ret->backend = backend;
379370
ret->devicestr = RedisModule_Strdup(devicestr);
380371
ret->refCount = 1;
381372
ret->opts = opts;
382-
ret->data = onnxbuffer;
373+
ret->data = buffer;
374+
ret->datalen = modellen;
383375

384376
return ret;
385377

@@ -392,7 +384,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char* devicestr, RAI_Mo
392384
void RAI_ModelFreeORT(RAI_Model* model, RAI_Error* error) {
393385
const OrtApi* ort = OrtGetApiBase()->GetApi(1);
394386

395-
RedisModule_Free(((RAI_ONNXBuffer*)(model->data))->data);
396387
RedisModule_Free(model->data);
397388
RedisModule_Free(model->devicestr);
398389
ort->ReleaseSession(model->session);
@@ -570,10 +561,9 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
570561
}
571562

572563
int RAI_ModelSerializeORT(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
573-
RAI_ONNXBuffer* onnxbuffer = (RAI_ONNXBuffer*)model->data;
574-
*buffer = RedisModule_Calloc(onnxbuffer->len, sizeof(char));
575-
memcpy(*buffer, onnxbuffer->data, onnxbuffer->len);
576-
*len = onnxbuffer->len;
564+
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
565+
memcpy(*buffer, model->data, model->datalen);
566+
*len = model->datalen;
577567

578568
return 0;
579569
}

src/backends/tensorflow.c

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
238238

239239
TF_ImportGraphDefOptions* options = TF_NewImportGraphDefOptions();
240240

241-
TF_Buffer *buffer = TF_NewBuffer();
242-
buffer->length = modellen;
243-
buffer->data = modeldef;
241+
TF_Buffer *tfbuffer = TF_NewBuffer();
242+
tfbuffer->length = modellen;
243+
tfbuffer->data = modeldef;
244244

245245
TF_Status *status = TF_NewStatus();
246246

247-
TF_GraphImportGraphDef(model, buffer, options, status);
247+
TF_GraphImportGraphDef(model, tfbuffer, options, status);
248248

249249
if (TF_GetCode(status) != TF_OK) {
250250
char* errorMessage = RedisModule_Strdup(TF_Message(status));
@@ -276,7 +276,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
276276
}
277277

278278
TF_DeleteImportGraphDefOptions(options);
279-
TF_DeleteBuffer(buffer);
279+
TF_DeleteBuffer(tfbuffer);
280280
TF_DeleteStatus(status);
281281

282282
TF_Status *optionsStatus = TF_NewStatus();
@@ -394,6 +394,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
394394
array_append(outputs_, RedisModule_Strdup(outputs[i]));
395395
}
396396

397+
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
398+
memcpy(buffer, modeldef, modellen);
399+
397400
RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
398401
ret->model = model;
399402
ret->session = session;
@@ -403,7 +406,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
403406
ret->outputs = outputs_;
404407
ret->opts = opts;
405408
ret->refCount = 1;
406-
409+
ret->data = buffer;
410+
ret->datalen = modellen;
411+
407412
return ret;
408413
}
409414

@@ -445,6 +450,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
445450
array_free(model->outputs);
446451
}
447452

453+
if (model->data) {
454+
RedisModule_Free(model->data);
455+
}
456+
448457
TF_DeleteStatus(status);
449458
}
450459

@@ -534,24 +543,32 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
534543
}
535544

536545
int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
537-
TF_Buffer *tf_buffer = TF_NewBuffer();
538-
TF_Status *status = TF_NewStatus();
539546

540-
TF_GraphToGraphDef(model->model, tf_buffer, status);
547+
if (model->data) {
548+
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
549+
memcpy(*buffer, model->data, model->datalen);
550+
*len = model->datalen;
551+
}
552+
else {
553+
TF_Buffer *tf_buffer = TF_NewBuffer();
554+
TF_Status *status = TF_NewStatus();
555+
556+
TF_GraphToGraphDef(model->model, tf_buffer, status);
557+
558+
if (TF_GetCode(status) != TF_OK) {
559+
RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model");
560+
TF_DeleteBuffer(tf_buffer);
561+
TF_DeleteStatus(status);
562+
return 1;
563+
}
564+
565+
*buffer = RedisModule_Alloc(tf_buffer->length);
566+
memcpy(*buffer, tf_buffer->data, tf_buffer->length);
567+
*len = tf_buffer->length;
541568

542-
if (TF_GetCode(status) != TF_OK) {
543-
RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model");
544569
TF_DeleteBuffer(tf_buffer);
545570
TF_DeleteStatus(status);
546-
return 1;
547571
}
548572

549-
*buffer = RedisModule_Alloc(tf_buffer->length);
550-
memcpy(*buffer, tf_buffer->data, tf_buffer->length);
551-
*len = tf_buffer->length;
552-
553-
TF_DeleteBuffer(tf_buffer);
554-
TF_DeleteStatus(status);
555-
556573
return 0;
557574
}

src/backends/tflite.c

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
1414
return REDISMODULE_OK;
1515
}
1616

17-
typedef struct RAI_TfLiteBuffer {
18-
char* data;
19-
size_t len;
20-
} RAI_TfLiteBuffer;
21-
2217
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI_ModelOpts opts,
2318
const char *modeldef, size_t modellen,
2419
RAI_Error *error) {
@@ -55,10 +50,6 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI
5550
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
5651
memcpy(buffer, modeldef, modellen);
5752

58-
RAI_TfLiteBuffer* tflitebuffer = RedisModule_Calloc(1, sizeof(*tflitebuffer));
59-
tflitebuffer->data = buffer;
60-
tflitebuffer->len = modellen;
61-
6253
RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
6354
ret->model = model;
6455
ret->session = NULL;
@@ -68,13 +59,13 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char* devicestr, RAI
6859
ret->outputs = NULL;
6960
ret->refCount = 1;
7061
ret->opts = opts;
71-
ret->data = tflitebuffer;
62+
ret->data = buffer;
63+
ret->datalen = modellen;
7264

7365
return ret;
7466
}
7567

7668
void RAI_ModelFreeTFLite(RAI_Model* model, RAI_Error *error) {
77-
RedisModule_Free(((RAI_TfLiteBuffer*)(model->data))->data);
7869
RedisModule_Free(model->data);
7970
RedisModule_Free(model->devicestr);
8071
tfliteDeallocContext(model->model);
@@ -178,10 +169,9 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
178169
}
179170

180171
int RAI_ModelSerializeTFLite(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
181-
RAI_TfLiteBuffer* tflitebuffer = (RAI_TfLiteBuffer*)model->data;
182-
*buffer = RedisModule_Calloc(tflitebuffer->len, sizeof(char));
183-
memcpy(*buffer, tflitebuffer->data, tflitebuffer->len);
184-
*len = tflitebuffer->len;
172+
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
173+
memcpy(*buffer, model->data, model->datalen);
174+
*len = model->datalen;
185175

186176
return 0;
187177
}

src/backends/torch.c

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_
4848
return NULL;
4949
}
5050

51+
char* buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
52+
memcpy(buffer, modeldef, modellen);
53+
5154
RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
5255
ret->model = model;
5356
ret->session = NULL;
@@ -57,14 +60,19 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr, RAI_
5760
ret->outputs = NULL;
5861
ret->opts = opts;
5962
ret->refCount = 1;
60-
63+
ret->data = buffer;
64+
ret->datalen = modellen;
65+
6166
return ret;
6267
}
6368

6469
void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) {
6570
if(model->devicestr){
6671
RedisModule_Free(model->devicestr);
6772
}
73+
if (model->data) {
74+
RedisModule_Free(model->data);
75+
}
6876
torchDeallocContext(model->model);
6977
}
7078

@@ -157,13 +165,21 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
157165
}
158166

159167
int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
160-
char* error_descr = NULL;
161-
torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc);
162168

163-
if (*buffer == NULL) {
164-
RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr);
165-
RedisModule_Free(error_descr);
166-
return 1;
169+
if (model->data) {
170+
*buffer = RedisModule_Calloc(model->datalen, sizeof(char));
171+
memcpy(*buffer, model->data, model->datalen);
172+
*len = model->datalen;
173+
}
174+
else {
175+
char* error_descr = NULL;
176+
torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc);
177+
178+
if (*buffer == NULL) {
179+
RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr);
180+
RedisModule_Free(error_descr);
181+
return 1;
182+
}
167183
}
168184

169185
return 0;

src/model_struct.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ typedef struct RAI_Model {
2929
char **outputs;
3030
size_t noutputs;
3131
long long refCount;
32-
void* data;
32+
char* data;
33+
long long datalen;
3334
void* infokey;
3435
} RAI_Model;
3536

0 commit comments

Comments
 (0)