Skip to content

Add support for variadic arguments to SCRIPT #395

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,16 @@ The **`AI.SCRIPTRUN`** command runs a script stored as a key's value on its spec
**Redis API**

```
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] OUTPUTS <output> [output ...]
AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] [$ input ...] OUTPUTS <output> [output ...]
```

_Arguments_

* **key**: the script's key name
* **function**: the name of the function to run
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names;
variadic arguments are supported by prepending the list with `$`, in this case the
script is expected an argument of type `List[Tensor]` as its last argument
* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by one or more key names

_Return_
Expand All @@ -489,6 +491,29 @@ redis> AI.TENSORGET result VALUES
3) 1) "42"
```

If 'myscript' supports variadic arguments:
```python
def addn(a, args : List[Tensor]):
return a + torch.stack(args).sum()
```

then one can provide an arbitrary number of inputs after the `$` sign:

```
redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40
OK
redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 1
OK
redis> AI.TENSORSET mytensor3 FLOAT 1 VALUES 1
OK
redis> AI.SCRIPTRUN myscript addn INPUTS mytensor1 $ mytensor2 mytensor3 OUTPUTS result
OK
redis> AI.TENSORGET result VALUES
1) FLOAT
2) 1) (integer) 1
3) 1) "42"
```

!!! warning "Intermediate memory overhead"
The execution of scripts may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis.

Expand Down
1 change: 1 addition & 0 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* error) {

char* error_descr = NULL;
torchRunScript(sctx->script->script, sctx->fnname,
sctx->variadic,
nInputs, inputs, nOutputs, outputs,
&error_descr, RedisModule_Alloc);

Expand Down
22 changes: 18 additions & 4 deletions src/libtorch_c/torch_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct ModuleContext {
int64_t device_id;
};

void torchRunModule(ModuleContext* ctx, const char* fnName,
void torchRunModule(ModuleContext* ctx, const char* fnName, int variadic,
long nInputs, DLManagedTensor** inputs,
long nOutputs, DLManagedTensor** outputs) {
// Checks device, if GPU then move input to GPU before running
Expand All @@ -214,11 +214,25 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
torch::jit::Stack stack;

for (int i=0; i<nInputs; i++) {
if (i == variadic) {
break;
}
DLTensor* input = &(inputs[i]->dl_tensor);
torch::Tensor tensor = fromDLPack(input);
stack.push_back(tensor.to(device));
}

if (variadic != -1 ) {
std::vector<torch::Tensor> args;
for (int i=variadic; i<nInputs; i++) {
DLTensor* input = &(inputs[i]->dl_tensor);
torch::Tensor tensor = fromDLPack(input);
tensor.to(device);
args.emplace_back(tensor);
}
stack.push_back(args);
}

if (ctx->module) {
torch::NoGradGuard guard;
torch::jit::script::Method method = ctx->module->get_method(fnName);
Expand Down Expand Up @@ -351,14 +365,14 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
return ctx;
}

extern "C" void torchRunScript(void* scriptCtx, const char* fnName,
extern "C" void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
long nInputs, DLManagedTensor** inputs,
long nOutputs, DLManagedTensor** outputs,
char **error, void* (*alloc)(size_t))
{
ModuleContext* ctx = (ModuleContext*)scriptCtx;
try {
torchRunModule(ctx, fnName, nInputs, inputs, nOutputs, outputs);
torchRunModule(ctx, fnName, variadic, nInputs, inputs, nOutputs, outputs);
}
catch(std::exception& e) {
const size_t len = strlen(e.what());
Expand All @@ -376,7 +390,7 @@ extern "C" void torchRunModel(void* modelCtx,
{
ModuleContext* ctx = (ModuleContext*)modelCtx;
try {
torchRunModule(ctx, "forward", nInputs, inputs, nOutputs, outputs);
torchRunModule(ctx, "forward", -1, nInputs, inputs, nOutputs, outputs);
}
catch(std::exception& e) {
const size_t len = strlen(e.what());
Expand Down
2 changes: 1 addition & 1 deletion src/libtorch_c/torch_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void* torchCompileScript(const char* script, DLDeviceType device, int64_t device
void* torchLoadModel(const char* model, size_t modellen, DLDeviceType device, int64_t device_id,
char **error, void* (*alloc)(size_t));

void torchRunScript(void* scriptCtx, const char* fnName,
void torchRunScript(void* scriptCtx, const char* fnName, int variadic,
long nInputs, DLManagedTensor** inputs,
long nOutputs, DLManagedTensor** outputs,
char **error, void* (*alloc)(size_t));
Expand Down
11 changes: 8 additions & 3 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE);
sctx->fnname = RedisModule_Strdup(fnname);
sctx->variadic = -1;
return sctx;
}

Expand Down Expand Up @@ -285,6 +286,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
is_input = 1;
outputs_flag_count = 1;
} else {
if (!strcasecmp(arg_string, "$")) {
(*sctx)->variadic = argpos - 4;
continue;
}
RedisModule_RetainString(ctx, argv[argpos]);
if (is_input == 0) {
RAI_Tensor *inputTensor;
Expand All @@ -299,18 +304,18 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
RedisModule_CloseKey(tensorKey);
} else {
const int get_result = RAI_getTensorFromLocalContext(
ctx, *localContextDict, arg_string, &inputTensor,error);
ctx, *localContextDict, arg_string, &inputTensor, error);
if (get_result == REDISMODULE_ERR) {
return -1;
}
}
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) {
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Input key not found");
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found");
return -1;
}
} else {
if (!RAI_ScriptRunCtxAddOutput(*sctx)) {
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Output key not found");
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found");
return -1;
}
*outkeys=array_append(*outkeys,argv[argpos]);
Expand Down
1 change: 1 addition & 0 deletions src/script_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ typedef struct RAI_ScriptRunCtx {
char* fnname;
RAI_ScriptCtxParam* inputs;
RAI_ScriptCtxParam* outputs;
int variadic;
} RAI_ScriptRunCtx;

#endif /* SRC_SCRIPT_STRUCT_H_ */
3 changes: 3 additions & 0 deletions test/test_data/script.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
def bar(a, b):
return a + b

def bar_variadic(a, args : List[Tensor]):
return args[0] + args[1]
115 changes: 115 additions & 0 deletions test/tests_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,61 @@ def test_pytorch_scriptrun(env):
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
env.assertEqual(values2, values)


def test_pytorch_scriptrun_variadic(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
return

con = env.getConnection()

test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
script_filename = os.path.join(test_data_path, 'script.txt')

with open(script_filename, 'rb') as f:
script = f.read()

ret = con.execute_command('AI.SCRIPTSET', 'myscript', DEVICE, 'TAG', 'version1', 'SOURCE', script)
env.assertEqual(ret, b'OK')

ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')
ret = con.execute_command('AI.TENSORSET', 'b1', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')
ret = con.execute_command('AI.TENSORSET', 'b2', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')

ensureSlaveSynced(con, env)

for _ in range( 0,100):
ret = con.execute_command('AI.SCRIPTRUN', 'myscript', 'bar_variadic', 'INPUTS', 'a', '$', 'b1', 'b2', 'OUTPUTS', 'c')
env.assertEqual(ret, b'OK')

ensureSlaveSynced(con, env)

info = con.execute_command('AI.INFO', 'myscript')
info_dict_0 = info_to_dict(info)

env.assertEqual(info_dict_0['key'], 'myscript')
env.assertEqual(info_dict_0['type'], 'SCRIPT')
env.assertEqual(info_dict_0['backend'], 'TORCH')
env.assertEqual(info_dict_0['tag'], 'version1')
env.assertTrue(info_dict_0['duration'] > 0)
env.assertEqual(info_dict_0['samples'], -1)
env.assertEqual(info_dict_0['calls'], 100)
env.assertEqual(info_dict_0['errors'], 0)

values = con.execute_command('AI.TENSORGET', 'c', 'VALUES')
env.assertEqual(values, [b'4', b'6', b'4', b'6'])

ensureSlaveSynced(con, env)

if env.useSlaves:
con2 = env.getSlaveConnection()
values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES')
env.assertEqual(values2, values)


def test_pytorch_scriptrun_errors(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
Expand Down Expand Up @@ -528,6 +583,66 @@ def test_pytorch_scriptrun_errors(env):
env.assertEqual(type(exception), redis.exceptions.ResponseError)


def test_pytorch_scriptrun_errors(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
return

con = env.getConnection()

test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
script_filename = os.path.join(test_data_path, 'script.txt')

with open(script_filename, 'rb') as f:
script = f.read()

ret = con.execute_command('AI.SCRIPTSET', 'ket', DEVICE, 'TAG', 'asdf', 'SOURCE', script)
env.assertEqual(ret, b'OK')

ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')
ret = con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')

ensureSlaveSynced(con, env)

# ERR Variadic input key is empty
try:
con.execute_command('DEL', 'EMPTY')
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$', 'EMPTY', 'b', 'OUTPUTS', 'c')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("tensor key is empty", exception.__str__())

# ERR Variadic input key not tensor
try:
con.execute_command('SET', 'NOT_TENSOR', 'BAR')
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$' , 'NOT_TENSOR', 'b', 'OUTPUTS', 'c')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__())

try:
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS', 'c')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)

try:
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)

try:
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)


def test_pytorch_scriptinfo(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
Expand Down