Skip to content

Commit a63e78d

Browse files
authored
Add support for batching (#241) * Add support for automated batching Add support for inspection and eviction to queue Mock run info batching Mock run info batching Make TF tests work Add batching for ONNX and ONNX-ML Fix torch API, still WIP Fix torch backend Fixes after rebasing Add auto-batching to TFLite backend Fix from rebase Add batching args to command and change API accordingly Add batching heuristics [WIP] Fix TFLite test by accessing first tensor in first batch safely Temporarily comment out wrong_bg test check Implement batching heuristics Introduce autobatch tests, tflite still fails Fix segfault when error was generated from the backend Fix tflite autobatch test Updated documentation with auto batching Remove stale comments Avoid making extra copies of inputs and outputs when batch count is 1 * Address review comments re const-correctness
1 parent 0c5c8ba commit a63e78d

22 files changed

+1115
-198
lines changed

docs/commands.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,22 @@ AI.TENSORGET foo BLOB
8787
Set a model.
8888

8989
```sql
90-
AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
90+
AI.MODELSET model_key backend device [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
9191
```
9292

9393
* model_key - Key for storing the model
9494
* backend - The backend corresponding to the model being set. Allowed values: `TF`, `TORCH`, `ONNX`.
9595
* device - Device where the model is loaded and where the computation will run. Allowed values: `CPU`, `GPU`.
96+
* BATCHSIZE n - Batch incoming requests from multiple clients if they hit the same model and if input tensors have the same
97+
shape. Upon MODELRUN, the request queue is visited, input tensors from compatible requests are concatenated
98+
along the 0-th (batch) dimension, up until BATCHSIZE is exceeded. The model is then run for the entire batch,
99+
results are unpacked back among the individual requests and the respective clients are unblocked.
100+
If the batch size of the inputs to the first request in the queue exceeds BATCHSIZE, the request is served
101+
in any case. Default is 0 (no batching).
102+
* MINBATCHSIZE m - Do not execute a MODELRUN until the batch size has reached MINBATCHSIZE. This is primarily used to force
103+
batching during testing, but it can also be used under normal operation. In this case, note that requests
104+
for which MINBATCHSIZE is not reached will hang indefinitely.
105+
Default is 0 (no minimum batch size).
96106
* INPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to inputs [`TF` backend only]
97107
* OUTPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to outputs [`TF` backend only]
98108
* model_blob - Binary buffer containing the model protobuf saved from a supported backend
@@ -111,6 +121,14 @@ AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 < foo.pb
111121
AI.MODELSET mnist_net ONNX CPU < mnist.onnx
112122
```
113123

124+
```sql
125+
AI.MODELSET mnist_net ONNX CPU BATCHSIZE 10 < mnist.onnx
126+
```
127+
128+
```sql
129+
AI.MODELSET resnet18 TF CPU BATCHSIZE 10 MINBATCHSIZE 6 INPUTS in1 OUTPUTS linear4 < foo.pb
130+
```
131+
114132
## AI.MODELGET
115133

116134
Get a model.

src/backends.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) {
7474
}
7575
init_backend(RedisModule_GetApi);
7676

77-
backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*,
77+
backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
7878
size_t, const char**, size_t, const char**,
7979
const char*, size_t, RAI_Error*))
8080
(unsigned long) dlsym(handle, "RAI_ModelCreateTF");
@@ -140,7 +140,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
140140
}
141141
init_backend(RedisModule_GetApi);
142142

143-
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
143+
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
144144
const char*, size_t, RAI_Error*))
145145
(unsigned long) dlsym(handle, "RAI_ModelCreateTFLite");
146146
if (backend.model_create == NULL) {
@@ -205,7 +205,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
205205
}
206206
init_backend(RedisModule_GetApi);
207207

208-
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
208+
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
209209
const char*, size_t, RAI_Error*))
210210
(unsigned long) dlsym(handle, "RAI_ModelCreateTorch");
211211
if (backend.model_create == NULL) {
@@ -294,7 +294,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) {
294294
}
295295
init_backend(RedisModule_GetApi);
296296

297-
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
297+
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
298298
const char*, size_t, RAI_Error*))
299299
(unsigned long) dlsym(handle, "RAI_ModelCreateORT");
300300
if (backend.model_create == NULL) {

src/backends.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#include "err.h"
99

1010
typedef struct RAI_LoadedBackend {
11-
RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*,
11+
RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*, RAI_ModelOpts,
1212
size_t, const char**, size_t, const char**,
1313
const char*, size_t, RAI_Error*);
14-
RAI_Model* (*model_create)(RAI_Backend, const char*,
14+
RAI_Model* (*model_create)(RAI_Backend, const char*, RAI_ModelOpts,
1515
const char*, size_t, RAI_Error*);
1616
void (*model_free)(RAI_Model*, RAI_Error*);
1717
int (*model_run)(RAI_ModelRunCtx*, RAI_Error*);

0 commit comments

Comments
 (0)