Skip to content

Commit af92171

Browse files
lantigafilipecosta90
authored andcommitted
Add support for variadic arguments to SCRIPT (#395)
* Add support for variadic arguments to SCRIPT * Add negative errors
1 parent 87df138 commit af92171

File tree

8 files changed

+174
-10
lines changed

8 files changed

+174
-10
lines changed

docs/commands.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,16 @@ The **`AI.SCRIPTRUN`** command runs a script stored as a key's value on its spec
460460
**Redis API**
461461

462462
```
463-
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] OUTPUTS <output> [output ...]
463+
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] [$ input ...] OUTPUTS <output> [output ...]
464464
```
465465

466466
_Arguments_
467467

468468
* **key**: the script's key name
469469
* **function**: the name of the function to run
470-
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names
470+
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names;
471+
variadic arguments are supported by prepending the list with `$`, in this case the
472+
script is expected an argument of type `List[Tensor]` as its last argument
471473
* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by one or more key names
472474

473475
_Return_
@@ -491,6 +493,29 @@ redis> AI.TENSORGET result VALUES
491493
3) 1) "42"
492494
```
493495

496+
If 'myscript' supports variadic arguments:
497+
```python
498+
def addn(a, args : List[Tensor]):
499+
return a + torch.stack(args).sum()
500+
```
501+
502+
then one can provide an arbitrary number of inputs after the `$` sign:
503+
504+
```
505+
redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40
506+
OK
507+
redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 1
508+
OK
509+
redis> AI.TENSORSET mytensor3 FLOAT 1 VALUES 1
510+
OK
511+
redis> AI.SCRIPTRUN myscript addn INPUTS mytensor1 $ mytensor2 mytensor3 OUTPUTS result
512+
OK
513+
redis> AI.TENSORGET result VALUES
514+
1) FLOAT
515+
2) 1) (integer) 1
516+
3) 1) "42"
517+
```
518+
494519
!!! warning "Intermediate memory overhead"
495520
The execution of scripts may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis.
496521

src/backends/torch.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* error) {
252252

253253
char* error_descr = NULL;
254254
torchRunScript(sctx->script->script, sctx->fnname,
255+
sctx->variadic,
255256
nInputs, inputs, nOutputs, outputs,
256257
&error_descr, RedisModule_Alloc);
257258

src/libtorch_c/torch_c.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ struct ModuleContext {
190190
int64_t device_id;
191191
};
192192

193-
void torchRunModule(ModuleContext* ctx, const char* fnName,
193+
void torchRunModule(ModuleContext* ctx, const char* fnName, int variadic,
194194
long nInputs, DLManagedTensor** inputs,
195195
long nOutputs, DLManagedTensor** outputs) {
196196
// Checks device, if GPU then move input to GPU before running
@@ -214,11 +214,25 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
214214
torch::jit::Stack stack;
215215

216216
for (int i=0; i<nInputs; i++) {
217+
if (i == variadic) {
218+
break;
219+
}
217220
DLTensor* input = &(inputs[i]->dl_tensor);
218221
torch::Tensor tensor = fromDLPack(input);
219222
stack.push_back(tensor.to(device));
220223
}
221224

225+
if (variadic != -1 ) {
226+
std::vector<torch::Tensor> args;
227+
for (int i=variadic; i<nInputs; i++) {
228+
DLTensor* input = &(inputs[i]->dl_tensor);
229+
torch::Tensor tensor = fromDLPack(input);
230+
tensor.to(device);
231+
args.emplace_back(tensor);
232+
}
233+
stack.push_back(args);
234+
}
235+
222236
if (ctx->module) {
223237
torch::NoGradGuard guard;
224238
torch::jit::script::Method method = ctx->module->get_method(fnName);
@@ -351,14 +365,14 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
351365
return ctx;
352366
}
353367

354-
extern "C" void torchRunScript(void* scriptCtx, const char* fnName,
368+
extern "C" void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
355369
long nInputs, DLManagedTensor** inputs,
356370
long nOutputs, DLManagedTensor** outputs,
357371
char **error, void* (*alloc)(size_t))
358372
{
359373
ModuleContext* ctx = (ModuleContext*)scriptCtx;
360374
try {
361-
torchRunModule(ctx, fnName, nInputs, inputs, nOutputs, outputs);
375+
torchRunModule(ctx, fnName, variadic, nInputs, inputs, nOutputs, outputs);
362376
}
363377
catch(std::exception& e) {
364378
const size_t len = strlen(e.what());
@@ -376,7 +390,7 @@ extern "C" void torchRunModel(void* modelCtx,
376390
{
377391
ModuleContext* ctx = (ModuleContext*)modelCtx;
378392
try {
379-
torchRunModule(ctx, "forward", nInputs, inputs, nOutputs, outputs);
393+
torchRunModule(ctx, "forward", -1, nInputs, inputs, nOutputs, outputs);
380394
}
381395
catch(std::exception& e) {
382396
const size_t len = strlen(e.what());

src/libtorch_c/torch_c.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void* torchCompileScript(const char* script, DLDeviceType device, int64_t device
1919
void* torchLoadModel(const char* model, size_t modellen, DLDeviceType device, int64_t device_id,
2020
char **error, void* (*alloc)(size_t));
2121

22-
void torchRunScript(void* scriptCtx, const char* fnName,
22+
void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
2323
long nInputs, DLManagedTensor** inputs,
2424
long nOutputs, DLManagedTensor** outputs,
2525
char **error, void* (*alloc)(size_t));

src/script.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
150150
sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
151151
sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
152152
sctx->fnname = RedisModule_Strdup(fnname);
153+
sctx->variadic = -1;
153154
return sctx;
154155
}
155156

@@ -288,6 +289,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
288289
is_input = 1;
289290
outputs_flag_count = 1;
290291
} else {
292+
if (!strcasecmp(arg_string, "$")) {
293+
(*sctx)->variadic = argpos - 4;
294+
continue;
295+
}
291296
RedisModule_RetainString(ctx, argv[argpos]);
292297
if (is_input == 0) {
293298
RAI_Tensor *inputTensor;
@@ -302,18 +307,18 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
302307
RedisModule_CloseKey(tensorKey);
303308
} else {
304309
const int get_result = RAI_getTensorFromLocalContext(
305-
ctx, *localContextDict, arg_string, &inputTensor,error);
310+
ctx, *localContextDict, arg_string, &inputTensor, error);
306311
if (get_result == REDISMODULE_ERR) {
307312
return -1;
308313
}
309314
}
310315
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) {
311-
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Input key not found");
316+
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found");
312317
return -1;
313318
}
314319
} else {
315320
if (!RAI_ScriptRunCtxAddOutput(*sctx)) {
316-
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Output key not found");
321+
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found");
317322
return -1;
318323
}
319324
*outkeys=array_append(*outkeys,argv[argpos]);

src/script_struct.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ typedef struct RAI_ScriptRunCtx {
2727
char* fnname;
2828
RAI_ScriptCtxParam* inputs;
2929
RAI_ScriptCtxParam* outputs;
30+
int variadic;
3031
} RAI_ScriptRunCtx;
3132

3233
#endif /* SRC_SCRIPT_STRUCT_H_ */

test/test_data/script.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
def bar(a, b):
22
return a + b
3+
4+
def bar_variadic(a, args : List[Tensor]):
5+
return args[0] + args[1]

test/tests_pytorch.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,61 @@ def test_pytorch_scriptrun(env):
446446
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
447447
env.assertEqual(values2, values)
448448

449+
450+
def test_pytorch_scriptrun_variadic(env):
451+
if not TEST_PT:
452+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
453+
return
454+
455+
con = env.getConnection()
456+
457+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
458+
script_filename = os.path.join(test_data_path, 'script.txt')
459+
460+
with open(script_filename, 'rb') as f:
461+
script = f.read()
462+
463+
ret = con.execute_command('AI.SCRIPTSET', 'myscript', DEVICE, 'TAG', 'version1', 'SOURCE', script)
464+
env.assertEqual(ret, b'OK')
465+
466+
ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
467+
env.assertEqual(ret, b'OK')
468+
ret = con.execute_command('AI.TENSORSET', 'b1', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
469+
env.assertEqual(ret, b'OK')
470+
ret = con.execute_command('AI.TENSORSET', 'b2', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
471+
env.assertEqual(ret, b'OK')
472+
473+
ensureSlaveSynced(con, env)
474+
475+
for _ in range( 0,100):
476+
ret = con.execute_command('AI.SCRIPTRUN', 'myscript', 'bar_variadic', 'INPUTS', 'a', '$', 'b1', 'b2', 'OUTPUTS', 'c')
477+
env.assertEqual(ret, b'OK')
478+
479+
ensureSlaveSynced(con, env)
480+
481+
info = con.execute_command('AI.INFO', 'myscript')
482+
info_dict_0 = info_to_dict(info)
483+
484+
env.assertEqual(info_dict_0['key'], 'myscript')
485+
env.assertEqual(info_dict_0['type'], 'SCRIPT')
486+
env.assertEqual(info_dict_0['backend'], 'TORCH')
487+
env.assertEqual(info_dict_0['tag'], 'version1')
488+
env.assertTrue(info_dict_0['duration'] > 0)
489+
env.assertEqual(info_dict_0['samples'], -1)
490+
env.assertEqual(info_dict_0['calls'], 100)
491+
env.assertEqual(info_dict_0['errors'], 0)
492+
493+
values = con.execute_command('AI.TENSORGET', 'c', 'VALUES')
494+
env.assertEqual(values, [b'4', b'6', b'4', b'6'])
495+
496+
ensureSlaveSynced(con, env)
497+
498+
if env.useSlaves:
499+
con2 = env.getSlaveConnection()
500+
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
501+
env.assertEqual(values2, values)
502+
503+
449504
def test_pytorch_scriptrun_errors(env):
450505
if not TEST_PT:
451506
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
@@ -548,6 +603,66 @@ def test_pytorch_scriptrun_errors(env):
548603
env.assertEqual(type(exception), redis.exceptions.ResponseError)
549604

550605

606+
def test_pytorch_scriptrun_errors(env):
607+
if not TEST_PT:
608+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
609+
return
610+
611+
con = env.getConnection()
612+
613+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
614+
script_filename = os.path.join(test_data_path, 'script.txt')
615+
616+
with open(script_filename, 'rb') as f:
617+
script = f.read()
618+
619+
ret = con.execute_command('AI.SCRIPTSET', 'ket', DEVICE, 'TAG', 'asdf', 'SOURCE', script)
620+
env.assertEqual(ret, b'OK')
621+
622+
ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
623+
env.assertEqual(ret, b'OK')
624+
ret = con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
625+
env.assertEqual(ret, b'OK')
626+
627+
ensureSlaveSynced(con, env)
628+
629+
# ERR Variadic input key is empty
630+
try:
631+
con.execute_command('DEL', 'EMPTY')
632+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$', 'EMPTY', 'b', 'OUTPUTS', 'c')
633+
except Exception as e:
634+
exception = e
635+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
636+
env.assertEqual("tensor key is empty", exception.__str__())
637+
638+
# ERR Variadic input key not tensor
639+
try:
640+
con.execute_command('SET', 'NOT_TENSOR', 'BAR')
641+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$' , 'NOT_TENSOR', 'b', 'OUTPUTS', 'c')
642+
except Exception as e:
643+
exception = e
644+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
645+
env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__())
646+
647+
try:
648+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS', 'c')
649+
except Exception as e:
650+
exception = e
651+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
652+
653+
try:
654+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS')
655+
except Exception as e:
656+
exception = e
657+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
658+
659+
try:
660+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS')
661+
except Exception as e:
662+
exception = e
663+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
664+
665+
551666
def test_pytorch_scriptinfo(env):
552667
if not TEST_PT:
553668
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)

0 commit comments

Comments
 (0)