Skip to content

Commit 81b528b

Browse files
committed
Add default behaviour for AI.MODELGET + documentation
1 parent 65cad8e commit 81b528b

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

docs/commands.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ AI.MODELGET <key> [META] [BLOB]
222222
_Arguments
223223

224224
* **key**: the model's key name
225-
* **META**: will return the model's meta information on backend, device, tag and batching parameters
226-
* **BLOB**: will return the model's blob containing the serialized model
225+
* **META**: will return only the model's meta information on backend, device, tag and batching parameters
226+
* **BLOB**: will return only the model's blob containing the serialized model
227227

228228
_Return_
229229

@@ -237,7 +237,7 @@ An array of alternating key-value pairs as follows:
237237
1. **INPUTS**: array reply with one or more names of the model's input nodes (applicable only for TensorFlow models)
238238
1. **OUTPUTS**: array reply with one or more names of the model's output nodes (applicable only for TensorFlow models)
239239
1. **MINBATCHTIMEOUT**: The time in milliseconds for which the engine will wait before executing a request to run the model, when the number of incoming requests is lower than `MINBATCHSIZE`. When `MINBATCHTIMEOUT` is 0, the engine will not run the model before it receives at least `MINBATCHSIZE` requests.
240-
1. **BLOB**: a blob containing the serialized model (when called with the `BLOB` argument) as a String. If the size of the serialized model exceeds `MODEL_CHUNK_SIZE` (see `AI.CONFIG` command), then an array of chunks is returned. The full serialized model can be obtained by concatenating the chunks.
240+
1. **BLOB**: a blob containing the serialized model as a String. If the size of the serialized model exceeds `MODEL_CHUNK_SIZE` (see `AI.CONFIG` command), then an array of chunks is returned. The full serialized model can be obtained by concatenating the chunks.
241241

242242
**Examples**
243243

src/redisai.c

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -416,31 +416,24 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
416416
return REDISMODULE_ERR;
417417
}
418418

419-
int meta = 0;
420-
int blob = 0;
419+
int meta = false;
420+
int blob = false;
421421
for (int i = 2; i < argc; i++) {
422422
const char *optstr = RedisModule_StringPtrLen(argv[i], NULL);
423423
if (!strcasecmp(optstr, "META")) {
424-
meta = 1;
424+
meta = true;
425425
} else if (!strcasecmp(optstr, "BLOB")) {
426-
blob = 1;
426+
blob = true;
427427
}
428428
}
429429

430-
if (!meta && !blob) {
431-
return RedisModule_ReplyWithError(ctx, "ERR no META or BLOB specified");
432-
}
433-
434430
char *buffer = NULL;
435431
size_t len = 0;
436432

437-
if (blob) {
433+
if (!meta || blob) {
438434
RAI_ModelSerialize(mto, &buffer, &len, &err);
439-
if (err.code != RAI_OK) {
440-
#ifdef RAI_PRINT_BACKEND_ERRORS
441-
printf("ERR: %s\n", err.detail);
442-
#endif
443-
int ret = RedisModule_ReplyWithError(ctx, err.detail);
435+
if (RAI_GetErrorCode(&err) != RAI_OK) {
436+
int ret = RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
444437
RAI_ClearError(&err);
445438
if (*buffer) {
446439
RedisModule_Free(buffer);
@@ -455,12 +448,14 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
455448
return REDISMODULE_OK;
456449
}
457450

458-
const int outentries = blob ? 18 : 16;
459-
RedisModule_ReplyWithArray(ctx, outentries);
451+
// The only case where we return only META, is when META is given but BLOB
452+
// was not. Otherwise, we return both META+SOURCE
453+
const int out_entries = (meta && !blob) ? 16 : 18;
454+
RedisModule_ReplyWithArray(ctx, out_entries);
460455

461456
RedisModule_ReplyWithCString(ctx, "backend");
462-
const char *backendstr = RAI_GetBackendName(mto->backend);
463-
RedisModule_ReplyWithCString(ctx, backendstr);
457+
const char *backend_str = RAI_GetBackendName(mto->backend);
458+
RedisModule_ReplyWithCString(ctx, backend_str);
464459

465460
RedisModule_ReplyWithCString(ctx, "device");
466461
RedisModule_ReplyWithCString(ctx, mto->devicestr);
@@ -495,7 +490,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
495490
RedisModule_ReplyWithCString(ctx, "minbatchtimeout");
496491
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.minbatchtimeout);
497492

498-
if (meta && blob) {
493+
// This condition is the negation of (meta && !blob)
494+
if (!meta || blob) {
499495
RedisModule_ReplyWithCString(ctx, "blob");
500496
RAI_ReplyWithChunks(ctx, buffer, len);
501497
RedisModule_Free(buffer);

tests/flow/tests_commands.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_modelstore_errors(env):
5454
'AI.MODELSTORE', 'm{1}', 'TORCH', DEVICE, 'BATCHSIZE', 2, 'BLOB')
5555

5656

57-
def test_modelget_errors(env):
57+
def test_modelget(env):
5858
if not TEST_TF:
5959
env.debugPrint("Skipping test since TF is not available", force=True)
6060
return
@@ -69,8 +69,17 @@ def test_modelget_errors(env):
6969

7070
# ERR model key is empty
7171
con.execute_command('DEL', 'DONT_EXIST{1}')
72-
check_error_message(env, con, "model key is empty",
73-
'AI.MODELGET', 'DONT_EXIST{1}')
72+
check_error_message(env, con, "model key is empty", 'AI.MODELGET', 'DONT_EXIST{1}')
73+
74+
# The default behaviour on success is return META+BLOB
75+
model_pb = load_file_content('graph.pb')
76+
con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul',
77+
'BLOB', model_pb)
78+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _, inputs, _, outputs, _, minbatchtimeout, _, blob = \
79+
con.execute_command('AI.MODELGET', 'm{1}')
80+
env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
81+
[b"TF", bytes(DEVICE, "utf8"), b"", 0, 0, 0, [b"a", b"b"], [b"mul"]])
82+
env.assertEqual(blob, model_pb)
7483

7584

7685
def test_modelexecute_errors(env):

0 commit comments

Comments
 (0)