Skip to content

Commit b66585b

Browse files
DvirDukhanfilipecosta90
authored andcommitted
Llapi updates (#400)
* added low level api return redis types * added variadic to llapi * fixed memory issue for params array re-alloc
1 parent f0fdbf1 commit b66585b

File tree

9 files changed

+105
-16
lines changed

9 files changed

+105
-16
lines changed

src/model.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,3 +648,7 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
648648
}
649649
return argpos;
650650
}
651+
652+
RedisModuleType *RAI_ModelRedisType(void) {
653+
return RedisAI_ModelType;
654+
}

src/model.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,11 @@ int RedisAI_Parse_ModelRun_RedisCommand(
226226
RAI_ModelRunCtx** mctx, RedisModuleString*** outkeys, RAI_Model** mto,
227227
int useLocalContext, AI_dict** localContextDict, int use_chaining_operator,
228228
const char* chaining_operator, RAI_Error* error);
229+
230+
/**
231+
* @brief Returns the redis module type representing a model.
232+
* @return redis module type representing a model.
233+
*/
234+
RedisModuleType *RAI_ModelRedisType(void);
235+
229236
#endif /* SRC_MODEL_H_ */

src/redisai.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) {
966966
REGISTER_API(TensorDim, ctx);
967967
REGISTER_API(TensorByteSize, ctx);
968968
REGISTER_API(TensorData, ctx);
969+
REGISTER_API(TensorRedisType, ctx);
969970

970971
REGISTER_API(ModelCreate, ctx);
971972
REGISTER_API(ModelFree, ctx);
@@ -978,17 +979,20 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) {
978979
REGISTER_API(ModelRun, ctx);
979980
REGISTER_API(ModelSerialize, ctx);
980981
REGISTER_API(ModelGetShallowCopy, ctx);
982+
REGISTER_API(ModelRedisType, ctx);
981983

982984
REGISTER_API(ScriptCreate, ctx);
983985
REGISTER_API(ScriptFree, ctx);
984986
REGISTER_API(ScriptRunCtxCreate, ctx);
985987
REGISTER_API(ScriptRunCtxAddInput, ctx);
988+
REGISTER_API(ScriptRunCtxAddInputList, ctx);
986989
REGISTER_API(ScriptRunCtxAddOutput, ctx);
987990
REGISTER_API(ScriptRunCtxNumOutputs, ctx);
988991
REGISTER_API(ScriptRunCtxOutputTensor, ctx);
989992
REGISTER_API(ScriptRunCtxFree, ctx);
990993
REGISTER_API(ScriptRun, ctx);
991994
REGISTER_API(ScriptGetShallowCopy, ctx);
995+
REGISTER_API(ScriptRedisType, ctx);
992996

993997
return REDISMODULE_OK;
994998
}

src/redisai.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ int MODULE_API_FUNC(RedisAI_TensorNumDims)(RAI_Tensor* t);
7878
long long MODULE_API_FUNC(RedisAI_TensorDim)(RAI_Tensor* t, int dim);
7979
size_t MODULE_API_FUNC(RedisAI_TensorByteSize)(RAI_Tensor* t);
8080
char* MODULE_API_FUNC(RedisAI_TensorData)(RAI_Tensor* t);
81+
RedisModuleType* MODULE_API_FUNC(RedisAI_TensorRedisType)(void);
8182

8283
RAI_Model* MODULE_API_FUNC(RedisAI_ModelCreate)(int backend, char* devicestr, char* tag, RAI_ModelOpts opts,
8384
size_t ninputs, const char **inputs,
@@ -93,17 +94,20 @@ void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx* mctx);
9394
int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx** mctx, long long n, RAI_Error* err);
9495
RAI_Model* MODULE_API_FUNC(RedisAI_ModelGetShallowCopy)(RAI_Model* model);
9596
int MODULE_API_FUNC(RedisAI_ModelSerialize)(RAI_Model *model, char **buffer, size_t *len, RAI_Error *err);
97+
RedisModuleType* MODULE_API_FUNC(RedisAI_ModelRedisType)(void);
9698

9799
RAI_Script* MODULE_API_FUNC(RedisAI_ScriptCreate)(char* devicestr, char* tag, const char* scriptdef, RAI_Error* err);
98100
void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script* script, RAI_Error* err);
99101
RAI_ScriptRunCtx* MODULE_API_FUNC(RedisAI_ScriptRunCtxCreate)(RAI_Script* script, const char *fnname);
100-
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor);
102+
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err);
103+
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInputList)(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err);
101104
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddOutput)(RAI_ScriptRunCtx* sctx);
102105
size_t MODULE_API_FUNC(RedisAI_ScriptRunCtxNumOutputs)(RAI_ScriptRunCtx* sctx);
103106
RAI_Tensor* MODULE_API_FUNC(RedisAI_ScriptRunCtxOutputTensor)(RAI_ScriptRunCtx* sctx, size_t index);
104107
void MODULE_API_FUNC(RedisAI_ScriptRunCtxFree)(RAI_ScriptRunCtx* sctx);
105108
int MODULE_API_FUNC(RedisAI_ScriptRun)(RAI_ScriptRunCtx* sctx, RAI_Error* err);
106109
RAI_Script* MODULE_API_FUNC(RedisAI_ScriptGetShallowCopy)(RAI_Script* script);
110+
RedisModuleType* MODULE_API_FUNC(RedisAI_ScriptRedisType)(void);
107111

108112
int MODULE_API_FUNC(RedisAI_GetLLAPIVersion)();
109113

@@ -145,6 +149,7 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){
145149
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorDim);
146150
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorByteSize);
147151
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorData);
152+
REDISAI_MODULE_INIT_FUNCTION(ctx, TensorRedisType);
148153

149154
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelCreate);
150155
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelFree);
@@ -157,17 +162,20 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){
157162
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelRun);
158163
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelGetShallowCopy);
159164
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelSerialize);
165+
REDISAI_MODULE_INIT_FUNCTION(ctx, ModelRedisType);
160166

161167
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptCreate);
162168
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptFree);
163169
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxCreate);
164170
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInput);
171+
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInputList);
165172
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddOutput);
166173
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxNumOutputs);
167174
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxOutputTensor);
168175
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxFree);
169176
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRun);
170177
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptGetShallowCopy);
178+
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRedisType);
171179

172180
if(RedisAI_GetLLAPIVersion() < REDISAI_LLAPI_VERSION){
173181
return REDISMODULE_ERR;

src/script.c

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,40 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
156156
}
157157

158158
static int Script_RunCtxAddParam(RAI_ScriptRunCtx* sctx,
159-
RAI_ScriptCtxParam* paramArr,
159+
RAI_ScriptCtxParam** paramArr,
160160
RAI_Tensor* tensor) {
161161
RAI_ScriptCtxParam param = {
162162
.tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL,
163163
};
164-
paramArr = array_append(paramArr, param);
164+
*paramArr = array_append(*paramArr, param);
165165
return 1;
166166
}
167167

168-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor) {
169-
return Script_RunCtxAddParam(sctx, sctx->inputs, inputTensor);
168+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err) {
169+
if(sctx->variadic != -1) {
170+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors");
171+
return 0;
172+
}
173+
return Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensor);
174+
}
175+
176+
int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err) {
177+
// If this is the first time a list is added, set the variadic, else return an error.
178+
if(sctx->variadic == -1) {
179+
sctx->variadic = array_len(sctx->inputs);
180+
}
181+
else {
182+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors");
183+
return 0;
184+
}
185+
for(size_t i=0; i < len; i++){
186+
Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensors[i]);
187+
}
188+
return 1;
170189
}
171190

172191
int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx* sctx) {
173-
return Script_RunCtxAddParam(sctx, sctx->outputs, NULL);
192+
return Script_RunCtxAddParam(sctx, &sctx->outputs, NULL);
174193
}
175194

176195
size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx* sctx) {
@@ -274,7 +293,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
274293
int is_input = 0;
275294
int outputs_flag_count = 0;
276295
size_t argpos = 4;
277-
296+
// Keep variadic local variable as the calls for RAI_ScriptRunCtxAddInput check if (*sctx)->variadic already assigned.
297+
size_t variadic = (*sctx)->variadic;
278298
for (; argpos <= argc - 1; argpos++) {
279299
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
280300
if(!arg_string){
@@ -291,7 +311,11 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
291311
outputs_flag_count = 1;
292312
} else {
293313
if (!strcasecmp(arg_string, "$")) {
294-
(*sctx)->variadic = argpos - 4;
314+
if(variadic > -1) {
315+
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Already encountered a variable size list of tensors");
316+
return -1;
317+
}
318+
variadic = argpos - 4;
295319
continue;
296320
}
297321
RedisModule_RetainString(ctx, argv[argpos]);
@@ -313,10 +337,7 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
313337
return -1;
314338
}
315339
}
316-
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) {
317-
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found");
318-
return -1;
319-
}
340+
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor, error)) return -1;
320341
} else {
321342
if (!RAI_ScriptRunCtxAddOutput(*sctx)) {
322343
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found");
@@ -326,6 +347,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
326347
}
327348
}
328349
}
350+
// In case variadic position found, set it in the context.
351+
(*sctx)->variadic = variadic;
329352
return argpos;
330353
}
331354

@@ -338,4 +361,8 @@ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCod
338361
} else {
339362
RedisModule_ReplyWithError(ctx, errorMessage);
340363
}
341-
}
364+
}
365+
366+
RedisModuleType *RAI_ScriptRedisType(void) {
367+
return RedisAI_ScriptType;
368+
}

src/script.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,25 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
6767
*
6868
* @param sctx input RAI_ScriptRunCtx to add the input tensor
6969
* @param inputTensor input tensor structure
70-
* @return returns 1 on success ( always returns success )
70+
* @param err error data structure to store error message in the case of
71+
* failures
72+
* @return returns 1 on success, 0 in case of error.
7173
*/
72-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor);
74+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err);
75+
76+
/**
77+
* For each Allocates a RAI_ScriptCtxParam data structure, and enforces a shallow copy of
78+
* the provided input tensor, adding it to the input tensors array of the
79+
* RAI_ScriptRunCtx.
80+
*
81+
* @param sctx input RAI_ScriptRunCtx to add the input tensor
82+
* @param inputTensors input tensors array
83+
* @param len input tensors array len
84+
* @param err error data structure to store error message in the case of
85+
* failures
86+
* @return returns 1 on success, 0 in case of error.
87+
*/
88+
int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err);
7389

7490
/**
7591
* Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference
@@ -193,4 +209,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
193209
*/
194210
void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCode code, const char* errorMessage );
195211

212+
/**
213+
* @brief Returns the redis module type representing a script.
214+
* @return redis module type representing a script.
215+
*/
216+
RedisModuleType *RAI_ScriptRedisType(void);
217+
196218
#endif /* SRC_SCRIPT_H_ */

src/tensor.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,4 +1072,8 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
10721072

10731073
// return command arity as the number of processed args
10741074
return argc;
1075-
}
1075+
}
1076+
1077+
RedisModuleType *RAI_TensorRedisType(void) {
1078+
return RedisAI_TensorType;
1079+
}

src/tensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,10 @@ int RAI_parseTensorSetArgs(RedisModuleCtx* ctx, RedisModuleString** argv,
378378
int RAI_parseTensorGetArgs(RedisModuleCtx* ctx, RedisModuleString** argv,
379379
int argc, RAI_Tensor* t);
380380

381+
/**
382+
* @brief Returns the redis module type representing a tensor.
383+
* @return redis module type representing a tensor.
384+
*/
385+
RedisModuleType *RAI_TensorRedisType(void);
386+
381387
#endif /* SRC_TENSOR_H_ */

test/tests_pytorch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,13 @@ def test_pytorch_scriptrun_errors(env):
661661
except Exception as e:
662662
exception = e
663663
env.assertEqual(type(exception), redis.exceptions.ResponseError)
664+
665+
# "ERR Already encountered a variable size list of tensors"
666+
try:
667+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'a', '$', 'b' 'OUTPUTS')
668+
except Exception as e:
669+
exception = e
670+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
664671

665672

666673
def test_pytorch_scriptinfo(env):

0 commit comments

Comments
 (0)