From 040043d2b3c77093f42db82e38601f7ccd3ac3f7 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Wed, 26 May 2021 10:41:36 +0300 Subject: [PATCH 01/27] Introduce kill switch mechanism for onnxruntime sessions (not ready yet) --- src/backends/backends.c | 11 +++++++ src/backends/backends.h | 3 ++ src/backends/onnxruntime.c | 5 ++- src/config/config.h | 2 +- src/execution/background_workers.c | 2 ++ src/execution/onnx_timeout.c | 49 ++++++++++++++++++++++++++++++ src/execution/onnx_timeout.h | 21 +++++++++++++ src/redisai.c | 8 +++++ 8 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 src/execution/onnx_timeout.c create mode 100644 src/execution/onnx_timeout.h diff --git a/src/backends/backends.c b/src/backends/backends.c index f647fa585..7b15ab09f 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -14,6 +14,7 @@ #include #include #include +#include "execution/onnx_timeout.h" #include "redismodule.h" @@ -469,9 +470,19 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { "not loaded from %s", path); } + backend.terminate_run_session = + (void (*)(void *))(unsigned long)dlsym(handle, "RAI_TerminateRunSessionORT"); + if (backend.terminate_run_session == NULL) { + dlclose(handle); + RedisModule_Log(ctx, "warning", + "Backend does not export RAI_TerminateRunSessionORT. ONNX backend " + "not loaded from %s", + path); + } RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); + return REDISMODULE_OK; } diff --git a/src/backends/backends.h b/src/backends/backends.h index 6fd4d1e80..a93596c35 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -81,6 +81,9 @@ typedef struct RAI_LoadedBackend { // Returns the number of times that Redis accessed backend allocator. unsigned long long (*get_memory_access_num)(void); + + // Kill run session (for stopping long runs). + void (*terminate_run_session)(void *); } RAI_LoadedBackend; typedef struct RAI_LoadedBackends { diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 1d493205f..408650ba1 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -5,6 +5,7 @@ #include "util/arr.h" #include "backends/onnxruntime.h" #include "redis_ai_objects/tensor.h" +#include "execution/onnx_timeout.h" #include "onnxruntime_c_api.h" @@ -554,10 +555,12 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { outputs = array_append(outputs, NULL); } - OrtRunOptions *run_options = NULL; + OrtRunOptions *run_options; + ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); + //ReplaceRunSessionCtx() for (uint32_t i = 0; i < ninputs; i++) { status = ort->AllocatorFree(global_allocator, (void *)input_names[i]); diff --git a/src/config/config.h b/src/config/config.h index d8fc461fe..a05391810 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -19,7 +19,7 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; //#define RAI_COPY_RUN_INPUT #define RAI_COPY_RUN_OUTPUT #define RAI_PRINT_BACKEND_ERRORS -#define REDISAI_DEFAULT_THREADS_PER_QUEUE 1 +#define REDISAI_DEFAULT_THREADS_PER_QUEUE 4 #define REDISAI_DEFAULT_INTRA_OP_PARALLELISM 0 #define REDISAI_DEFAULT_INTER_OP_PARALLELISM 0 #define REDISAI_DEFAULT_MODEL_CHUNK_SIZE 535822336 // (511 * 1024 * 1024) diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index b1b394c72..1db09ceaf 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -16,6 +16,8 @@ #include #include #include +#include +#include #include "redisai.h" #include "run_info.h" #include "background_workers.h" diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c new file mode 100644 index 000000000..92d2be423 --- /dev/null +++ b/src/execution/onnx_timeout.c @@ -0,0 +1,49 @@ +#include "onnx_timeout.h" +#include "util/arr.h" +#include + +// Gets the current time in milliseconds. +long long _mstime(void) { + struct timeval tv; + long long ust; + + gettimeofday(&tv, NULL); + ust = ((long long)tv.tv_sec) * 1000000; + ust += tv.tv_usec; + return ust/1000; +} + +int CreateGlobalOnnxRunSessions(pthread_t *working_thread_ids, size_t size) { + OnnxRunSessions = array_new(onnxRunSessionCtx *, size); + for (size_t i = 0; i < size; i++) { + OnnxRunSessions = array_append(OnnxRunSessions, NULL); + } + return REDISMODULE_OK; +} + +void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, + uint64_t subevent, void *data) { + + const OrtApi *ort = OrtGetApiBase()->GetApi(1); + size_t len = array_len(OnnxRunSessions); + for (size_t i = 0; i < len; i++) { + if (OnnxRunSessions[i] == NULL) { + continue; + } + long long currTime = _mstime(); + if (currTime - OnnxRunSessions[i]->queuingTime > ONNX_MAX_RUNTIME) { + ort->RunOptionsSetTerminate(OnnxRunSessions[i]->runOptions); + } + } +} + +void ReplaceRunSessionCtx(size_t index, OrtRunOptions *newRunOptions) { + const OrtApi *ort = OrtGetApiBase()->GetApi(1); + if (OnnxRunSessions[index] != NULL) { + ort->ReleaseRunOptions(OnnxRunSessions[index]->runOptions); + RedisModule_Free(OnnxRunSessions[index]); + } + onnxRunSessionCtx *runSessionCtx = RedisModule_Alloc(sizeof(onnxRunSessionCtx)); + runSessionCtx->runOptions = newRunOptions; + runSessionCtx->queuingTime = _mstime(); +} diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h new file mode 100644 index 000000000..ddae72791 --- /dev/null +++ b/src/execution/onnx_timeout.h @@ -0,0 +1,21 @@ +#pragma once + +#include "backends/onnxruntime.h" +#include "onnxruntime_c_api.h" + +// The maximum time in milliseconds before killing onnx run session. +#define ONNX_MAX_RUNTIME 5000 + +typedef struct onnxRunSessionCtx { + long long queuingTime; + OrtRunOptions* runOptions; +} onnxRunSessionCtx; + +onnxRunSessionCtx **OnnxRunSessions; + +int CreateGlobalOnnxRunSessions(pthread_t *working_thread_ids, size_t size) + +void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, + uint64_t subevent, void *data); + +void ReplaceRunSessionCtx(size_t index, OrtRunOptions *runOptions); diff --git a/src/redisai.c b/src/redisai.c index 0baff87f5..d6a58770f 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -26,6 +26,7 @@ #include #include #include +#include #include "rmutil/alloc.h" #include "rmutil/args.h" @@ -1476,6 +1477,13 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_Log(ctx, "warning", "Queue not initialized for device CPU"); return REDISMODULE_ERR; } + for (size_t i = 0; i < perqueueThreadPoolSize; i++) { + RedisModule_Log(ctx, "warning", "thread id in index %zu is %lu", i, + run_queue_info->threads[i]); + } + // Create a global array of onnx runSessions, with an entry for every working thread. + CreateGlobalOnnxRunSessions(run_queue_info->threads, perqueueThreadPoolSize); + RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, OnnxEnforceTimeoutCallback); run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); From 0961e6688ab55816e415ec8b2b1ad63abaafff35 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 27 May 2021 15:03:55 +0300 Subject: [PATCH 02/27] Putting login in onnx backend (not ready yet) --- src/CMakeLists.txt | 2 ++ src/backends/backends.c | 12 +++++++---- src/backends/backends.h | 5 +++-- src/backends/onnxruntime.c | 24 +++++++++++++++++++--- src/backends/onnxruntime.h | 1 + src/execution/background_workers.c | 20 ++++++++++++++++--- src/execution/background_workers.h | 5 +++++ src/execution/onnx_timeout.c | 32 +++++++++++++++++------------- src/execution/onnx_timeout.h | 12 ++++++----- src/redisai.c | 4 +--- 10 files changed, 83 insertions(+), 34 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 901cc7dbd..13e894e20 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -83,8 +83,10 @@ ENDIF() IF(BUILD_ORT) ADD_LIBRARY(redisai_onnxruntime_obj OBJECT backends/onnxruntime.c + execution/onnx_timeout.c ${BACKEND_COMMON_SRC} ) + SET_PROPERTY(TARGET redisai_onnxruntime_obj PROPERTY ENABLE_EXPORTS 1) ENDIF() INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/backends/backends.c b/src/backends/backends.c index 7b15ab09f..dd71329b7 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -14,6 +14,7 @@ #include #include #include +#include #include "execution/onnx_timeout.h" #include "redismodule.h" @@ -470,16 +471,19 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { "not loaded from %s", path); } - backend.terminate_run_session = - (void (*)(void *))(unsigned long)dlsym(handle, "RAI_TerminateRunSessionORT"); - if (backend.terminate_run_session == NULL) { + + backend.enforce_runtime_duration = + (void (*)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *))(unsigned long)dlsym(handle, "OnnxEnforceTimeoutCallback"); + if (backend.enforce_runtime_duration == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", - "Backend does not export RAI_TerminateRunSessionORT. ONNX backend " + "Backend does not export OnnxEnforceTimeoutCallback. ONNX backend " "not loaded from %s", path); } + RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, backend.enforce_runtime_duration); + RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); diff --git a/src/backends/backends.h b/src/backends/backends.h index a93596c35..697408453 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -82,8 +82,9 @@ typedef struct RAI_LoadedBackend { // Returns the number of times that Redis accessed backend allocator. unsigned long long (*get_memory_access_num)(void); - // Kill run session (for stopping long runs). - void (*terminate_run_session)(void *); + // Kill run session callback (for stopping long runs). + void (*enforce_runtime_duration)(RedisModuleCtx *, RedisModuleEvent, + uint64_t, void *); } RAI_LoadedBackend; typedef struct RAI_LoadedBackends { diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 408650ba1..ef48e2485 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -2,10 +2,11 @@ #include #include "backends/util.h" #include +#include +#include #include "util/arr.h" #include "backends/onnxruntime.h" #include "redis_ai_objects/tensor.h" -#include "execution/onnx_timeout.h" #include "onnxruntime_c_api.h" @@ -25,6 +26,10 @@ OrtAllocator *global_allocator = NULL; unsigned long long OnnxMemory = 0; unsigned long long OnnxMemoryAccessCounter = 0; +// Globals from RedisAI to use for handling sessions timeouts/ +long long perqueueThreadPoolSize = 4; +pthread_key_t tls_id_key; + const OrtMemoryInfo *AllocatorInfo(const OrtAllocator *allocator) { (void)allocator; const OrtApi *ort = OrtGetApiBase()->GetApi(1); @@ -88,6 +93,12 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)) { get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext)); get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext)); get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); + + + // Create a global array of onnx runSessions, with an entry for every working thread. + long long size = perqueueThreadPoolSize; + CreateGlobalOnnxRunSessions(size); + return REDISMODULE_OK; } @@ -508,6 +519,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { array_new_on_stack(const char *, 5, output_names); array_new_on_stack(OrtValue *, 5, inputs); array_new_on_stack(OrtValue *, 5, outputs); + OrtRunOptions *run_options = NULL; OrtTensorTypeAndShapeInfo *info = NULL; { size_t n_input_nodes; @@ -555,12 +567,15 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { outputs = array_append(outputs, NULL); } - OrtRunOptions *run_options; ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); + int *thread_ind = (int *)pthread_getspecific(tls_id_key); + SetRunSessionCtx(*thread_ind, run_options); + ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - //ReplaceRunSessionCtx() + ClearRunSessionCtx(*thread_ind); + run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { status = ort->AllocatorFree(global_allocator, (void *)input_names[i]); @@ -648,6 +663,9 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { if (info) { ort->ReleaseTensorTypeAndShapeInfo(info); } + if (run_options) { + ort->ReleaseRunOptions(run_options); + } return REDISMODULE_ERR; } diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index ec282bac3..f2b0ae8ce 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -5,6 +5,7 @@ #include "redis_ai_objects/tensor_struct.h" #include "redis_ai_objects/model_struct.h" + unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 1db09ceaf..cfed55091 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -40,6 +40,8 @@ int pthread_setname_np(const char *name); #endif #endif +pthread_key_t tls_id_key; + int freeRunQueueInfo(RunQueueInfo *info) { int result = REDISMODULE_OK; if (info->run_queue) { @@ -96,8 +98,11 @@ int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info) { (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * perqueueThreadPoolSize); /* create threads */ for (int i = 0; i < perqueueThreadPoolSize; i++) { + WorkerThreadInfo *thread_info = RedisModule_Alloc(sizeof(WorkerThreadInfo)); + thread_info->run_queue_info = *run_queue_info; + thread_info->id = i; if (pthread_create(&((*run_queue_info)->threads[i]), NULL, RedisAI_Run_ThreadMain, - *run_queue_info) != 0) { + thread_info) != 0) { freeRunQueueInfo(*run_queue_info); return REDISMODULE_ERR; } @@ -111,6 +116,13 @@ int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info) { return result; } +void _SaveThreadId(int id) { + pthread_key_create(&tls_id_key, RedisModule_Free); + int *id_value = RedisModule_Alloc(sizeof(int)); + *id_value = id; + pthread_setspecific(tls_id_key, id_value); +} + /** * @brief In case a DAG Op can express a MINBATCHSIZE > 0 with a MINBATCHTIMEOUT * in milliseconds, we will use a timedwait of one millisecond to evaluate @@ -301,8 +313,10 @@ static RedisAI_RunInfo **_BGThread_BatchOperations(RunQueueInfo *run_queue_info, } void *RedisAI_Run_ThreadMain(void *arg) { - RunQueueInfo *run_queue_info = (RunQueueInfo *)arg; - RAI_PTHREAD_SETNAME("redisai_bthread"); + WorkerThreadInfo *thread_info = (WorkerThreadInfo *)arg; + RunQueueInfo *run_queue_info = thread_info->run_queue_info; + _SaveThreadId(thread_info->id); + RedisModule_Free(thread_info); RedisAI_RunInfo **batch_rinfo = array_new(RedisAI_RunInfo *, 1); pthread_mutex_lock(&run_queue_info->run_queue_mutex); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 43d7099b0..da0bd4fe1 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -41,6 +41,11 @@ typedef struct RunQueueInfo { char *devicestr; } RunQueueInfo; +typedef struct WorkerThreadInfo { + RunQueueInfo *run_queue_info; + int id; +} WorkerThreadInfo; + int freeRunQueueInfo(RunQueueInfo *info); /* Ensure that the the run queue for the device exists. diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index 92d2be423..9546965a7 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -3,7 +3,7 @@ #include // Gets the current time in milliseconds. -long long _mstime(void) { +static long long _mstime(void) { struct timeval tv; long long ust; @@ -13,10 +13,11 @@ long long _mstime(void) { return ust/1000; } -int CreateGlobalOnnxRunSessions(pthread_t *working_thread_ids, size_t size) { - OnnxRunSessions = array_new(onnxRunSessionCtx *, size); +int CreateGlobalOnnxRunSessions(long long size) { + OnnxRunSessions = array_new(OnnxRunSessionCtx *, size); for (size_t i = 0; i < size; i++) { - OnnxRunSessions = array_append(OnnxRunSessions, NULL); + OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); + OnnxRunSessions = array_append(OnnxRunSessions, entry); } return REDISMODULE_OK; } @@ -27,23 +28,26 @@ void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, const OrtApi *ort = OrtGetApiBase()->GetApi(1); size_t len = array_len(OnnxRunSessions); for (size_t i = 0; i < len; i++) { - if (OnnxRunSessions[i] == NULL) { + if (OnnxRunSessions[i]->runOptions == NULL) { continue; } - long long currTime = _mstime(); - if (currTime - OnnxRunSessions[i]->queuingTime > ONNX_MAX_RUNTIME) { + long long curr_time = _mstime(); + if (curr_time - OnnxRunSessions[i]->queuingTime > ONNX_MAX_RUNTIME) { ort->RunOptionsSetTerminate(OnnxRunSessions[i]->runOptions); } } } -void ReplaceRunSessionCtx(size_t index, OrtRunOptions *newRunOptions) { - const OrtApi *ort = OrtGetApiBase()->GetApi(1); - if (OnnxRunSessions[index] != NULL) { - ort->ReleaseRunOptions(OnnxRunSessions[index]->runOptions); - RedisModule_Free(OnnxRunSessions[index]); - } - onnxRunSessionCtx *runSessionCtx = RedisModule_Alloc(sizeof(onnxRunSessionCtx)); +void SetRunSessionCtx(size_t index, OrtRunOptions *newRunOptions) { + OnnxRunSessionCtx *runSessionCtx = OnnxRunSessions[index]; + RedisModule_Assert(runSessionCtx->runOptions == NULL); runSessionCtx->runOptions = newRunOptions; runSessionCtx->queuingTime = _mstime(); } + +void ClearRunSessionCtx(size_t index) { + const OrtApi *ort = OrtGetApiBase()->GetApi(1); + OnnxRunSessionCtx *runSessionCtx = OnnxRunSessions[index]; + ort->ReleaseRunOptions(runSessionCtx->runOptions); + runSessionCtx->runOptions = NULL; +} diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index ddae72791..e4cd70bd7 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -6,16 +6,18 @@ // The maximum time in milliseconds before killing onnx run session. #define ONNX_MAX_RUNTIME 5000 -typedef struct onnxRunSessionCtx { +typedef struct OnnxRunSessionCtx { long long queuingTime; OrtRunOptions* runOptions; -} onnxRunSessionCtx; +} OnnxRunSessionCtx; -onnxRunSessionCtx **OnnxRunSessions; +OnnxRunSessionCtx **OnnxRunSessions; -int CreateGlobalOnnxRunSessions(pthread_t *working_thread_ids, size_t size) +int CreateGlobalOnnxRunSessions(long long size); void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data); -void ReplaceRunSessionCtx(size_t index, OrtRunOptions *runOptions); +void SetRunSessionCtx(size_t index, OrtRunOptions *newRunOptions); + +void ClearRunSessionCtx(size_t index); diff --git a/src/redisai.c b/src/redisai.c index d6a58770f..9e6472cba 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -54,6 +54,7 @@ #endif #endif + extern int redisMajorVersion; extern int redisMinorVersion; extern int redisPatchVersion; @@ -1481,9 +1482,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_Log(ctx, "warning", "thread id in index %zu is %lu", i, run_queue_info->threads[i]); } - // Create a global array of onnx runSessions, with an entry for every working thread. - CreateGlobalOnnxRunSessions(run_queue_info->threads, perqueueThreadPoolSize); - RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, OnnxEnforceTimeoutCallback); run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); From 9d5fbf8700a413c875165f174646939f8c868445 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 27 May 2021 17:37:04 +0300 Subject: [PATCH 03/27] WIP --- src/backends/backends.c | 18 +++++++++++++++--- src/backends/onnxruntime.c | 17 +++++++++++------ src/backends/onnxruntime.h | 4 ++-- src/execution/background_workers.c | 16 ++++++++++++++-- src/execution/background_workers.h | 5 +++++ 5 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index dd71329b7..97bd702ba 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -19,6 +19,18 @@ #include "redismodule.h" +int RAI_GetApi(const char *funcname, void **targetPtrPtr) { + + if (strcmp("ThreadIdKey", funcname) == 0) { + *targetPtrPtr = ThreadIdKey; + } else if (strcmp("NumThreadsPerQueue", funcname) == 0) { + *targetPtrPtr = NumThreadsPerQueue; + } else { + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} + RedisModuleString *RAI_GetModulePath(RedisModuleCtx *ctx) { Dl_info info; RedisModuleString *module_path = NULL; @@ -384,9 +396,9 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { RAI_LoadedBackend backend = {0}; - int (*init_backend)(int (*)(const char *, void *)); + int (*init_backend)(int (*)(const char *, void *), int (*)(const char *, void *)); init_backend = - (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym(handle, "RAI_InitBackendORT"); + (int (*)(int (*)(const char *, void *), int (*)(const char *, void *)))(unsigned long)dlsym(handle, "RAI_InitBackendORT"); if (init_backend == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", @@ -395,7 +407,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { path); return REDISMODULE_ERR; } - init_backend(RedisModule_GetApi); + init_backend(RedisModule_GetApi, (int (*)(const char*, void*)) RAI_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index ef48e2485..aaf5130c8 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -3,6 +3,7 @@ #include "backends/util.h" #include #include +#include "execution/background_workers.h" #include #include "util/arr.h" #include "backends/onnxruntime.h" @@ -26,9 +27,8 @@ OrtAllocator *global_allocator = NULL; unsigned long long OnnxMemory = 0; unsigned long long OnnxMemoryAccessCounter = 0; -// Globals from RedisAI to use for handling sessions timeouts/ -long long perqueueThreadPoolSize = 4; -pthread_key_t tls_id_key; +pthread_key_t (*RedisAI_ThreadIdKey)(void); +long long (*RedisAI_NumThreadsPerQueue)(void); const OrtMemoryInfo *AllocatorInfo(const OrtAllocator *allocator) { (void)allocator; @@ -83,7 +83,9 @@ unsigned long long RAI_GetMemoryInfoORT() { return OnnxMemory; } unsigned long long RAI_GetMemoryAccessORT() { return OnnxMemoryAccessCounter; } -int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)) { +int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), + int (*get_api_fn_rai)(const char *, void *)) { + // Export redis callbacks. get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc)); get_api_fn("RedisModule_Calloc", ((void **)&RedisModule_Calloc)); get_api_fn("RedisModule_Free", ((void **)&RedisModule_Free)); @@ -94,9 +96,12 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)) { get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext)); get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); + // Export RedisAI callbacks. + get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadIdKey)); + get_api_fn_rai("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); // Create a global array of onnx runSessions, with an entry for every working thread. - long long size = perqueueThreadPoolSize; + long long size = RedisAI_NumThreadsPerQueue(); CreateGlobalOnnxRunSessions(size); return REDISMODULE_OK; @@ -568,7 +573,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { } ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); - int *thread_ind = (int *)pthread_getspecific(tls_id_key); + int *thread_ind = (int *)pthread_getspecific(RedisAI_ThreadIdKey()); SetRunSessionCtx(*thread_ind, run_options); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index f2b0ae8ce..65b3e480d 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -5,12 +5,12 @@ #include "redis_ai_objects/tensor_struct.h" #include "redis_ai_objects/model_struct.h" - unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); -int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *)); +int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), + int (*get_api_fn_rai)(const char *, void *)); RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *err); diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index cfed55091..922079be1 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -40,7 +40,7 @@ int pthread_setname_np(const char *name); #endif #endif -pthread_key_t tls_id_key; +pthread_key_t *tls_id_keys; // tls_id_keys[0] is CPU threads int freeRunQueueInfo(RunQueueInfo *info) { int result = REDISMODULE_OK; @@ -96,6 +96,9 @@ int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info) { pthread_mutex_init(&(*run_queue_info)->run_queue_mutex, NULL); (*run_queue_info)->threads = (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * perqueueThreadPoolSize); + pthread_key_t thread_id_key; + pthread_key_create(&thread_id_key, RedisModule_Free); + (*run_queue_info)->thread_id_key = thread_id_key; /* create threads */ for (int i = 0; i < perqueueThreadPoolSize; i++) { WorkerThreadInfo *thread_info = RedisModule_Alloc(sizeof(WorkerThreadInfo)); @@ -116,13 +119,22 @@ int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info) { return result; } +pthread_key_t ThreadIdKey() { + return tls_id_key; +} + +long long NumThreadsPerQueue() { + return perqueueThreadPoolSize; +} + void _SaveThreadId(int id) { - pthread_key_create(&tls_id_key, RedisModule_Free); int *id_value = RedisModule_Alloc(sizeof(int)); *id_value = id; pthread_setspecific(tls_id_key, id_value); } + + /** * @brief In case a DAG Op can express a MINBATCHSIZE > 0 with a MINBATCHTIMEOUT * in milliseconds, we will use a timedwait of one millisecond to evaluate diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index da0bd4fe1..a68b5ea34 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -38,6 +38,7 @@ typedef struct RunQueueInfo { pthread_cond_t queue_condition_var; queue *run_queue; pthread_t *threads; + pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. char *devicestr; } RunQueueInfo; @@ -51,3 +52,7 @@ int freeRunQueueInfo(RunQueueInfo *info); /* Ensure that the the run queue for the device exists. * If not, create it. */ int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info); + +pthread_key_t ThreadIdKey(void); + +long long NumThreadsPerQueue(void); From 22662bdd2653959da348af1c7cb12b9b2de57a3b Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sat, 29 May 2021 23:53:07 +0300 Subject: [PATCH 04/27] Refactor background workers + add support to kill switch in onnx (for multiple devices). Tests pass (new feature hasn't tested yet) --- src/CMakeLists.txt | 2 + src/backends/backends.c | 48 +- src/backends/backends.h | 6 +- src/backends/onnxruntime.c | 20 +- src/backends/onnxruntime.h | 5 +- src/config/config.c | 6 +- src/execution/DAG/dag_execute.c | 11 +- src/execution/background_workers.c | 162 +-- src/execution/background_workers.h | 22 +- src/execution/onnx_timeout.c | 85 +- src/execution/onnx_timeout.h | 20 +- src/execution/parsing/deprecated.c | 15 +- src/redisai.c | 106 +- src/util/queue.c | 2 +- src/util/queue.h | 2 +- src/util/rax.c | 2003 ++++++++++++++++++++++++++++ src/util/rax.h | 218 +++ src/util/rax_malloc.h | 43 + src/util/string_utils.c | 15 + src/util/string_utils.h | 1 + 20 files changed, 2557 insertions(+), 235 deletions(-) create mode 100644 src/util/rax.c create mode 100644 src/util/rax.h create mode 100644 src/util/rax_malloc.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 13e894e20..c4fc5f000 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,7 @@ ADD_LIBRARY(redisai_obj OBJECT util/dictionaries.c util/queue.c util/string_utils.c + util/rax.c redisai.c execution/command_parser.c execution/parsing/deprecated.c @@ -84,6 +85,7 @@ IF(BUILD_ORT) ADD_LIBRARY(redisai_onnxruntime_obj OBJECT backends/onnxruntime.c execution/onnx_timeout.c + util/rax.c ${BACKEND_COMMON_SRC} ) SET_PROPERTY(TARGET redisai_onnxruntime_obj PROPERTY ENABLE_EXPORTS 1) diff --git a/src/backends/backends.c b/src/backends/backends.c index 97bd702ba..691e82e1e 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -19,16 +19,16 @@ #include "redismodule.h" -int RAI_GetApi(const char *funcname, void **targetPtrPtr) { - - if (strcmp("ThreadIdKey", funcname) == 0) { - *targetPtrPtr = ThreadIdKey; - } else if (strcmp("NumThreadsPerQueue", funcname) == 0) { - *targetPtrPtr = NumThreadsPerQueue; - } else { - return REDISMODULE_ERR; - } - return REDISMODULE_OK; +int RAI_GetApi(const char *func_name, void **targetPtrPtr) { + + if (strcmp("ThreadIdKey", func_name) == 0) { + *targetPtrPtr = GetQueueThreadIdKey; + } else if (strcmp("NumThreadsPerQueue", func_name) == 0) { + *targetPtrPtr = GetNumThreadsPerQueue; + } else { + return REDISMODULE_ERR; + } + return REDISMODULE_OK; } RedisModuleString *RAI_GetModulePath(RedisModuleCtx *ctx) { @@ -397,8 +397,8 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { RAI_LoadedBackend backend = {0}; int (*init_backend)(int (*)(const char *, void *), int (*)(const char *, void *)); - init_backend = - (int (*)(int (*)(const char *, void *), int (*)(const char *, void *)))(unsigned long)dlsym(handle, "RAI_InitBackendORT"); + init_backend = (int (*)(int (*)(const char *, void *), int (*)(const char *, void *)))( + unsigned long)dlsym(handle, "RAI_InitBackendORT"); if (init_backend == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", @@ -407,7 +407,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { path); return REDISMODULE_ERR; } - init_backend(RedisModule_GetApi, (int (*)(const char*, void*)) RAI_GetApi); + init_backend(RedisModule_GetApi, (int (*)(const char *, void *))RAI_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, @@ -485,16 +485,28 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { } backend.enforce_runtime_duration = - (void (*)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *))(unsigned long)dlsym(handle, "OnnxEnforceTimeoutCallback"); + (void (*)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *))(unsigned long)dlsym( + handle, "OnnxEnforceTimeoutCallback"); if (backend.enforce_runtime_duration == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", - "Backend does not export OnnxEnforceTimeoutCallback. ONNX backend " - "not loaded from %s", - path); + "Backend does not export OnnxEnforceTimeoutCallback. ONNX backend " + "not loaded from %s", + path); + } + + backend.add_new_device = + (int (*)(const char *))(unsigned long)dlsym(handle, "AddDeviceToGlobalRunSessions"); + if (backend.add_new_device == NULL) { + dlclose(handle); + RedisModule_Log(ctx, "warning", + "Backend does not export AddDeviceToGlobalRunSessions. ONNX backend " + "not loaded from %s", + path); } - RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, backend.enforce_runtime_duration); + RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, + backend.enforce_runtime_duration); RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); diff --git a/src/backends/backends.h b/src/backends/backends.h index 697408453..bf67c87b6 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -82,9 +82,11 @@ typedef struct RAI_LoadedBackend { // Returns the number of times that Redis accessed backend allocator. unsigned long long (*get_memory_access_num)(void); + // A callback for to use whenever a new device is introduced. + int (*add_new_device)(const char *); + // Kill run session callback (for stopping long runs). - void (*enforce_runtime_duration)(RedisModuleCtx *, RedisModuleEvent, - uint64_t, void *); + void (*enforce_runtime_duration)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *); } RAI_LoadedBackend; typedef struct RAI_LoadedBackends { diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index aaf5130c8..581de22f5 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -27,9 +27,6 @@ OrtAllocator *global_allocator = NULL; unsigned long long OnnxMemory = 0; unsigned long long OnnxMemoryAccessCounter = 0; -pthread_key_t (*RedisAI_ThreadIdKey)(void); -long long (*RedisAI_NumThreadsPerQueue)(void); - const OrtMemoryInfo *AllocatorInfo(const OrtAllocator *allocator) { (void)allocator; const OrtApi *ort = OrtGetApiBase()->GetApi(1); @@ -84,7 +81,7 @@ unsigned long long RAI_GetMemoryInfoORT() { return OnnxMemory; } unsigned long long RAI_GetMemoryAccessORT() { return OnnxMemoryAccessCounter; } int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), - int (*get_api_fn_rai)(const char *, void *)) { + int (*get_api_fn_rai)(const char *, void *)) { // Export redis callbacks. get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc)); get_api_fn("RedisModule_Calloc", ((void **)&RedisModule_Calloc)); @@ -97,12 +94,11 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); // Export RedisAI callbacks. - get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadIdKey)); - get_api_fn_rai("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); + get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadIdKey)); + get_api_fn_rai("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); // Create a global array of onnx runSessions, with an entry for every working thread. - long long size = RedisAI_NumThreadsPerQueue(); - CreateGlobalOnnxRunSessions(size); + CreateGlobalOnnxRunSessions(); return REDISMODULE_OK; } @@ -573,13 +569,13 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { } ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); - int *thread_ind = (int *)pthread_getspecific(RedisAI_ThreadIdKey()); - SetRunSessionCtx(*thread_ind, run_options); - + // Set the created run option in the global RunSessions and return it. + OnnxRunSessionCtx *run_session_ctx = + SetGetRunSessionCtx(mctxs[0]->model->devicestr, run_options); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - ClearRunSessionCtx(*thread_ind); + ClearRunSessionCtx(run_session_ctx); run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index 65b3e480d..eae2aec8b 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -9,8 +9,11 @@ unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); +pthread_key_t (*RedisAI_ThreadIdKey)(void); +long long (*RedisAI_NumThreadsPerQueue)(void); + int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), - int (*get_api_fn_rai)(const char *, void *)); + int (*get_api_fn_rai)(const char *, void *)); RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *err); diff --git a/src/config/config.c b/src/config/config.c index a8e43825a..e21a2402d 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -147,11 +147,11 @@ int RedisAI_Config_BackendsPath(RedisModuleCtx *ctx, const char *path) { * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string) { - int result = RedisModule_StringToLongLong(num_threads_string, &perqueueThreadPoolSize); + int result = RedisModule_StringToLongLong(num_threads_string, &ThreadPoolSizePerQueue); // make sure the number of threads is a positive integer // if not set the value to the default - if (result == REDISMODULE_OK && perqueueThreadPoolSize < 1) { - perqueueThreadPoolSize = REDISAI_DEFAULT_THREADS_PER_QUEUE; + if (result == REDISMODULE_OK && ThreadPoolSizePerQueue < 1) { + ThreadPoolSizePerQueue = REDISAI_DEFAULT_THREADS_PER_QUEUE; result = REDISMODULE_ERR; } return result; diff --git a/src/execution/DAG/dag_execute.c b/src/execution/DAG/dag_execute.c index e676581a7..78e5390a3 100644 --- a/src/execution/DAG/dag_execute.c +++ b/src/execution/DAG/dag_execute.c @@ -105,9 +105,8 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo) { RunQueueInfo **run_queues_info = array_new(RunQueueInfo *, ndevices); for (long long i = 0; i < ndevices; i++) { - const char *devicestr = devices[i]; - RunQueueInfo *run_queue_info = NULL; - if (ensureRunQueue(devicestr, &run_queue_info) == REDISMODULE_ERR) { + const char *device_str = devices[i]; + if (!IsRunQueueExists(device_str) == REDISMODULE_ERR) { // A device run queue was not created properly, so we free everything, // set an error and finish. array_free(devices); @@ -119,6 +118,12 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo) { RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR Queue not initialized for device"); return REDISMODULE_ERR; } + + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + RunQueueInfo *run_queue_info = + raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len); run_queues_info = array_append(run_queues_info, run_queue_info); } for (long long i = 0; i < ndevices; i++) { diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 922079be1..99f103e98 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -17,10 +17,13 @@ #include #include #include +#include "string_utils.h" +#include "backends/backends.h" #include #include "redisai.h" #include "run_info.h" #include "background_workers.h" +#include "onnx_timeout.h" /* Define for RedisAI thread name setter */ #ifdef __linux__ @@ -40,101 +43,103 @@ int pthread_setname_np(const char *name); #endif #endif -pthread_key_t *tls_id_keys; // tls_id_keys[0] is CPU threads +void *RedisAI_Run_ThreadMain(void *arg); -int freeRunQueueInfo(RunQueueInfo *info) { - int result = REDISMODULE_OK; - if (info->run_queue) { - RedisModule_Free(info->run_queue); +RunQueueInfo *CreateRunQueue(const char *device_str) { + + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + + // Create new run queue and initialize its inner fields. + RunQueueInfo *run_queue_info = RedisModule_Alloc(sizeof(RunQueueInfo)); + run_queue_info->run_queue = queueCreate(); + run_queue_info->device_str = RedisModule_Strdup(upper_device_str); + pthread_cond_init(&(run_queue_info->queue_condition_var), NULL); + pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); + run_queue_info->threads = + (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * ThreadPoolSizePerQueue); + pthread_key_create(&(run_queue_info->thread_id_key), RedisModule_Free); + + // Save device with its associate run queue info in rax. + /*todo: this should be protected from parallel writs. can add a lock, + or calling this function from main thread only in modelstore command.*/ + if (raxInsert(RunQueues, (unsigned char *)upper_device_str, device_str_len, run_queue_info, + NULL) != 1) { + RunQueueInfoFree(run_queue_info); + return NULL; } - RedisModule_Free(info->devicestr); - if (info->threads) { - /* Wait for workers to exit */ - for (int i = 0; i < perqueueThreadPoolSize; i++) { - const int rtn = pthread_join(info->threads[i], NULL); - if (rtn != 0) { - result = REDISMODULE_ERR; - } + + // Create worker threads. + for (int i = 0; i < ThreadPoolSizePerQueue; i++) { + WorkerThreadInfo *thread_info = RedisModule_Alloc(sizeof(WorkerThreadInfo)); + thread_info->run_queue_info = run_queue_info; + thread_info->id = i; + if (pthread_create(&(run_queue_info->threads[i]), NULL, RedisAI_Run_ThreadMain, + thread_info) != 0) { + raxRemove(RunQueues, (unsigned char *)upper_device_str, device_str_len, NULL); + RunQueueInfoFree(run_queue_info); + return NULL; } - /* Now free pool structure */ - RedisModule_Free(info->threads); } - RedisModule_Free(info); - return result; -} - -void *RedisAI_Run_ThreadMain(void *arg); -char *strToUpper(const char *input) { - char *output = RedisModule_Strdup(input); - size_t output_len = strlen(output); - for (long long i = 0; i < output_len; i++) { - output[i] = toupper(output[i]); + // Add the new device worker threads to onnx run sessions monitoring. + if (RAI_backends.onnx.add_new_device) { + RAI_backends.onnx.add_new_device(device_str); } - return output; + return run_queue_info; } -/* Ensure that the the run queue for the device exists. - * If not, create it. */ -int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info) { - int result = REDISMODULE_ERR; - if (run_queues == NULL) { - return result; - } - - char *devicestr_ = strToUpper(devicestr); - - AI_dictEntry *entry = AI_dictFind(run_queues, devicestr_); - if (entry) { - *run_queue_info = AI_dictGetVal(entry); - result = REDISMODULE_OK; - } else { - *run_queue_info = RedisModule_Alloc(sizeof(RunQueueInfo)); - (*run_queue_info)->run_queue = queueCreate(); - (*run_queue_info)->devicestr = RedisModule_Strdup(devicestr_); - pthread_cond_init(&(*run_queue_info)->queue_condition_var, NULL); - pthread_mutex_init(&(*run_queue_info)->run_queue_mutex, NULL); - (*run_queue_info)->threads = - (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * perqueueThreadPoolSize); - pthread_key_t thread_id_key; - pthread_key_create(&thread_id_key, RedisModule_Free); - (*run_queue_info)->thread_id_key = thread_id_key; - /* create threads */ - for (int i = 0; i < perqueueThreadPoolSize; i++) { - WorkerThreadInfo *thread_info = RedisModule_Alloc(sizeof(WorkerThreadInfo)); - thread_info->run_queue_info = *run_queue_info; - thread_info->id = i; - if (pthread_create(&((*run_queue_info)->threads[i]), NULL, RedisAI_Run_ThreadMain, - thread_info) != 0) { - freeRunQueueInfo(*run_queue_info); - return REDISMODULE_ERR; - } - } - AI_dictAdd(run_queues, (void *)devicestr_, (void *)*run_queue_info); - result = REDISMODULE_OK; +bool IsRunQueueExists(const char *device_str) { + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + if (raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len) == raxNotFound) { + return false; } + return true; +} - RedisModule_Free(devicestr_); +pthread_key_t GetQueueThreadIdKey(const char *device_str) { + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); - return result; + RunQueueInfo *run_queue_info = + (RunQueueInfo *)raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len); + RedisModule_Assert(run_queue_info != raxNotFound); + return run_queue_info->thread_id_key; } -pthread_key_t ThreadIdKey() { - return tls_id_key; -} +long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } + +void RunQueueInfoFree(RunQueueInfo *run_queue_info) { + RedisModule_Assert(queueLength(run_queue_info->run_queue) == 0); + RedisModule_Free(run_queue_info->run_queue); + RedisModule_Free(run_queue_info->device_str); -long long NumThreadsPerQueue() { - return perqueueThreadPoolSize; + // Wait for workers to exit and free the pool. + for (int i = 0; i < ThreadPoolSizePerQueue; i++) { + RedisModule_Assert(pthread_join(run_queue_info->threads[i], NULL) == 0); + RedisModule_Free(run_queue_info->threads); + } + pthread_mutex_destroy(&(run_queue_info->run_queue_mutex)); + pthread_cond_destroy(&(run_queue_info->queue_condition_var)); + pthread_key_delete(run_queue_info->thread_id_key); + RedisModule_Free(run_queue_info); } -void _SaveThreadId(int id) { +/** + * @brief Save the id for some working thread in thread local storage. Every + * device has a designated id key saved within its run_queue_info, which is used + * for storing and retrieving the id in the thread local storage. + */ +static void _SaveThreadId(pthread_key_t thread_id_key, int id) { int *id_value = RedisModule_Alloc(sizeof(int)); *id_value = id; - pthread_setspecific(tls_id_key, id_value); + pthread_setspecific(thread_id_key, id_value); } - - /** * @brief In case a DAG Op can express a MINBATCHSIZE > 0 with a MINBATCHTIMEOUT * in milliseconds, we will use a timedwait of one millisecond to evaluate @@ -203,9 +208,9 @@ static void _BGThread_Execute(RunQueueInfo *run_queue_info, RedisAI_RunInfo **ba // For simplicity, we call into different functions whether the run // is batched or not if (batched_run) { - RedisAI_BatchedDagRunSessionStep(batch_rinfo, run_queue_info->devicestr); + RedisAI_BatchedDagRunSessionStep(batch_rinfo, run_queue_info->device_str); } else { - RedisAI_DagRunSessionStep(batch_rinfo[0], run_queue_info->devicestr); + RedisAI_DagRunSessionStep(batch_rinfo[0], run_queue_info->device_str); } } } @@ -327,8 +332,9 @@ static RedisAI_RunInfo **_BGThread_BatchOperations(RunQueueInfo *run_queue_info, void *RedisAI_Run_ThreadMain(void *arg) { WorkerThreadInfo *thread_info = (WorkerThreadInfo *)arg; RunQueueInfo *run_queue_info = thread_info->run_queue_info; - _SaveThreadId(thread_info->id); + _SaveThreadId(run_queue_info->thread_id_key, thread_info->id); RedisModule_Free(thread_info); + RedisAI_RunInfo **batch_rinfo = array_new(RedisAI_RunInfo *, 1); pthread_mutex_lock(&run_queue_info->run_queue_mutex); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index a68b5ea34..147f07c82 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -27,19 +27,19 @@ #include "redis_ai_objects/stats.h" #include "redis_ai_objects/tensor.h" #include "util/arr.h" -#include "util/dict.h" +#include "util/rax.h" #include "util/queue.h" -AI_dict *run_queues; -long long perqueueThreadPoolSize; +rax *RunQueues; +long long ThreadPoolSizePerQueue; typedef struct RunQueueInfo { pthread_mutex_t run_queue_mutex; pthread_cond_t queue_condition_var; queue *run_queue; pthread_t *threads; - pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. - char *devicestr; + pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. + char *device_str; } RunQueueInfo; typedef struct WorkerThreadInfo { @@ -47,12 +47,12 @@ typedef struct WorkerThreadInfo { int id; } WorkerThreadInfo; -int freeRunQueueInfo(RunQueueInfo *info); +void RunQueueInfoFree(RunQueueInfo *info); -/* Ensure that the the run queue for the device exists. - * If not, create it. */ -int ensureRunQueue(const char *devicestr, RunQueueInfo **run_queue_info); +RunQueueInfo *CreateRunQueue(const char *device_str); -pthread_key_t ThreadIdKey(void); +bool IsRunQueueExists(const char *device_str); -long long NumThreadsPerQueue(void); +pthread_key_t GetQueueThreadIdKey(const char *device_str); + +long long GetNumThreadsPerQueue(void); diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index 9546965a7..269a6f697 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -1,6 +1,9 @@ #include "onnx_timeout.h" #include "util/arr.h" #include +#include +#include "util/rax.h" +#include "util/string_utils.h" // Gets the current time in milliseconds. static long long _mstime(void) { @@ -10,44 +13,78 @@ static long long _mstime(void) { gettimeofday(&tv, NULL); ust = ((long long)tv.tv_sec) * 1000000; ust += tv.tv_usec; - return ust/1000; + return ust / 1000; } -int CreateGlobalOnnxRunSessions(long long size) { - OnnxRunSessions = array_new(OnnxRunSessionCtx *, size); +int CreateGlobalOnnxRunSessions() { + OnnxRunSessions = raxNew(); + if (OnnxRunSessions == NULL) { + return REDISMODULE_ERR; + } + return AddDeviceToGlobalRunSessions("CPU"); +} + +int AddDeviceToGlobalRunSessions(const char *device) { + + size_t size = RedisAI_NumThreadsPerQueue(); + // Create array with an entry for every working thread, initialized to NULL. + OnnxRunSessionCtx **device_run_sessions = array_new(OnnxRunSessionCtx *, size); for (size_t i = 0; i < size; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); - OnnxRunSessions = array_append(OnnxRunSessions, entry); + device_run_sessions = array_append(device_run_sessions, entry); + } + // Add the array to the global rax that holds onnx run sessions per device. + size_t device_str_len = strlen(device); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device, upper_device_str, &device_str_len); + if (raxInsert(OnnxRunSessions, (unsigned char *)upper_device_str, device_str_len, + device_run_sessions, NULL) != 1) { + return REDISMODULE_ERR; } return REDISMODULE_OK; } -void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, - uint64_t subevent, void *data) { - +void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, + void *data) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); - size_t len = array_len(OnnxRunSessions); - for (size_t i = 0; i < len; i++) { - if (OnnxRunSessions[i]->runOptions == NULL) { - continue; - } - long long curr_time = _mstime(); - if (curr_time - OnnxRunSessions[i]->queuingTime > ONNX_MAX_RUNTIME) { - ort->RunOptionsSetTerminate(OnnxRunSessions[i]->runOptions); + + raxIterator rax_it; + raxStart(&rax_it, OnnxRunSessions); + raxSeek(&rax_it, "^", NULL, 0); + + // Go over all the possible existing run sessions for every device. + while (raxNext(&rax_it)) { + OnnxRunSessionCtx **onnx_run_sessions_per_device = rax_it.data; + size_t threads_per_device = array_len(onnx_run_sessions_per_device); + for (size_t i = 0; i < threads_per_device; i++) { + if (onnx_run_sessions_per_device[i]->runOptions == NULL) { + continue; + } + long long curr_time = _mstime(); + // Check if a sessions is running for too long, and kill it if so. + if (curr_time - onnx_run_sessions_per_device[i]->queuingTime > ONNX_MAX_RUNTIME) { + ort->RunOptionsSetTerminate(onnx_run_sessions_per_device[i]->runOptions); + } } } } -void SetRunSessionCtx(size_t index, OrtRunOptions *newRunOptions) { - OnnxRunSessionCtx *runSessionCtx = OnnxRunSessions[index]; - RedisModule_Assert(runSessionCtx->runOptions == NULL); - runSessionCtx->runOptions = newRunOptions; - runSessionCtx->queuingTime = _mstime(); +OnnxRunSessionCtx *SetGetRunSessionCtx(const char *device, OrtRunOptions *new_run_options) { + + int *thread_ind = (int *)pthread_getspecific(RedisAI_ThreadIdKey()); + + OnnxRunSessionCtx **device_run_sessions = + raxFind(OnnxRunSessions, (unsigned char *)device, strlen(device)); + RedisModule_Assert(device_run_sessions != raxNotFound); + RedisModule_Assert(device_run_sessions[*thread_ind]->runOptions == NULL); + + device_run_sessions[*thread_ind]->runOptions = new_run_options; + device_run_sessions[*thread_ind]->queuingTime = _mstime(); + return device_run_sessions[*thread_ind]; } -void ClearRunSessionCtx(size_t index) { +void ClearRunSessionCtx(OnnxRunSessionCtx *run_session_ctx) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); - OnnxRunSessionCtx *runSessionCtx = OnnxRunSessions[index]; - ort->ReleaseRunOptions(runSessionCtx->runOptions); - runSessionCtx->runOptions = NULL; + ort->ReleaseRunOptions(run_session_ctx->runOptions); + run_session_ctx->runOptions = NULL; } diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index e4cd70bd7..dbb585432 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -2,22 +2,28 @@ #include "backends/onnxruntime.h" #include "onnxruntime_c_api.h" +#include "util/rax.h" // The maximum time in milliseconds before killing onnx run session. +// todo: make it a load time config #define ONNX_MAX_RUNTIME 5000 typedef struct OnnxRunSessionCtx { long long queuingTime; - OrtRunOptions* runOptions; + OrtRunOptions *runOptions; } OnnxRunSessionCtx; -OnnxRunSessionCtx **OnnxRunSessions; +// This is a global rax that holds an array of OnnxRunSessionCtx for every device +// that onnx models may run on. +rax *OnnxRunSessions; -int CreateGlobalOnnxRunSessions(long long size); +int CreateGlobalOnnxRunSessions(void); -void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, - uint64_t subevent, void *data); +int AddDeviceToGlobalRunSessions(const char *device); -void SetRunSessionCtx(size_t index, OrtRunOptions *newRunOptions); +void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, + void *data); -void ClearRunSessionCtx(size_t index); +OnnxRunSessionCtx *SetGetRunSessionCtx(const char *device, OrtRunOptions *new_run_options); + +void ClearRunSessionCtx(OnnxRunSessionCtx *run_session_ctx); diff --git a/src/execution/parsing/deprecated.c b/src/execution/parsing/deprecated.c index 45e1c2fc5..7a709f84a 100644 --- a/src/execution/parsing/deprecated.c +++ b/src/execution/parsing/deprecated.c @@ -305,17 +305,12 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } // TODO: if backend loaded, make sure there's a queue - RunQueueInfo *run_queue_info = NULL; - if (ensureRunQueue(devicestr, &run_queue_info) != REDISMODULE_OK) { - RAI_ModelFree(model, &err); - if (err.code != RAI_OK) { - RedisModule_Log(ctx, "warning", "%s", err.detail); - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; + if (!IsRunQueueExists(devicestr)) { + RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (run_queue_info == NULL) { + RAI_ModelFree(model, &err); + RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); } - return RedisModule_ReplyWithError(ctx, - "ERR Could not initialize queue on requested device"); } RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE); diff --git a/src/redisai.c b/src/redisai.c index 9e6472cba..6c173303a 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -54,7 +54,6 @@ #endif #endif - extern int redisMajorVersion; extern int redisMinorVersion; extern int redisPatchVersion; @@ -354,17 +353,12 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg } // TODO: if backend loaded, make sure there's a queue - RunQueueInfo *run_queue_info = NULL; - if (ensureRunQueue(devicestr, &run_queue_info) != REDISMODULE_OK) { - RAI_ModelFree(model, &err); - if (err.code != RAI_OK) { - RedisModule_Log(ctx, "warning", "%s", err.detail); - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; + if (!IsRunQueueExists(devicestr)) { + RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (run_queue_info == NULL) { + RAI_ModelFree(model, &err); + RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); } - return RedisModule_ReplyWithError(ctx, - "ERR Could not initialize queue on requested device"); } RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE); @@ -782,20 +776,12 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv return ret; } - RunQueueInfo *run_queue_info = NULL; - // If the queue does not exist, initialize it - if (ensureRunQueue(devicestr, &run_queue_info) == REDISMODULE_ERR) { - RAI_ScriptFree(script, &err); - if (err.code != RAI_OK) { -#ifdef RAI_PRINT_BACKEND_ERRORS - printf("ERR: %s\n", err.detail); -#endif - int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline); - RAI_ClearError(&err); - return ret; + if (!IsRunQueueExists(devicestr)) { + RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (run_queue_info == NULL) { + RAI_ScriptFree(script, &err); + RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); } - return RedisModule_ReplyWithError(ctx, - "ERR Could not initialize queue on requested device"); } RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE); @@ -1215,7 +1201,7 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_InfoAddSection(ctx, "git"); RedisModule_InfoAddFieldCString(ctx, "git_sha", REDISAI_GIT_SHA); RedisModule_InfoAddSection(ctx, "load_time_configs"); - RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", perqueueThreadPoolSize); + RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", ThreadPoolSizePerQueue); RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", getBackendsInterOpParallelism()); RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", getBackendsIntraOpParallelism()); RedisModule_InfoAddSection(ctx, "memory_usage"); @@ -1278,45 +1264,42 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_FreeString(NULL, main_thread_used_cpu_sys); RedisModule_FreeString(NULL, main_thread_used_cpu_user); - AI_dictIterator *iter = AI_dictGetSafeIterator(run_queues); - AI_dictEntry *entry = AI_dictNext(iter); - while (entry) { - char *queue_name = (char *)AI_dictGetKey(entry); - RunQueueInfo *run_queue_info = (RunQueueInfo *)AI_dictGetVal(entry); - if (run_queue_info) { - for (int i = 0; i < perqueueThreadPoolSize; i++) { - pthread_t current_bg_threads = run_queue_info->threads[i]; - struct timespec ts; - clockid_t cid; - RedisModuleString *queue_used_cpu_total = RedisModule_CreateStringPrintf( - NULL, "queue_%s_bthread_n%d_used_cpu_total", queue_name, i + 1); - RedisModuleString *bthread_used_cpu_total = NULL; + raxIterator rax_it; + raxStart(&rax_it, RunQueues); + raxSeek(&rax_it, "^", NULL, 0); + while (raxNext(&rax_it)) { + char *queue_name = (char *)rax_it.key; + RunQueueInfo *run_queue_info = (RunQueueInfo *)rax_it.data; + for (int i = 0; i < ThreadPoolSizePerQueue; i++) { + pthread_t current_bg_threads = run_queue_info->threads[i]; + struct timespec ts; + clockid_t cid; + RedisModuleString *queue_used_cpu_total = RedisModule_CreateStringPrintf( + NULL, "queue_%s_bthread_n%d_used_cpu_total", queue_name, i + 1); + RedisModuleString *bthread_used_cpu_total = NULL; #if (!defined(_POSIX_C_SOURCE) && !defined(_XOPEN_SOURCE)) || defined(_DARWIN_C_SOURCE) || \ defined(__cplusplus) - const int status = -1; + const int status = -1; #else - const int status = pthread_getcpuclockid(current_bg_threads, &cid); + const int status = pthread_getcpuclockid(current_bg_threads, &cid); #endif - if (status != 0) { + if (status != 0) { + bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); + } else { + if (clock_gettime(cid, &ts) == -1) { bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); } else { - if (clock_gettime(cid, &ts) == -1) { - bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); - } else { - bthread_used_cpu_total = RedisModule_CreateStringPrintf( - NULL, "%ld.%06ld", (long)ts.tv_sec, (long)(ts.tv_nsec / 1000)); - } + bthread_used_cpu_total = RedisModule_CreateStringPrintf( + NULL, "%ld.%06ld", (long)ts.tv_sec, (long)(ts.tv_nsec / 1000)); } - RedisModule_InfoAddFieldString( - ctx, (char *)RedisModule_StringPtrLen(queue_used_cpu_total, NULL), - bthread_used_cpu_total); - RedisModule_FreeString(NULL, queue_used_cpu_total); - RedisModule_FreeString(NULL, bthread_used_cpu_total); } + RedisModule_InfoAddFieldString( + ctx, (char *)RedisModule_StringPtrLen(queue_used_cpu_total, NULL), + bthread_used_cpu_total); + RedisModule_FreeString(NULL, queue_used_cpu_total); + RedisModule_FreeString(NULL, bthread_used_cpu_total); } - entry = AI_dictNext(iter); } - AI_dictReleaseIterator(iter); } int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { @@ -1465,24 +1448,19 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) // Default configs RAI_BackendsPath = NULL; - perqueueThreadPoolSize = REDISAI_DEFAULT_THREADS_PER_QUEUE; + ThreadPoolSizePerQueue = REDISAI_DEFAULT_THREADS_PER_QUEUE; setBackendsInterOpParallelism(REDISAI_DEFAULT_INTER_OP_PARALLELISM); setBackendsIntraOpParallelism(REDISAI_DEFAULT_INTRA_OP_PARALLELISM); setModelChunkSize(REDISAI_DEFAULT_MODEL_CHUNK_SIZE); RAI_loadTimeConfig(ctx, argv, argc); - run_queues = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); - RunQueueInfo *run_queue_info = NULL; - if (ensureRunQueue("CPU", &run_queue_info) != REDISMODULE_OK) { - RedisModule_Log(ctx, "warning", "Queue not initialized for device CPU"); + RunQueues = raxNew(); + RunQueueInfo *cpu_run_queue_info = CreateRunQueue("CPU"); + if (cpu_run_queue_info == NULL) { + RedisModule_Log(ctx, "warning", "RedisAI could not initialize run queue for CPU"); return REDISMODULE_ERR; } - for (size_t i = 0; i < perqueueThreadPoolSize; i++) { - RedisModule_Log(ctx, "warning", "thread id in index %zu is %lu", i, - run_queue_info->threads[i]); - } - run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); return REDISMODULE_OK; diff --git a/src/util/queue.c b/src/util/queue.c index 3c22488e0..b68c1a15c 100644 --- a/src/util/queue.c +++ b/src/util/queue.c @@ -90,7 +90,7 @@ queueItem *queueEvict(queue *queue, queueItem *item) { return item; } -long long queueLength(queue *queue) { return queue->len; } +long queueLength(queue *queue) { return queue->len; } void queueRelease(queue *queue) { unsigned long len; diff --git a/src/util/queue.h b/src/util/queue.h index af755181d..a03845182 100644 --- a/src/util/queue.h +++ b/src/util/queue.h @@ -28,5 +28,5 @@ queueItem *queuePop(queue *queue); queueItem *queueFront(queue *queue); queueItem *queueNext(queueItem *item); queueItem *queueEvict(queue *queue, queueItem *item); -long long queueLength(queue *queue); +long queueLength(queue *queue); void queueRelease(queue *queue); diff --git a/src/util/rax.c b/src/util/rax.c new file mode 100644 index 000000000..221258c58 --- /dev/null +++ b/src/util/rax.c @@ -0,0 +1,2003 @@ +/* Rax -- A radix tree implementation. + * + * Version 1.0 -- 14 November 2019 + * + * Copyright (c) 2017-2019, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include +#include +#include +#include "rax.h" + +#ifndef RAX_MALLOC_INCLUDE +#define RAX_MALLOC_INCLUDE "rax_malloc.h" +#endif + +#include RAX_MALLOC_INCLUDE + +/* This is a special pointer that is guaranteed to never have the same value + * of a radix tree node. It's used in order to report "not found" error without + * requiring the function to have multiple return values. */ +void *raxNotFound = (void *)"rax-not-found-pointer"; + +/* -------------------------------- Debugging ------------------------------ */ + +void raxDebugShowNode(const char *msg, raxNode *n); + +/* Turn debugging messages on/off by compiling with RAX_DEBUG_MSG macro on. + * When RAX_DEBUG_MSG is defined by default Rax operations will emit a lot + * of debugging info to the standard output, however you can still turn + * debugging on/off in order to enable it only when you suspect there is an + * operation causing a bug using the function raxSetDebugMsg(). */ +#ifdef RAX_DEBUG_MSG +#define debugf(...) \ + if (raxDebugMsg) { \ + printf("%s:%s:%d:\t", __FILE__, __FUNCTION__, __LINE__); \ + printf(__VA_ARGS__); \ + fflush(stdout); \ + } + +#define debugnode(msg, n) raxDebugShowNode(msg, n) +#else +#define debugf(...) +#define debugnode(msg, n) +#endif + +/* By default log debug info if RAX_DEBUG_MSG is defined. */ +static int raxDebugMsg = 1; + +/* When debug messages are enabled, turn them on/off dynamically. By + * default they are enabled. Set the state to 0 to disable, and 1 to + * re-enable. */ +void raxSetDebugMsg(int onoff) { raxDebugMsg = onoff; } + +/* ------------------------- raxStack functions -------------------------- + * The raxStack is a simple stack of pointers that is capable of switching + * from using a stack-allocated array to dynamic heap once a given number of + * items are reached. It is used in order to retain the list of parent nodes + * while walking the radix tree in order to implement certain operations that + * need to navigate the tree upward. + * ------------------------------------------------------------------------- */ + +/* Initialize the stack. */ +static inline void raxStackInit(raxStack *ts) { + ts->stack = ts->static_items; + ts->items = 0; + ts->maxitems = RAX_STACK_STATIC_ITEMS; + ts->oom = 0; +} + +/* Push an item into the stack, returns 1 on success, 0 on out of memory. */ +static inline int raxStackPush(raxStack *ts, void *ptr) { + if (ts->items == ts->maxitems) { + if (ts->stack == ts->static_items) { + ts->stack = rax_malloc(sizeof(void *) * ts->maxitems * 2); + if (ts->stack == NULL) { + ts->stack = ts->static_items; + ts->oom = 1; + errno = ENOMEM; + return 0; + } + memcpy(ts->stack, ts->static_items, sizeof(void *) * ts->maxitems); + } else { + void **newalloc = rax_realloc(ts->stack, sizeof(void *) * ts->maxitems * 2); + if (newalloc == NULL) { + ts->oom = 1; + errno = ENOMEM; + return 0; + } + ts->stack = newalloc; + } + ts->maxitems *= 2; + } + ts->stack[ts->items] = ptr; + ts->items++; + return 1; +} + +/* Pop an item from the stack, the function returns NULL if there are no + * items to pop. */ +static inline void *raxStackPop(raxStack *ts) { + if (ts->items == 0) + return NULL; + ts->items--; + return ts->stack[ts->items]; +} + +/* Return the stack item at the top of the stack without actually consuming + * it. */ +static inline void *raxStackPeek(raxStack *ts) { + if (ts->items == 0) + return NULL; + return ts->stack[ts->items - 1]; +} + +/* Free the stack in case we used heap allocation. */ +static inline void raxStackFree(raxStack *ts) { + if (ts->stack != ts->static_items) + rax_free(ts->stack); +} + +/* ---------------------------------------------------------------------------- + * Radix tree implementation + * --------------------------------------------------------------------------*/ + +/* Return the padding needed in the characters section of a node having size + * 'nodesize'. The padding is needed to store the child pointers to aligned + * addresses. Note that we add 4 to the node size because the node has a four + * bytes header. */ +#define raxPadding(nodesize) \ + ((sizeof(void *) - ((nodesize + 4) % sizeof(void *))) & (sizeof(void *) - 1)) + +/* Return the pointer to the last child pointer in a node. For the compressed + * nodes this is the only child pointer. */ +#define raxNodeLastChildPtr(n) \ + ((raxNode **)(((char *)(n)) + raxNodeCurrentLength(n) - sizeof(raxNode *) - \ + (((n)->iskey && !(n)->isnull) ? sizeof(void *) : 0))) + +/* Return the pointer to the first child pointer. */ +#define raxNodeFirstChildPtr(n) ((raxNode **)((n)->data + (n)->size + raxPadding((n)->size))) + +/* Return the current total size of the node. Note that the second line + * computes the padding after the string of characters, needed in order to + * save pointers to aligned addresses. */ +#define raxNodeCurrentLength(n) \ + (sizeof(raxNode) + (n)->size + raxPadding((n)->size) + \ + ((n)->iscompr ? sizeof(raxNode *) : sizeof(raxNode *) * (n)->size) + \ + (((n)->iskey && !(n)->isnull) * sizeof(void *))) + +/* Allocate a new non compressed node with the specified number of children. + * If datafiled is true, the allocation is made large enough to hold the + * associated data pointer. + * Returns the new node pointer. On out of memory NULL is returned. */ +raxNode *raxNewNode(size_t children, int datafield) { + size_t nodesize = + sizeof(raxNode) + children + raxPadding(children) + sizeof(raxNode *) * children; + if (datafield) + nodesize += sizeof(void *); + raxNode *node = rax_malloc(nodesize); + if (node == NULL) + return NULL; + node->iskey = 0; + node->isnull = 0; + node->iscompr = 0; + node->size = children; + return node; +} + +/* Allocate a new rax and return its pointer. On out of memory the function + * returns NULL. */ +rax *raxNew(void) { + rax *rax = rax_malloc(sizeof(*rax)); + if (rax == NULL) + return NULL; + rax->numele = 0; + rax->numnodes = 1; + rax->head = raxNewNode(0, 0); + if (rax->head == NULL) { + rax_free(rax); + return NULL; + } else { + return rax; + } +} + +/* realloc the node to make room for auxiliary data in order + * to store an item in that node. On out of memory NULL is returned. */ +raxNode *raxReallocForData(raxNode *n, void *data) { + if (data == NULL) + return n; /* No reallocation needed, setting isnull=1 */ + size_t curlen = raxNodeCurrentLength(n); + return rax_realloc(n, curlen + sizeof(void *)); +} + +/* Set the node auxiliary data to the specified pointer. */ +void raxSetData(raxNode *n, void *data) { + n->iskey = 1; + if (data != NULL) { + n->isnull = 0; + void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); + memcpy(ndata, &data, sizeof(data)); + } else { + n->isnull = 1; + } +} + +/* Get the node auxiliary data. */ +void *raxGetData(raxNode *n) { + if (n->isnull) + return NULL; + void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); + void *data; + memcpy(&data, ndata, sizeof(data)); + return data; +} + +/* Add a new child to the node 'n' representing the character 'c' and return + * its new pointer, as well as the child pointer by reference. Additionally + * '***parentlink' is populated with the raxNode pointer-to-pointer of where + * the new child was stored, which is useful for the caller to replace the + * child pointer if it gets reallocated. + * + * On success the new parent node pointer is returned (it may change because + * of the realloc, so the caller should discard 'n' and use the new value). + * On out of memory NULL is returned, and the old node is still valid. */ +raxNode *raxAddChild(raxNode *n, unsigned char c, raxNode **childptr, raxNode ***parentlink) { + assert(n->iscompr == 0); + + size_t curlen = raxNodeCurrentLength(n); + n->size++; + size_t newlen = raxNodeCurrentLength(n); + n->size--; /* For now restore the orignal size. We'll update it only on + success at the end. */ + + /* Alloc the new child we will link to 'n'. */ + raxNode *child = raxNewNode(0, 0); + if (child == NULL) + return NULL; + + /* Make space in the original node. */ + raxNode *newn = rax_realloc(n, newlen); + if (newn == NULL) { + rax_free(child); + return NULL; + } + n = newn; + + /* After the reallocation, we have up to 8/16 (depending on the system + * pointer size, and the required node padding) bytes at the end, that is, + * the additional char in the 'data' section, plus one pointer to the new + * child, plus the padding needed in order to store addresses into aligned + * locations. + * + * So if we start with the following node, having "abde" edges. + * + * Note: + * - We assume 4 bytes pointer for simplicity. + * - Each space below corresponds to one byte + * + * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP| + * + * After the reallocation we need: 1 byte for the new edge character + * plus 4 bytes for a new child pointer (assuming 32 bit machine). + * However after adding 1 byte to the edge char, the header + the edge + * characters are no longer aligned, so we also need 3 bytes of padding. + * In total the reallocation will add 1+4+3 bytes = 8 bytes: + * + * (Blank bytes are represented by ".") + * + * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP|[....][....] + * + * Let's find where to insert the new child in order to make sure + * it is inserted in-place lexicographically. Assuming we are adding + * a child "c" in our case pos will be = 2 after the end of the following + * loop. */ + int pos; + for (pos = 0; pos < n->size; pos++) { + if (n->data[pos] > c) + break; + } + + /* Now, if present, move auxiliary data pointer at the end + * so that we can mess with the other data without overwriting it. + * We will obtain something like that: + * + * [HDR*][abde][Aptr][Bptr][Dptr][Eptr][....][....]|AUXP| + */ + unsigned char *src, *dst; + if (n->iskey && !n->isnull) { + src = ((unsigned char *)n + curlen - sizeof(void *)); + dst = ((unsigned char *)n + newlen - sizeof(void *)); + memmove(dst, src, sizeof(void *)); + } + + /* Compute the "shift", that is, how many bytes we need to move the + * pointers section forward because of the addition of the new child + * byte in the string section. Note that if we had no padding, that + * would be always "1", since we are adding a single byte in the string + * section of the node (where now there is "abde" basically). + * + * However we have padding, so it could be zero, or up to 8. + * + * Another way to think at the shift is, how many bytes we need to + * move child pointers forward *other than* the obvious sizeof(void*) + * needed for the additional pointer itself. */ + size_t shift = newlen - curlen - sizeof(void *); + + /* We said we are adding a node with edge 'c'. The insertion + * point is between 'b' and 'd', so the 'pos' variable value is + * the index of the first child pointer that we need to move forward + * to make space for our new pointer. + * + * To start, move all the child pointers after the insertion point + * of shift+sizeof(pointer) bytes on the right, to obtain: + * + * [HDR*][abde][Aptr][Bptr][....][....][Dptr][Eptr]|AUXP| + */ + src = n->data + n->size + raxPadding(n->size) + sizeof(raxNode *) * pos; + memmove(src + shift + sizeof(raxNode *), src, sizeof(raxNode *) * (n->size - pos)); + + /* Move the pointers to the left of the insertion position as well. Often + * we don't need to do anything if there was already some padding to use. In + * that case the final destination of the pointers will be the same, however + * in our example there was no pre-existing padding, so we added one byte + * plus thre bytes of padding. After the next memmove() things will look + * like thata: + * + * [HDR*][abde][....][Aptr][Bptr][....][Dptr][Eptr]|AUXP| + */ + if (shift) { + src = (unsigned char *)raxNodeFirstChildPtr(n); + memmove(src + shift, src, sizeof(raxNode *) * pos); + } + + /* Now make the space for the additional char in the data section, + * but also move the pointers before the insertion point to the right + * by shift bytes, in order to obtain the following: + * + * [HDR*][ab.d][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| + */ + src = n->data + pos; + memmove(src + 1, src, n->size - pos); + + /* We can now set the character and its child node pointer to get: + * + * [HDR*][abcd][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| + * [HDR*][abcd][e...][Aptr][Bptr][Cptr][Dptr][Eptr]|AUXP| + */ + n->data[pos] = c; + n->size++; + src = (unsigned char *)raxNodeFirstChildPtr(n); + raxNode **childfield = (raxNode **)(src + sizeof(raxNode *) * pos); + memcpy(childfield, &child, sizeof(child)); + *childptr = child; + *parentlink = childfield; + return n; +} + +/* Turn the node 'n', that must be a node without any children, into a + * compressed node representing a set of nodes linked one after the other + * and having exactly one child each. The node can be a key or not: this + * property and the associated value if any will be preserved. + * + * The function also returns a child node, since the last node of the + * compressed chain cannot be part of the chain: it has zero children while + * we can only compress inner nodes with exactly one child each. */ +raxNode *raxCompressNode(raxNode *n, unsigned char *s, size_t len, raxNode **child) { + assert(n->size == 0 && n->iscompr == 0); + void *data = NULL; /* Initialized only to avoid warnings. */ + size_t newsize; + + debugf("Compress node: %.*s\n", (int)len, s); + + /* Allocate the child to link to this node. */ + *child = raxNewNode(0, 0); + if (*child == NULL) + return NULL; + + /* Make space in the parent node. */ + newsize = sizeof(raxNode) + len + raxPadding(len) + sizeof(raxNode *); + if (n->iskey) { + data = raxGetData(n); /* To restore it later. */ + if (!n->isnull) + newsize += sizeof(void *); + } + raxNode *newn = rax_realloc(n, newsize); + if (newn == NULL) { + rax_free(*child); + return NULL; + } + n = newn; + + n->iscompr = 1; + n->size = len; + memcpy(n->data, s, len); + if (n->iskey) + raxSetData(n, data); + raxNode **childfield = raxNodeLastChildPtr(n); + memcpy(childfield, child, sizeof(*child)); + return n; +} + +/* Low level function that walks the tree looking for the string + * 's' of 'len' bytes. The function returns the number of characters + * of the key that was possible to process: if the returned integer + * is the same as 'len', then it means that the node corresponding to the + * string was found (however it may not be a key in case the node->iskey is + * zero or if simply we stopped in the middle of a compressed node, so that + * 'splitpos' is non zero). + * + * Otherwise if the returned integer is not the same as 'len', there was an + * early stop during the tree walk because of a character mismatch. + * + * The node where the search ended (because the full string was processed + * or because there was an early stop) is returned by reference as + * '*stopnode' if the passed pointer is not NULL. This node link in the + * parent's node is returned as '*plink' if not NULL. Finally, if the + * search stopped in a compressed node, '*splitpos' returns the index + * inside the compressed node where the search ended. This is useful to + * know where to split the node for insertion. + * + * Note that when we stop in the middle of a compressed node with + * a perfect match, this function will return a length equal to the + * 'len' argument (all the key matched), and will return a *splitpos which is + * always positive (that will represent the index of the character immediately + * *after* the last match in the current compressed node). + * + * When instead we stop at a compressed node and *splitpos is zero, it + * means that the current node represents the key (that is, none of the + * compressed node characters are needed to represent the key, just all + * its parents nodes). */ +static inline size_t raxLowWalk(rax *rax, unsigned char *s, size_t len, raxNode **stopnode, + raxNode ***plink, int *splitpos, raxStack *ts) { + raxNode *h = rax->head; + raxNode **parentlink = &rax->head; + + size_t i = 0; /* Position in the string. */ + size_t j = 0; /* Position in the node children (or bytes if compressed).*/ + while (h->size && i < len) { + debugnode("Lookup current node", h); + unsigned char *v = h->data; + + if (h->iscompr) { + for (j = 0; j < h->size && i < len; j++, i++) { + if (v[j] != s[i]) + break; + } + if (j != h->size) + break; + } else { + /* Even when h->size is large, linear scan provides good + * performances compared to other approaches that are in theory + * more sounding, like performing a binary search. */ + for (j = 0; j < h->size; j++) { + if (v[j] == s[i]) + break; + } + if (j == h->size) + break; + i++; + } + + if (ts) + raxStackPush(ts, h); /* Save stack of parent nodes. */ + raxNode **children = raxNodeFirstChildPtr(h); + if (h->iscompr) + j = 0; /* Compressed node only child is at index 0. */ + memcpy(&h, children + j, sizeof(h)); + parentlink = children + j; + j = 0; /* If the new node is compressed and we do not + iterate again (since i == l) set the split + position to 0 to signal this node represents + the searched key. */ + } + debugnode("Lookup stop node is", h); + if (stopnode) + *stopnode = h; + if (plink) + *plink = parentlink; + if (splitpos && h->iscompr) + *splitpos = j; + return i; +} + +/* Insert the element 's' of size 'len', setting as auxiliary data + * the pointer 'data'. If the element is already present, the associated + * data is updated (only if 'overwrite' is set to 1), and 0 is returned, + * otherwise the element is inserted and 1 is returned. On out of memory the + * function returns 0 as well but sets errno to ENOMEM, otherwise errno will + * be set to 0. + */ +int raxGenericInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old, + int overwrite) { + size_t i; + int j = 0; /* Split position. If raxLowWalk() stops in a compressed + node, the index 'j' represents the char we stopped within the + compressed node, that is, the position where to split the + node for insertion. */ + raxNode *h, **parentlink; + + debugf("### Insert %.*s with value %p\n", (int)len, s, data); + i = raxLowWalk(rax, s, len, &h, &parentlink, &j, NULL); + + /* If i == len we walked following the whole string. If we are not + * in the middle of a compressed node, the string is either already + * inserted or this middle node is currently not a key, but can represent + * our key. We have just to reallocate the node and make space for the + * data pointer. */ + if (i == len && (!h->iscompr || j == 0 /* not in the middle if j is 0 */)) { + debugf("### Insert: node representing key exists\n"); + /* Make space for the value pointer if needed. */ + if (!h->iskey || (h->isnull && overwrite)) { + h = raxReallocForData(h, data); + if (h) + memcpy(parentlink, &h, sizeof(h)); + } + if (h == NULL) { + errno = ENOMEM; + return 0; + } + + /* Update the existing key if there is already one. */ + if (h->iskey) { + if (old) + *old = raxGetData(h); + if (overwrite) + raxSetData(h, data); + errno = 0; + return 0; /* Element already exists. */ + } + + /* Otherwise set the node as a key. Note that raxSetData() + * will set h->iskey. */ + raxSetData(h, data); + rax->numele++; + return 1; /* Element inserted. */ + } + + /* If the node we stopped at is a compressed node, we need to + * split it before to continue. + * + * Splitting a compressed node have a few possible cases. + * Imagine that the node 'h' we are currently at is a compressed + * node contaning the string "ANNIBALE" (it means that it represents + * nodes A -> N -> N -> I -> B -> A -> L -> E with the only child + * pointer of this node pointing at the 'E' node, because remember that + * we have characters at the edges of the graph, not inside the nodes + * themselves. + * + * In order to show a real case imagine our node to also point to + * another compressed node, that finally points at the node without + * children, representing 'O': + * + * "ANNIBALE" -> "SCO" -> [] + * + * When inserting we may face the following cases. Note that all the cases + * require the insertion of a non compressed node with exactly two + * children, except for the last case which just requires splitting a + * compressed node. + * + * 1) Inserting "ANNIENTARE" + * + * |B| -> "ALE" -> "SCO" -> [] + * "ANNI" -> |-| + * |E| -> (... continue algo ...) "NTARE" -> [] + * + * 2) Inserting "ANNIBALI" + * + * |E| -> "SCO" -> [] + * "ANNIBAL" -> |-| + * |I| -> (... continue algo ...) [] + * + * 3) Inserting "AGO" (Like case 1, but set iscompr = 0 into original node) + * + * |N| -> "NIBALE" -> "SCO" -> [] + * |A| -> |-| + * |G| -> (... continue algo ...) |O| -> [] + * + * 4) Inserting "CIAO" + * + * |A| -> "NNIBALE" -> "SCO" -> [] + * |-| + * |C| -> (... continue algo ...) "IAO" -> [] + * + * 5) Inserting "ANNI" + * + * "ANNI" -> "BALE" -> "SCO" -> [] + * + * The final algorithm for insertion covering all the above cases is as + * follows. + * + * ============================= ALGO 1 ============================= + * + * For the above cases 1 to 4, that is, all cases where we stopped in + * the middle of a compressed node for a character mismatch, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of characters, we found the mismatching + * character. For example if the node contains "ANNIBALE" and we add + * "ANNIENTARE" the $SPLITPOS is 4, that is, the index at which the + * mismatching character is found. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create "split node" having as child the non common letter + * at the compressed node. The other non common letter (at the key) + * will be added later as we continue the normal insertion algorithm + * at step "6". + * + * 3a. IF $SPLITPOS == 0: + * Replace the old node with the split node, by copying the auxiliary + * data if any. Fix parent's reference. Free old node eventually + * (we still need its data for the next steps of the algorithm). + * + * 3b. IF $SPLITPOS != 0: + * Trim the compressed node (reallocating it as well) in order to + * contain $splitpos characters. Change chilid pointer in order to link + * to the split node. If new compressed node len is just 1, set + * iscompr to 0 (layout is the same). Fix parent's reference. + * + * 4a. IF the postfix len (the length of the remaining string of the + * original compressed node after the split character) is non zero, + * create a "postfix node". If the postfix node has just one character + * set iscompr to 0, otherwise iscompr to 1. Set the postfix node + * child pointer to $NEXT. + * + * 4b. IF the postfix len is zero, just use $NEXT as postfix pointer. + * + * 5. Set child[0] of split node to postfix node. + * + * 6. Set the split node as the current node, set current index at child[1] + * and continue insertion algorithm as usually. + * + * ============================= ALGO 2 ============================= + * + * For case 5, that is, if we stopped in the middle of a compressed + * node but no mismatch was found, do: + * + * Let $SPLITPOS be the zero-based index at which, in the + * compressed node array of characters, we stopped iterating because + * there were no more keys character to match. So in the example of + * the node "ANNIBALE", addig the string "ANNI", the $SPLITPOS is 4. + * + * 1. Save the current compressed node $NEXT pointer (the pointer to the + * child element, that is always present in compressed nodes). + * + * 2. Create a "postfix node" containing all the characters from $SPLITPOS + * to the end. Use $NEXT as the postfix node child pointer. + * If the postfix node length is 1, set iscompr to 0. + * Set the node as a key with the associated value of the new + * inserted key. + * + * 3. Trim the current node to contain the first $SPLITPOS characters. + * As usually if the new node length is just 1, set iscompr to 0. + * Take the iskey / associated value as it was in the orignal node. + * Fix the parent's reference. + * + * 4. Set the postfix node as the only child pointer of the trimmed + * node created at step 1. + */ + + /* ------------------------- ALGORITHM 1 --------------------------- */ + if (h->iscompr && i != len) { + debugf("ALGO 1: Stopped at compressed node %.*s (%p)\n", h->size, h->data, (void *)h); + debugf("Still to insert: %.*s\n", (int)(len - i), s + i); + debugf("Splitting at %d: '%c'\n", j, ((char *)h->data)[j]); + debugf("Other (key) letter is '%c'\n", s[i]); + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next, childfield, sizeof(next)); + debugf("Next is %p\n", (void *)next); + debugf("iskey %d\n", h->iskey); + if (h->iskey) { + debugf("key value is %p\n", raxGetData(h)); + } + + /* Set the length of the additional nodes we will need. */ + size_t trimmedlen = j; + size_t postfixlen = h->size - j - 1; + int split_node_is_key = !trimmedlen && h->iskey && !h->isnull; + size_t nodesize; + + /* 2: Create the split node. Also allocate the other nodes we'll need + * ASAP, so that it will be simpler to handle OOM. */ + raxNode *splitnode = raxNewNode(1, split_node_is_key); + raxNode *trimmed = NULL; + raxNode *postfix = NULL; + + if (trimmedlen) { + nodesize = sizeof(raxNode) + trimmedlen + raxPadding(trimmedlen) + sizeof(raxNode *); + if (h->iskey && !h->isnull) + nodesize += sizeof(void *); + trimmed = rax_malloc(nodesize); + } + + if (postfixlen) { + nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); + postfix = rax_malloc(nodesize); + } + + /* OOM? Abort now that the tree is untouched. */ + if (splitnode == NULL || (trimmedlen && trimmed == NULL) || + (postfixlen && postfix == NULL)) { + rax_free(splitnode); + rax_free(trimmed); + rax_free(postfix); + errno = ENOMEM; + return 0; + } + splitnode->data[0] = h->data[j]; + + if (j == 0) { + /* 3a: Replace the old node with the split node. */ + if (h->iskey) { + void *ndata = raxGetData(h); + raxSetData(splitnode, ndata); + } + memcpy(parentlink, &splitnode, sizeof(splitnode)); + } else { + /* 3b: Trim the compressed node. */ + trimmed->size = j; + memcpy(trimmed->data, h->data, j); + trimmed->iscompr = j > 1 ? 1 : 0; + trimmed->iskey = h->iskey; + trimmed->isnull = h->isnull; + if (h->iskey && !h->isnull) { + void *ndata = raxGetData(h); + raxSetData(trimmed, ndata); + } + raxNode **cp = raxNodeLastChildPtr(trimmed); + memcpy(cp, &splitnode, sizeof(splitnode)); + memcpy(parentlink, &trimmed, sizeof(trimmed)); + parentlink = cp; /* Set parentlink to splitnode parent. */ + rax->numnodes++; + } + + /* 4: Create the postfix node: what remains of the original + * compressed node after the split. */ + if (postfixlen) { + /* 4a: create a postfix node. */ + postfix->iskey = 0; + postfix->isnull = 0; + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + memcpy(postfix->data, h->data + j + 1, postfixlen); + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp, &next, sizeof(next)); + rax->numnodes++; + } else { + /* 4b: just use next as postfix node. */ + postfix = next; + } + + /* 5: Set splitnode first child as the postfix node. */ + raxNode **splitchild = raxNodeLastChildPtr(splitnode); + memcpy(splitchild, &postfix, sizeof(postfix)); + + /* 6. Continue insertion: this will cause the splitnode to + * get a new child (the non common character at the currently + * inserted key). */ + rax_free(h); + h = splitnode; + } else if (h->iscompr && i == len) { + /* ------------------------- ALGORITHM 2 --------------------------- */ + debugf("ALGO 2: Stopped at compressed node %.*s (%p) j = %d\n", h->size, h->data, (void *)h, + j); + + /* Allocate postfix & trimmed nodes ASAP to fail for OOM gracefully. */ + size_t postfixlen = h->size - j; + size_t nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); + if (data != NULL) + nodesize += sizeof(void *); + raxNode *postfix = rax_malloc(nodesize); + + nodesize = sizeof(raxNode) + j + raxPadding(j) + sizeof(raxNode *); + if (h->iskey && !h->isnull) + nodesize += sizeof(void *); + raxNode *trimmed = rax_malloc(nodesize); + + if (postfix == NULL || trimmed == NULL) { + rax_free(postfix); + rax_free(trimmed); + errno = ENOMEM; + return 0; + } + + /* 1: Save next pointer. */ + raxNode **childfield = raxNodeLastChildPtr(h); + raxNode *next; + memcpy(&next, childfield, sizeof(next)); + + /* 2: Create the postfix node. */ + postfix->size = postfixlen; + postfix->iscompr = postfixlen > 1; + postfix->iskey = 1; + postfix->isnull = 0; + memcpy(postfix->data, h->data + j, postfixlen); + raxSetData(postfix, data); + raxNode **cp = raxNodeLastChildPtr(postfix); + memcpy(cp, &next, sizeof(next)); + rax->numnodes++; + + /* 3: Trim the compressed node. */ + trimmed->size = j; + trimmed->iscompr = j > 1; + trimmed->iskey = 0; + trimmed->isnull = 0; + memcpy(trimmed->data, h->data, j); + memcpy(parentlink, &trimmed, sizeof(trimmed)); + if (h->iskey) { + void *aux = raxGetData(h); + raxSetData(trimmed, aux); + } + + /* Fix the trimmed node child pointer to point to + * the postfix node. */ + cp = raxNodeLastChildPtr(trimmed); + memcpy(cp, &postfix, sizeof(postfix)); + + /* Finish! We don't need to continue with the insertion + * algorithm for ALGO 2. The key is already inserted. */ + rax->numele++; + rax_free(h); + return 1; /* Key inserted. */ + } + + /* We walked the radix tree as far as we could, but still there are left + * chars in our string. We need to insert the missing nodes. */ + while (i < len) { + raxNode *child; + + /* If this node is going to have a single child, and there + * are other characters, so that that would result in a chain + * of single-childed nodes, turn it into a compressed node. */ + if (h->size == 0 && len - i > 1) { + debugf("Inserting compressed node\n"); + size_t comprsize = len - i; + if (comprsize > RAX_NODE_MAX_SIZE) + comprsize = RAX_NODE_MAX_SIZE; + raxNode *newh = raxCompressNode(h, s + i, comprsize, &child); + if (newh == NULL) + goto oom; + h = newh; + memcpy(parentlink, &h, sizeof(h)); + parentlink = raxNodeLastChildPtr(h); + i += comprsize; + } else { + debugf("Inserting normal node\n"); + raxNode **new_parentlink; + raxNode *newh = raxAddChild(h, s[i], &child, &new_parentlink); + if (newh == NULL) + goto oom; + h = newh; + memcpy(parentlink, &h, sizeof(h)); + parentlink = new_parentlink; + i++; + } + rax->numnodes++; + h = child; + } + raxNode *newh = raxReallocForData(h, data); + if (newh == NULL) + goto oom; + h = newh; + if (!h->iskey) + rax->numele++; + raxSetData(h, data); + memcpy(parentlink, &h, sizeof(h)); + return 1; /* Element inserted. */ + +oom: + /* This code path handles out of memory after part of the sub-tree was + * already modified. Set the node as a key, and then remove it. However we + * do that only if the node is a terminal node, otherwise if the OOM + * happened reallocating a node in the middle, we don't need to free + * anything. */ + if (h->size == 0) { + h->isnull = 1; + h->iskey = 1; + rax->numele++; /* Compensate the next remove. */ + assert(raxRemove(rax, s, i, NULL) != 0); + } + errno = ENOMEM; + return 0; +} + +/* Overwriting insert. Just a wrapper for raxGenericInsert() that will + * update the element if there is already one for the same key. */ +int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { + return raxGenericInsert(rax, s, len, data, old, 1); +} + +/* Non overwriting insert function: this if an element with the same key + * exists, the value is not updated and the function returns 0. + * This is a just a wrapper for raxGenericInsert(). */ +int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { + return raxGenericInsert(rax, s, len, data, old, 0); +} + +/* Find a key in the rax, returns raxNotFound special void pointer value + * if the item was not found, otherwise the value associated with the + * item is returned. */ +void *raxFind(rax *rax, unsigned char *s, size_t len) { + raxNode *h; + + debugf("### Lookup: %.*s\n", (int)len, s); + int splitpos = 0; + size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, NULL); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) + return raxNotFound; + return raxGetData(h); +} + +/* Return the memory address where the 'parent' node stores the specified + * 'child' pointer, so that the caller can update the pointer with another + * one if needed. The function assumes it will find a match, otherwise the + * operation is an undefined behavior (it will continue scanning the + * memory without any bound checking). */ +raxNode **raxFindParentLink(raxNode *parent, raxNode *child) { + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode *c; + while (1) { + memcpy(&c, cp, sizeof(c)); + if (c == child) + break; + cp++; + } + return cp; +} + +/* Low level child removal from node. The new node pointer (after the child + * removal) is returned. Note that this function does not fix the pointer + * of the parent node in its parent, so this task is up to the caller. + * The function never fails for out of memory. */ +raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { + debugnode("raxRemoveChild before", parent); + /* If parent is a compressed node (having a single child, as for definition + * of the data structure), the removal of the child consists into turning + * it into a normal node without children. */ + if (parent->iscompr) { + void *data = NULL; + if (parent->iskey) + data = raxGetData(parent); + parent->isnull = 0; + parent->iscompr = 0; + parent->size = 0; + if (parent->iskey) + raxSetData(parent, data); + debugnode("raxRemoveChild after", parent); + return parent; + } + + /* Otherwise we need to scan for the child pointer and memmove() + * accordingly. + * + * 1. To start we seek the first element in both the children + * pointers and edge bytes in the node. */ + raxNode **cp = raxNodeFirstChildPtr(parent); + raxNode **c = cp; + unsigned char *e = parent->data; + + /* 2. Search the child pointer to remove inside the array of children + * pointers. */ + while (1) { + raxNode *aux; + memcpy(&aux, c, sizeof(aux)); + if (aux == child) + break; + c++; + e++; + } + + /* 3. Remove the edge and the pointer by memmoving the remaining children + * pointer and edge bytes one position before. */ + int taillen = parent->size - (e - parent->data) - 1; + debugf("raxRemoveChild tail len: %d\n", taillen); + memmove(e, e + 1, taillen); + + /* Compute the shift, that is the amount of bytes we should move our + * child pointers to the left, since the removal of one edge character + * and the corresponding padding change, may change the layout. + * We just check if in the old version of the node there was at the + * end just a single byte and all padding: in that case removing one char + * will remove a whole sizeof(void*) word. */ + size_t shift = ((parent->size + 4) % sizeof(void *)) == 1 ? sizeof(void *) : 0; + + /* Move the children pointers before the deletion point. */ + if (shift) + memmove(((char *)cp) - shift, cp, (parent->size - taillen - 1) * sizeof(raxNode **)); + + /* Move the remaining "tail" pointers at the right position as well. */ + size_t valuelen = (parent->iskey && !parent->isnull) ? sizeof(void *) : 0; + memmove(((char *)c) - shift, c + 1, taillen * sizeof(raxNode **) + valuelen); + + /* 4. Update size. */ + parent->size--; + + /* realloc the node according to the theoretical memory usage, to free + * data if we are over-allocating right now. */ + raxNode *newnode = rax_realloc(parent, raxNodeCurrentLength(parent)); + if (newnode) { + debugnode("raxRemoveChild after", newnode); + } + /* Note: if rax_realloc() fails we just return the old address, which + * is valid. */ + return newnode ? newnode : parent; +} + +/* Remove the specified item. Returns 1 if the item was found and + * deleted, 0 otherwise. */ +int raxRemove(rax *rax, unsigned char *s, size_t len, void **old) { + raxNode *h; + raxStack ts; + + debugf("### Delete: %.*s\n", (int)len, s); + raxStackInit(&ts); + int splitpos = 0; + size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, &ts); + if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { + raxStackFree(&ts); + return 0; + } + if (old) + *old = raxGetData(h); + h->iskey = 0; + rax->numele--; + + /* If this node has no children, the deletion needs to reclaim the + * no longer used nodes. This is an iterative process that needs to + * walk the three upward, deleting all the nodes with just one child + * that are not keys, until the head of the rax is reached or the first + * node with more than one child is found. */ + + int trycompress = 0; /* Will be set to 1 if we should try to optimize the + tree resulting from the deletion. */ + + if (h->size == 0) { + debugf("Key deleted in node without children. Cleanup needed.\n"); + raxNode *child = NULL; + while (h != rax->head) { + child = h; + debugf("Freeing child %p [%.*s] key:%d\n", (void *)child, (int)child->size, + (char *)child->data, child->iskey); + rax_free(child); + rax->numnodes--; + h = raxStackPop(&ts); + /* If this node has more then one child, or actually holds + * a key, stop here. */ + if (h->iskey || (!h->iscompr && h->size != 1)) + break; + } + if (child) { + debugf("Unlinking child %p from parent %p\n", (void *)child, (void *)h); + raxNode *new = raxRemoveChild(h, child); + if (new != h) { + raxNode *parent = raxStackPeek(&ts); + raxNode **parentlink; + if (parent == NULL) { + parentlink = &rax->head; + } else { + parentlink = raxFindParentLink(parent, h); + } + memcpy(parentlink, &new, sizeof(new)); + } + + /* If after the removal the node has just a single child + * and is not a key, we need to try to compress it. */ + if (new->size == 1 && new->iskey == 0) { + trycompress = 1; + h = new; + } + } + } else if (h->size == 1) { + /* If the node had just one child, after the removal of the key + * further compression with adjacent nodes is pontentially possible. */ + trycompress = 1; + } + + /* Don't try node compression if our nodes pointers stack is not + * complete because of OOM while executing raxLowWalk() */ + if (trycompress && ts.oom) + trycompress = 0; + + /* Recompression: if trycompress is true, 'h' points to a radix tree node + * that changed in a way that could allow to compress nodes in this + * sub-branch. Compressed nodes represent chains of nodes that are not + * keys and have a single child, so there are two deletion events that + * may alter the tree so that further compression is needed: + * + * 1) A node with a single child was a key and now no longer is a key. + * 2) A node with two children now has just one child. + * + * We try to navigate upward till there are other nodes that can be + * compressed, when we reach the upper node which is not a key and has + * a single child, we scan the chain of children to collect the + * compressable part of the tree, and replace the current node with the + * new one, fixing the child pointer to reference the first non + * compressable node. + * + * Example of case "1". A tree stores the keys "FOO" = 1 and + * "FOOBAR" = 2: + * + * + * "FOO" -> "BAR" -> [] (2) + * (1) + * + * After the removal of "FOO" the tree can be compressed as: + * + * "FOOBAR" -> [] (2) + * + * + * Example of case "2". A tree stores the keys "FOOBAR" = 1 and + * "FOOTER" = 2: + * + * |B| -> "AR" -> [] (1) + * "FOO" -> |-| + * |T| -> "ER" -> [] (2) + * + * After the removal of "FOOTER" the resulting tree is: + * + * "FOO" -> |B| -> "AR" -> [] (1) + * + * That can be compressed into: + * + * "FOOBAR" -> [] (1) + */ + if (trycompress) { + debugf("After removing %.*s:\n", (int)len, s); + debugnode("Compression may be needed", h); + debugf("Seek start node\n"); + + /* Try to reach the upper node that is compressible. + * At the end of the loop 'h' will point to the first node we + * can try to compress and 'parent' to its parent. */ + raxNode *parent; + while (1) { + parent = raxStackPop(&ts); + if (!parent || parent->iskey || (!parent->iscompr && parent->size != 1)) + break; + h = parent; + debugnode("Going up to", h); + } + raxNode *start = h; /* Compression starting node. */ + + /* Scan chain of nodes we can compress. */ + size_t comprsize = h->size; + int nodes = 1; + while (h->size != 0) { + raxNode **cp = raxNodeLastChildPtr(h); + memcpy(&h, cp, sizeof(h)); + if (h->iskey || (!h->iscompr && h->size != 1)) + break; + /* Stop here if going to the next node would result into + * a compressed node larger than h->size can hold. */ + if (comprsize + h->size > RAX_NODE_MAX_SIZE) + break; + nodes++; + comprsize += h->size; + } + if (nodes > 1) { + /* If we can compress, create the new node and populate it. */ + size_t nodesize = + sizeof(raxNode) + comprsize + raxPadding(comprsize) + sizeof(raxNode *); + raxNode *new = rax_malloc(nodesize); + /* An out of memory here just means we cannot optimize this + * node, but the tree is left in a consistent state. */ + if (new == NULL) { + raxStackFree(&ts); + return 1; + } + new->iskey = 0; + new->isnull = 0; + new->iscompr = 1; + new->size = comprsize; + rax->numnodes++; + + /* Scan again, this time to populate the new node content and + * to fix the new node child pointer. At the same time we free + * all the nodes that we'll no longer use. */ + comprsize = 0; + h = start; + while (h->size != 0) { + memcpy(new->data + comprsize, h->data, h->size); + comprsize += h->size; + raxNode **cp = raxNodeLastChildPtr(h); + raxNode *tofree = h; + memcpy(&h, cp, sizeof(h)); + rax_free(tofree); + rax->numnodes--; + if (h->iskey || (!h->iscompr && h->size != 1)) + break; + } + debugnode("New node", new); + + /* Now 'h' points to the first node that we still need to use, + * so our new node child pointer will point to it. */ + raxNode **cp = raxNodeLastChildPtr(new); + memcpy(cp, &h, sizeof(h)); + + /* Fix parent link. */ + if (parent) { + raxNode **parentlink = raxFindParentLink(parent, start); + memcpy(parentlink, &new, sizeof(new)); + } else { + rax->head = new; + } + + debugf("Compressed %d nodes, %d total bytes\n", nodes, (int)comprsize); + } + } + raxStackFree(&ts); + return 1; +} + +/* This is the core of raxFree(): performs a depth-first scan of the + * tree and releases all the nodes found. */ +void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void *)) { + debugnode("free traversing", n); + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeLastChildPtr(n); + while (numchildren--) { + raxNode *child; + memcpy(&child, cp, sizeof(child)); + raxRecursiveFree(rax, child, free_callback); + cp--; + } + debugnode("free depth-first", n); + if (free_callback && n->iskey && !n->isnull) + free_callback(raxGetData(n)); + rax_free(n); + rax->numnodes--; +} + +/* Free a whole radix tree, calling the specified callback in order to + * free the auxiliary data. */ +void raxFreeWithCallback(rax *rax, void (*free_callback)(void *)) { + raxRecursiveFree(rax, rax->head, free_callback); + assert(rax->numnodes == 0); + rax_free(rax); +} + +/* Free a whole radix tree. */ +void raxFree(rax *rax) { raxFreeWithCallback(rax, NULL); } + +/* ------------------------------- Iterator --------------------------------- */ + +/* Initialize a Rax iterator. This call should be performed a single time + * to initialize the iterator, and must be followed by a raxSeek() call, + * otherwise the raxPrev()/raxNext() functions will just return EOF. */ +void raxStart(raxIterator *it, rax *rt) { + it->flags = RAX_ITER_EOF; /* No crash if the iterator is not seeked. */ + it->rt = rt; + it->key_len = 0; + it->key = it->key_static_string; + it->key_max = RAX_ITER_STATIC_LEN; + it->data = NULL; + it->node_cb = NULL; + raxStackInit(&it->stack); +} + +/* Append characters at the current key string of the iterator 'it'. This + * is a low level function used to implement the iterator, not callable by + * the user. Returns 0 on out of memory, otherwise 1 is returned. */ +int raxIteratorAddChars(raxIterator *it, unsigned char *s, size_t len) { + if (it->key_max < it->key_len + len) { + unsigned char *old = (it->key == it->key_static_string) ? NULL : it->key; + size_t new_max = (it->key_len + len) * 2; + it->key = rax_realloc(old, new_max); + if (it->key == NULL) { + it->key = (!old) ? it->key_static_string : old; + errno = ENOMEM; + return 0; + } + if (old == NULL) + memcpy(it->key, it->key_static_string, it->key_len); + it->key_max = new_max; + } + /* Use memmove since there could be an overlap between 's' and + * it->key when we use the current key in order to re-seek. */ + memmove(it->key + it->key_len, s, len); + it->key_len += len; + return 1; +} + +/* Remove the specified number of chars from the right of the current + * iterator key. */ +void raxIteratorDelChars(raxIterator *it, size_t count) { it->key_len -= count; } + +/* Do an iteration step towards the next element. At the end of the step the + * iterator key will represent the (new) current key. If it is not possible + * to step in the specified direction since there are no longer elements, the + * iterator is flagged with RAX_ITER_EOF. + * + * If 'noup' is true the function starts directly scanning for the next + * lexicographically smaller children, and the current node is already assumed + * to be the parent of the last key node, so the first operation to go back to + * the parent will be skipped. This option is used by raxSeek() when + * implementing seeking a non existing element with the ">" or "<" options: + * the starting node is not a key in that particular case, so we start the scan + * from a node that does not represent the key set. + * + * The function returns 1 on success or 0 on out of memory. */ +int raxIteratorNextStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while (1) { + int children = it->node->iscompr ? 1 : it->node->size; + if (!noup && children) { + debugf("GO DEEPER\n"); + /* Seek the lexicographically smaller key in this subtree, which + * is the first one found always going torwards the first child + * of every successive node. */ + if (!raxStackPush(&it->stack, it->node)) + return 0; + raxNode **cp = raxNodeFirstChildPtr(it->node); + if (!raxIteratorAddChars(it, it->node->data, it->node->iscompr ? it->node->size : 1)) + return 0; + memcpy(&it->node, cp, sizeof(it->node)); + /* Call the node callback if any, and replace the node pointer + * if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp, &it->node, sizeof(it->node)); + /* For "next" step, stop every time we find a key along the + * way, since the key is lexicograhically smaller compared to + * what follows in the sub-children. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } else { + /* If we finished exporing the previous sub-tree, switch to the + * new one: go upper until a node is found where there are + * children representing keys lexicographically greater than the + * current key. */ + while (1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + /* If there are no children at the current node, try parent's + * next child. */ + unsigned char prevchild = it->key[it->key_len - 1]; + if (!noup) { + it->node = raxStackPop(&it->stack); + } else { + noup = 0; + } + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it, todel); + + /* Try visiting the next child if there was at least one + * additional child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeFirstChildPtr(it->node); + int i = 0; + while (i < it->node->size) { + debugf("SCAN NEXT %c\n", it->node->data[i]); + if (it->node->data[i] > prevchild) + break; + i++; + cp++; + } + if (i != it->node->size) { + debugf("SCAN found a new node\n"); + raxIteratorAddChars(it, it->node->data + i, 1); + if (!raxStackPush(&it->stack, it->node)) + return 0; + memcpy(&it->node, cp, sizeof(it->node)); + /* Call the node callback if any, and replace the node + * pointer if the callback returns true. */ + if (it->node_cb && it->node_cb(&it->node)) + memcpy(cp, &it->node, sizeof(it->node)); + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + break; + } + } + } + } + } +} + +/* Seek the greatest key in the subtree at the current node. Return 0 on + * out of memory, otherwise 1. This is an helper function for different + * iteration functions below. */ +int raxSeekGreatest(raxIterator *it) { + while (it->node->size) { + if (it->node->iscompr) { + if (!raxIteratorAddChars(it, it->node->data, it->node->size)) + return 0; + } else { + if (!raxIteratorAddChars(it, it->node->data + it->node->size - 1, 1)) + return 0; + } + raxNode **cp = raxNodeLastChildPtr(it->node); + if (!raxStackPush(&it->stack, it->node)) + return 0; + memcpy(&it->node, cp, sizeof(it->node)); + } + return 1; +} + +/* Like raxIteratorNextStep() but implements an iteration step moving + * to the lexicographically previous element. The 'noup' option has a similar + * effect to the one of raxIteratorNextStep(). */ +int raxIteratorPrevStep(raxIterator *it, int noup) { + if (it->flags & RAX_ITER_EOF) { + return 1; + } else if (it->flags & RAX_ITER_JUST_SEEKED) { + it->flags &= ~RAX_ITER_JUST_SEEKED; + return 1; + } + + /* Save key len, stack items and the node where we are currently + * so that on iterator EOF we can restore the current key and state. */ + size_t orig_key_len = it->key_len; + size_t orig_stack_items = it->stack.items; + raxNode *orig_node = it->node; + + while (1) { + int old_noup = noup; + + /* Already on head? Can't go up, iteration finished. */ + if (!noup && it->node == it->rt->head) { + it->flags |= RAX_ITER_EOF; + it->stack.items = orig_stack_items; + it->key_len = orig_key_len; + it->node = orig_node; + return 1; + } + + unsigned char prevchild = it->key[it->key_len - 1]; + if (!noup) { + it->node = raxStackPop(&it->stack); + } else { + noup = 0; + } + + /* Adjust the current key to represent the node we are + * at. */ + int todel = it->node->iscompr ? it->node->size : 1; + raxIteratorDelChars(it, todel); + + /* Try visiting the prev child if there is at least one + * child. */ + if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { + raxNode **cp = raxNodeLastChildPtr(it->node); + int i = it->node->size - 1; + while (i >= 0) { + debugf("SCAN PREV %c\n", it->node->data[i]); + if (it->node->data[i] < prevchild) + break; + i--; + cp--; + } + /* If we found a new subtree to explore in this node, + * go deeper following all the last children in order to + * find the key lexicographically greater. */ + if (i != -1) { + debugf("SCAN found a new node\n"); + /* Enter the node we just found. */ + if (!raxIteratorAddChars(it, it->node->data + i, 1)) + return 0; + if (!raxStackPush(&it->stack, it->node)) + return 0; + memcpy(&it->node, cp, sizeof(it->node)); + /* Seek sub-tree max. */ + if (!raxSeekGreatest(it)) + return 0; + } + } + + /* Return the key: this could be the key we found scanning a new + * subtree, or if we did not find a new subtree to explore here, + * before giving up with this node, check if it's a key itself. */ + if (it->node->iskey) { + it->data = raxGetData(it->node); + return 1; + } + } +} + +/* Seek an iterator at the specified element. + * Return 0 if the seek failed for syntax error or out of memory. Otherwise + * 1 is returned. When 0 is returned for out of memory, errno is set to + * the ENOMEM value. */ +int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len) { + int eq = 0, lt = 0, gt = 0, first = 0, last = 0; + + it->stack.items = 0; /* Just resetting. Intialized by raxStart(). */ + it->flags |= RAX_ITER_JUST_SEEKED; + it->flags &= ~RAX_ITER_EOF; + it->key_len = 0; + it->node = NULL; + + /* Set flags according to the operator used to perform the seek. */ + if (op[0] == '>') { + gt = 1; + if (op[1] == '=') + eq = 1; + } else if (op[0] == '<') { + lt = 1; + if (op[1] == '=') + eq = 1; + } else if (op[0] == '=') { + eq = 1; + } else if (op[0] == '^') { + first = 1; + } else if (op[0] == '$') { + last = 1; + } else { + errno = 0; + return 0; /* Error. */ + } + + /* If there are no elements, set the EOF condition immediately and + * return. */ + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 1; + } + + if (first) { + /* Seeking the first key greater or equal to the empty string + * is equivalent to seeking the smaller key available. */ + return raxSeek(it, ">=", NULL, 0); + } + + if (last) { + /* Find the greatest key taking always the last child till a + * final node is found. */ + it->node = it->rt->head; + if (!raxSeekGreatest(it)) + return 0; + assert(it->node->iskey); + it->data = raxGetData(it->node); + return 1; + } + + /* We need to seek the specified key. What we do here is to actually + * perform a lookup, and later invoke the prev/next key code that + * we already use for iteration. */ + int splitpos = 0; + size_t i = raxLowWalk(it->rt, ele, len, &it->node, NULL, &splitpos, &it->stack); + + /* Return OOM on incomplete stack info. */ + if (it->stack.oom) + return 0; + + if (eq && i == len && (!it->node->iscompr || splitpos == 0) && it->node->iskey) { + /* We found our node, since the key matches and we have an + * "equal" condition. */ + if (!raxIteratorAddChars(it, ele, len)) + return 0; /* OOM. */ + it->data = raxGetData(it->node); + } else if (lt || gt) { + /* Exact key not found or eq flag not set. We have to set as current + * key the one represented by the node we stopped at, and perform + * a next/prev operation to seek. To reconstruct the key at this node + * we start from the parent and go to the current node, accumulating + * the characters found along the way. */ + if (!raxStackPush(&it->stack, it->node)) + return 0; + for (size_t j = 1; j < it->stack.items; j++) { + raxNode *parent = it->stack.stack[j - 1]; + raxNode *child = it->stack.stack[j]; + if (parent->iscompr) { + if (!raxIteratorAddChars(it, parent->data, parent->size)) + return 0; + } else { + raxNode **cp = raxNodeFirstChildPtr(parent); + unsigned char *p = parent->data; + while (1) { + raxNode *aux; + memcpy(&aux, cp, sizeof(aux)); + if (aux == child) + break; + cp++; + p++; + } + if (!raxIteratorAddChars(it, p, 1)) + return 0; + } + } + raxStackPop(&it->stack); + + /* We need to set the iterator in the correct state to call next/prev + * step in order to seek the desired element. */ + debugf("After initial seek: i=%d len=%d key=%.*s\n", (int)i, (int)len, (int)it->key_len, + it->key); + if (i != len && !it->node->iscompr) { + /* If we stopped in the middle of a normal node because of a + * mismatch, add the mismatching character to the current key + * and call the iterator with the 'noup' flag so that it will try + * to seek the next/prev child in the current node directly based + * on the mismatching character. */ + if (!raxIteratorAddChars(it, ele + i, 1)) + return 0; + debugf("Seek normal node on mismatch: %.*s\n", (int)it->key_len, (char *)it->key); + + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (lt && !raxIteratorPrevStep(it, 1)) + return 0; + if (gt && !raxIteratorNextStep(it, 1)) + return 0; + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else if (i != len && it->node->iscompr) { + debugf("Compressed mismatch: %.*s\n", (int)it->key_len, (char *)it->key); + /* In case of a mismatch within a compressed node. */ + int nodechar = it->node->data[splitpos]; + int keychar = ele[i]; + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (gt) { + /* If the key the compressed node represents is greater + * than our seek element, continue forward, otherwise set the + * state in order to go back to the next sub-tree. */ + if (nodechar > keychar) { + if (!raxIteratorNextStep(it, 0)) + return 0; + } else { + if (!raxIteratorAddChars(it, it->node->data, it->node->size)) + return 0; + if (!raxIteratorNextStep(it, 1)) + return 0; + } + } + if (lt) { + /* If the key the compressed node represents is smaller + * than our seek element, seek the greater key in this + * subtree, otherwise set the state in order to go back to + * the previous sub-tree. */ + if (nodechar < keychar) { + if (!raxSeekGreatest(it)) + return 0; + it->data = raxGetData(it->node); + } else { + if (!raxIteratorAddChars(it, it->node->data, it->node->size)) + return 0; + if (!raxIteratorPrevStep(it, 1)) + return 0; + } + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } else { + debugf("No mismatch: %.*s\n", (int)it->key_len, (char *)it->key); + /* If there was no mismatch we are into a node representing the + * key, (but which is not a key or the seek operator does not + * include 'eq'), or we stopped in the middle of a compressed node + * after processing all the key. Continue iterating as this was + * a legitimate key we stopped at. */ + it->flags &= ~RAX_ITER_JUST_SEEKED; + if (it->node->iscompr && it->node->iskey && splitpos && lt) { + /* If we stopped in the middle of a compressed node with + * perfect match, and the condition is to seek a key "<" than + * the specified one, then if this node is a key it already + * represents our match. For instance we may have nodes: + * + * "f" -> "oobar" = 1 -> "" = 2 + * + * Representing keys "f" = 1, "foobar" = 2. A seek for + * the key < "foo" will stop in the middle of the "oobar" + * node, but will be our match, representing the key "f". + * + * So in that case, we don't seek backward. */ + it->data = raxGetData(it->node); + } else { + if (gt && !raxIteratorNextStep(it, 0)) + return 0; + if (lt && !raxIteratorPrevStep(it, 0)) + return 0; + } + it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ + } + } else { + /* If we are here just eq was set but no match was found. */ + it->flags |= RAX_ITER_EOF; + return 1; + } + return 1; +} + +/* Go to the next element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxNext(raxIterator *it) { + if (!raxIteratorNextStep(it, 0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Go to the previous element in the scope of the iterator 'it'. + * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is + * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ +int raxPrev(raxIterator *it) { + if (!raxIteratorPrevStep(it, 0)) { + errno = ENOMEM; + return 0; + } + if (it->flags & RAX_ITER_EOF) { + errno = 0; + return 0; + } + return 1; +} + +/* Perform a random walk starting in the current position of the iterator. + * Return 0 if the tree is empty or on out of memory. Otherwise 1 is returned + * and the iterator is set to the node reached after doing a random walk + * of 'steps' steps. If the 'steps' argument is 0, the random walk is performed + * using a random number of steps between 1 and two times the logarithm of + * the number of elements. + * + * NOTE: if you use this function to generate random elements from the radix + * tree, expect a disappointing distribution. A random walk produces good + * random elements if the tree is not sparse, however in the case of a radix + * tree certain keys will be reported much more often than others. At least + * this function should be able to expore every possible element eventually. */ +int raxRandomWalk(raxIterator *it, size_t steps) { + if (it->rt->numele == 0) { + it->flags |= RAX_ITER_EOF; + return 0; + } + + if (steps == 0) { + size_t fle = floor(log(it->rt->numele)); + fle *= 2; + steps = 1 + rand() % fle; + } + + raxNode *n = it->node; + while (steps > 0 || !n->iskey) { + int numchildren = n->iscompr ? 1 : n->size; + int r = rand() % (numchildren + (n != it->rt->head)); + + if (r == numchildren) { + /* Go up to parent. */ + n = raxStackPop(&it->stack); + int todel = n->iscompr ? n->size : 1; + raxIteratorDelChars(it, todel); + } else { + /* Select a random child. */ + if (n->iscompr) { + if (!raxIteratorAddChars(it, n->data, n->size)) + return 0; + } else { + if (!raxIteratorAddChars(it, n->data + r, 1)) + return 0; + } + raxNode **cp = raxNodeFirstChildPtr(n) + r; + if (!raxStackPush(&it->stack, n)) + return 0; + memcpy(&n, cp, sizeof(n)); + } + if (n->iskey) + steps--; + } + it->node = n; + return 1; +} + +/* Compare the key currently pointed by the iterator to the specified + * key according to the specified operator. Returns 1 if the comparison is + * true, otherwise 0 is returned. */ +int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len) { + int eq = 0, lt = 0, gt = 0; + + if (op[0] == '=' || op[1] == '=') + eq = 1; + if (op[0] == '>') + gt = 1; + else if (op[0] == '<') + lt = 1; + else if (op[1] != '=') + return 0; /* Syntax error. */ + + size_t minlen = key_len < iter->key_len ? key_len : iter->key_len; + int cmp = memcmp(iter->key, key, minlen); + + /* Handle == */ + if (lt == 0 && gt == 0) + return cmp == 0 && key_len == iter->key_len; + + /* Handle >, >=, <, <= */ + if (cmp == 0) { + /* Same prefix: longer wins. */ + if (eq && key_len == iter->key_len) + return 1; + else if (lt) + return iter->key_len < key_len; + else if (gt) + return iter->key_len > key_len; + else + return 0; /* Avoid warning, just 'eq' is handled before. */ + } else if (cmp > 0) { + return gt ? 1 : 0; + } else /* (cmp < 0) */ { + return lt ? 1 : 0; + } +} + +/* Free the iterator. */ +void raxStop(raxIterator *it) { + if (it->key != it->key_static_string) + rax_free(it->key); + raxStackFree(&it->stack); +} + +/* Return if the iterator is in an EOF state. This happens when raxSeek() + * failed to seek an appropriate element, so that raxNext() or raxPrev() + * will return zero, or when an EOF condition was reached while iterating + * with raxNext() and raxPrev(). */ +int raxEOF(raxIterator *it) { return it->flags & RAX_ITER_EOF; } + +/* Return the number of elements inside the radix tree. */ +uint64_t raxSize(rax *rax) { return rax->numele; } + +/* ----------------------------- Introspection ------------------------------ */ + +/* This function is mostly used for debugging and learning purposes. + * It shows an ASCII representation of a tree on standard output, outling + * all the nodes and the contained keys. + * + * The representation is as follow: + * + * "foobar" (compressed node) + * [abc] (normal node with three children) + * [abc]=0x12345678 (node is a key, pointing to value 0x12345678) + * [] (a normal empty node) + * + * Children are represented in new idented lines, each children prefixed by + * the "`-(x)" string, where "x" is the edge byte. + * + * [abc] + * `-(a) "ladin" + * `-(b) [kj] + * `-(c) [] + * + * However when a node has a single child the following representation + * is used instead: + * + * [abc] -> "ladin" -> [] + */ + +/* The actual implementation of raxShow(). */ +void raxRecursiveShow(int level, int lpad, raxNode *n) { + char s = n->iscompr ? '"' : '['; + char e = n->iscompr ? '"' : ']'; + + int numchars = printf("%c%.*s%c", s, n->size, n->data, e); + if (n->iskey) { + numchars += printf("=%p", raxGetData(n)); + } + + int numchildren = n->iscompr ? 1 : n->size; + /* Note that 7 and 4 magic constants are the string length + * of " `-(x) " and " -> " respectively. */ + if (level) { + lpad += (numchildren > 1) ? 7 : 4; + if (numchildren == 1) + lpad += numchars; + } + raxNode **cp = raxNodeFirstChildPtr(n); + for (int i = 0; i < numchildren; i++) { + char *branch = " `-(%c) "; + if (numchildren > 1) { + printf("\n"); + for (int j = 0; j < lpad; j++) + putchar(' '); + printf(branch, n->data[i]); + } else { + printf(" -> "); + } + raxNode *child; + memcpy(&child, cp, sizeof(child)); + raxRecursiveShow(level + 1, lpad, child); + cp++; + } +} + +/* Show a tree, as outlined in the comment above. */ +void raxShow(rax *rax) { + raxRecursiveShow(0, 0, rax->head); + putchar('\n'); +} + +/* Used by debugnode() macro to show info about a given node. */ +void raxDebugShowNode(const char *msg, raxNode *n) { + if (raxDebugMsg == 0) + return; + printf("%s: %p [%.*s] key:%d size:%d children:", msg, (void *)n, (int)n->size, (char *)n->data, + n->iskey, n->size); + int numcld = n->iscompr ? 1 : n->size; + raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld - 1); + while (numcld--) { + raxNode *child; + memcpy(&child, cldptr, sizeof(child)); + cldptr++; + printf("%p ", (void *)child); + } + printf("\n"); + fflush(stdout); +} + +/* Touch all the nodes of a tree returning a check sum. This is useful + * in order to make Valgrind detect if there is something wrong while + * reading the data structure. + * + * This function was used in order to identify Rax bugs after a big refactoring + * using this technique: + * + * 1. The rax-test is executed using Valgrind, adding a printf() so that for + * the fuzz tester we see what iteration in the loop we are in. + * 2. After every modification of the radix tree made by the fuzz tester + * in rax-test.c, we add a call to raxTouch(). + * 3. Now as soon as an operation will corrupt the tree, raxTouch() will + * detect it (via Valgrind) immediately. We can add more calls to narrow + * the state. + * 4. At this point a good idea is to enable Rax debugging messages immediately + * before the moment the tree is corrupted, to see what happens. + */ +unsigned long raxTouch(raxNode *n) { + debugf("Touching %p\n", (void *)n); + unsigned long sum = 0; + if (n->iskey) { + sum += (unsigned long)raxGetData(n); + } + + int numchildren = n->iscompr ? 1 : n->size; + raxNode **cp = raxNodeFirstChildPtr(n); + int count = 0; + for (int i = 0; i < numchildren; i++) { + if (numchildren > 1) { + sum += (long)n->data[i]; + } + raxNode *child; + memcpy(&child, cp, sizeof(child)); + if (child == (void *)0x65d1760) + count++; + if (count > 1) + exit(1); + sum += raxTouch(child); + cp++; + } + return sum; +} diff --git a/src/util/rax.h b/src/util/rax.h new file mode 100644 index 000000000..6ccb69200 --- /dev/null +++ b/src/util/rax.h @@ -0,0 +1,218 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017-2018, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef RAX_H +#define RAX_H + +#include + +/* Representation of a radix tree as implemented in this file, that contains + * the strings "foo", "foobar" and "footer" after the insertion of each + * word. When the node represents a key inside the radix tree, we write it + * between [], otherwise it is written between (). + * + * This is the vanilla representation: + * + * (f) "" + * \ + * (o) "f" + * \ + * (o) "fo" + * \ + * [t b] "foo" + * / \ + * "foot" (e) (a) "foob" + * / \ + * "foote" (r) (r) "fooba" + * / \ + * "footer" [] [] "foobar" + * + * However, this implementation implements a very common optimization where + * successive nodes having a single child are "compressed" into the node + * itself as a string of characters, each representing a next-level child, + * and only the link to the node representing the last character node is + * provided inside the representation. So the above representation is turend + * into: + * + * ["foo"] "" + * | + * [t b] "foo" + * / \ + * "foot" ("er") ("ar") "foob" + * / \ + * "footer" [] [] "foobar" + * + * However this optimization makes the implementation a bit more complex. + * For instance if a key "first" is added in the above radix tree, a + * "node splitting" operation is needed, since the "foo" prefix is no longer + * composed of nodes having a single child one after the other. This is the + * above tree and the resulting node splitting after this event happens: + * + * + * (f) "" + * / + * (i o) "f" + * / \ + * "firs" ("rst") (o) "fo" + * / \ + * "first" [] [t b] "foo" + * / \ + * "foot" ("er") ("ar") "foob" + * / \ + * "footer" [] [] "foobar" + * + * Similarly after deletion, if a new chain of nodes having a single child + * is created (the chain must also not include nodes that represent keys), + * it must be compressed back into a single node. + * + */ + +#define RAX_NODE_MAX_SIZE ((1 << 29) - 1) +typedef struct raxNode { + uint32_t iskey : 1; /* Does this node contain a key? */ + uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ + uint32_t iscompr : 1; /* Node is compressed. */ + uint32_t size : 29; /* Number of children, or compressed string len. */ + /* Data layout is as follows: + * + * If node is not compressed we have 'size' bytes, one for each children + * character, and 'size' raxNode pointers, point to each child node. + * Note how the character is not stored in the children but in the + * edge of the parents: + * + * [header iscompr=0][abc][a-ptr][b-ptr][c-ptr](value-ptr?) + * + * if node is compressed (iscompr bit is 1) the node has 1 children. + * In that case the 'size' bytes of the string stored immediately at + * the start of the data section, represent a sequence of successive + * nodes linked one after the other, for which only the last one in + * the sequence is actually represented as a node, and pointed to by + * the current compressed node. + * + * [header iscompr=1][xyz][z-ptr](value-ptr?) + * + * Both compressed and not compressed nodes can represent a key + * with associated data in the radix tree at any level (not just terminal + * nodes). + * + * If the node has an associated key (iskey=1) and is not NULL + * (isnull=0), then after the raxNode pointers poiting to the + * children, an additional value pointer is present (as you can see + * in the representation above as "value-ptr" field). + */ + unsigned char data[]; +} raxNode; + +typedef struct rax { + raxNode *head; + uint64_t numele; + uint64_t numnodes; +} rax; + +/* Stack data structure used by raxLowWalk() in order to, optionally, return + * a list of parent nodes to the caller. The nodes do not have a "parent" + * field for space concerns, so we use the auxiliary stack when needed. */ +#define RAX_STACK_STATIC_ITEMS 32 +typedef struct raxStack { + void **stack; /* Points to static_items or an heap allocated array. */ + size_t items, maxitems; /* Number of items contained and total space. */ + /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap + * and use this static array of pointers instead. */ + void *static_items[RAX_STACK_STATIC_ITEMS]; + int oom; /* True if pushing into this stack failed for OOM at some point. */ +} raxStack; + +/* Optional callback used for iterators and be notified on each rax node, + * including nodes not representing keys. If the callback returns true + * the callback changed the node pointer in the iterator structure, and the + * iterator implementation will have to replace the pointer in the radix tree + * internals. This allows the callback to reallocate the node to perform + * very special operations, normally not needed by normal applications. + * + * This callback is used to perform very low level analysis of the radix tree + * structure, scanning each possible node (but the root node), or in order to + * reallocate the nodes to reduce the allocation fragmentation (this is the + * Redis application for this callback). + * + * This is currently only supported in forward iterations (raxNext) */ +typedef int (*raxNodeCallback)(raxNode **noderef); + +/* Radix tree iterator state is encapsulated into this data structure. */ +#define RAX_ITER_STATIC_LEN 128 +#define RAX_ITER_JUST_SEEKED \ + (1 << 0) /* Iterator was just seeked. Return current \ + element for the first iteration and \ + clear the flag. */ +#define RAX_ITER_EOF (1 << 1) /* End of iteration reached. */ +#define RAX_ITER_SAFE \ + (1 << 2) /* Safe iterator, allows operations while \ + iterating. But it is slower. */ +typedef struct raxIterator { + int flags; + rax *rt; /* Radix tree we are iterating. */ + unsigned char *key; /* The current string. */ + void *data; /* Data associated to this key. */ + size_t key_len; /* Current key length. */ + size_t key_max; /* Max key len the current key buffer can hold. */ + unsigned char key_static_string[RAX_ITER_STATIC_LEN]; + raxNode *node; /* Current node. Only for unsafe iteration. */ + raxStack stack; /* Stack used for unsafe iteration. */ + raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ +} raxIterator; + +/* A special pointer returned for not found items. */ +extern void *raxNotFound; + +/* Exported API. */ +rax *raxNew(void); +int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); +int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); +int raxRemove(rax *rax, unsigned char *s, size_t len, void **old); +void *raxFind(rax *rax, unsigned char *s, size_t len); +void raxFree(rax *rax); +void raxFreeWithCallback(rax *rax, void (*free_callback)(void *)); +void raxStart(raxIterator *it, rax *rt); +int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len); +int raxNext(raxIterator *it); +int raxPrev(raxIterator *it); +int raxRandomWalk(raxIterator *it, size_t steps); +int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len); +void raxStop(raxIterator *it); +int raxEOF(raxIterator *it); +void raxShow(rax *rax); +uint64_t raxSize(rax *rax); +unsigned long raxTouch(raxNode *n); +void raxSetDebugMsg(int onoff); + +/* Internal API. May be used by the node callback in order to access rax nodes + * in a low level way, so this function is exported as well. */ +void raxSetData(raxNode *n, void *data); + +#endif diff --git a/src/util/rax_malloc.h b/src/util/rax_malloc.h new file mode 100644 index 000000000..c4e92199e --- /dev/null +++ b/src/util/rax_malloc.h @@ -0,0 +1,43 @@ +/* Rax -- A radix tree implementation. + * + * Copyright (c) 2017, Salvatore Sanfilippo + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of Redis nor the names of its contributors may be used + * to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +/* Allocator selection. + * + * This file is used in order to change the Rax allocator at compile time. + * Just define the following defines to what you want to use. Also add + * the include of your alternate allocator if needed (not needed in order + * to use the default libc allocator). */ + +#ifndef RAX_ALLOC_H +#define RAX_ALLOC_H +#define rax_malloc malloc +#define rax_realloc realloc +#define rax_free free +#endif diff --git a/src/util/string_utils.c b/src/util/string_utils.c index 6b567c3ed..0df2ee28f 100644 --- a/src/util/string_utils.c +++ b/src/util/string_utils.c @@ -1,6 +1,7 @@ #include "string_utils.h" #include "dict.h" #include +#include #include "util/redisai_memory.h" RedisModuleString *RAI_HoldString(RedisModuleString *str) { @@ -52,3 +53,17 @@ void RAI_RStringsKeyDestructor(void *privdata, void *key) { void *RAI_RStringsKeyDup(void *privdata, const void *key) { return RedisModule_CreateStringFromString(NULL, (RedisModuleString *)key); } + +void String_ToUpper(const char *str, char *upper, size_t *upper_len) { + size_t str_len = strlen(str); + // Avoid overflow + RedisModule_Assert(*upper_len >= str_len); + + // Update the upper string buffer len. + *upper_len = str_len; + + for (size_t i = 0; i < str_len; i++) { + upper[i] = (char)toupper(str[i]); + } + upper[str_len] = 0; +} diff --git a/src/util/string_utils.h b/src/util/string_utils.h index 835fc45e6..d2c60614b 100644 --- a/src/util/string_utils.h +++ b/src/util/string_utils.h @@ -2,6 +2,7 @@ #include "dict.h" RedisModuleString *RAI_HoldString(RedisModuleString *str); +void String_ToUpper(const char *str, char *upper, size_t *upper_len); uint64_t RAI_StringsHashFunction(const void *key); int RAI_StringsKeyCompare(void *privdata, const void *key1, const void *key2); From c3a45e9fc80c2b01048d94378b2238dcdc3058a2 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 30 May 2021 14:58:22 +0300 Subject: [PATCH 05/27] Refactor - do not use rax, extend onnxRunSessions array whenever a new device is introduced and use rwlock to synchronise. --- src/backends/backends.c | 4 +- src/backends/onnxruntime.c | 10 ++-- src/backends/onnxruntime.h | 2 +- src/config/config.h | 2 +- src/execution/DAG/dag_execute.c | 19 +----- src/execution/background_workers.c | 55 ++++++++---------- src/execution/background_workers.h | 14 ++--- src/execution/onnx_timeout.c | 92 ++++++++++++++---------------- src/execution/onnx_timeout.h | 31 +++++++--- src/redisai.c | 58 ++++++++++--------- 10 files changed, 138 insertions(+), 149 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index 691e82e1e..414f70ea3 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -22,7 +22,7 @@ int RAI_GetApi(const char *func_name, void **targetPtrPtr) { if (strcmp("ThreadIdKey", func_name) == 0) { - *targetPtrPtr = GetQueueThreadIdKey; + *targetPtrPtr = GetThreadId; } else if (strcmp("NumThreadsPerQueue", func_name) == 0) { *targetPtrPtr = GetNumThreadsPerQueue; } else { @@ -496,7 +496,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { } backend.add_new_device = - (int (*)(const char *))(unsigned long)dlsym(handle, "AddDeviceToGlobalRunSessions"); + (int (*)(const char *))(unsigned long)dlsym(handle, "ExtendGlobalRunSessions"); if (backend.add_new_device == NULL) { dlclose(handle); RedisModule_Log(ctx, "warning", diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 581de22f5..4b24b7ecc 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -94,7 +94,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); // Export RedisAI callbacks. - get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadIdKey)); + get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadId)); get_api_fn_rai("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); // Create a global array of onnx runSessions, with an entry for every working thread. @@ -569,13 +569,13 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { } ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); - // Set the created run option in the global RunSessions and return it. - OnnxRunSessionCtx *run_session_ctx = - SetGetRunSessionCtx(mctxs[0]->model->devicestr, run_options); + // Set the created run option in the global RunSessions and save its index. + size_t run_session_index; + SetRunSessionCtx(run_options, &run_session_index); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - ClearRunSessionCtx(run_session_ctx); + ClearRunSessionCtx(run_session_index); run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index eae2aec8b..27cfa8037 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -9,7 +9,7 @@ unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); -pthread_key_t (*RedisAI_ThreadIdKey)(void); +pthread_key_t (*RedisAI_ThreadId)(void); long long (*RedisAI_NumThreadsPerQueue)(void); int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), diff --git a/src/config/config.h b/src/config/config.h index a05391810..d8fc461fe 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -19,7 +19,7 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; //#define RAI_COPY_RUN_INPUT #define RAI_COPY_RUN_OUTPUT #define RAI_PRINT_BACKEND_ERRORS -#define REDISAI_DEFAULT_THREADS_PER_QUEUE 4 +#define REDISAI_DEFAULT_THREADS_PER_QUEUE 1 #define REDISAI_DEFAULT_INTRA_OP_PARALLELISM 0 #define REDISAI_DEFAULT_INTER_OP_PARALLELISM 0 #define REDISAI_DEFAULT_MODEL_CHUNK_SIZE 535822336 // (511 * 1024 * 1024) diff --git a/src/execution/DAG/dag_execute.c b/src/execution/DAG/dag_execute.c index 78e5390a3..1b3e88d49 100644 --- a/src/execution/DAG/dag_execute.c +++ b/src/execution/DAG/dag_execute.c @@ -106,24 +106,7 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo) { RunQueueInfo **run_queues_info = array_new(RunQueueInfo *, ndevices); for (long long i = 0; i < ndevices; i++) { const char *device_str = devices[i]; - if (!IsRunQueueExists(device_str) == REDISMODULE_ERR) { - // A device run queue was not created properly, so we free everything, - // set an error and finish. - array_free(devices); - for (int j = 0; j < ndevices; j++) { - RAI_DagRunInfoFreeShallowCopy(rinfo_copies[j]); - } - array_free(rinfo_copies); - array_free(run_queues_info); - RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR Queue not initialized for device"); - return REDISMODULE_ERR; - } - - size_t device_str_len = strlen(device_str); - char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); - RunQueueInfo *run_queue_info = - raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len); + RunQueueInfo *run_queue_info = GetRunQueueInfo(device_str); run_queues_info = array_append(run_queues_info, run_queue_info); } for (long long i = 0; i < ndevices; i++) { diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 99f103e98..9e802410c 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -59,25 +59,18 @@ RunQueueInfo *CreateRunQueue(const char *device_str) { pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); run_queue_info->threads = (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * ThreadPoolSizePerQueue); - pthread_key_create(&(run_queue_info->thread_id_key), RedisModule_Free); - // Save device with its associate run queue info in rax. - /*todo: this should be protected from parallel writs. can add a lock, - or calling this function from main thread only in modelstore command.*/ - if (raxInsert(RunQueues, (unsigned char *)upper_device_str, device_str_len, run_queue_info, - NULL) != 1) { + // Save device with its associate run queue info in the dictionary. + if (AI_dictAdd(RunQueues, upper_device_str, run_queue_info) != DICT_OK) { RunQueueInfoFree(run_queue_info); return NULL; } // Create worker threads. for (int i = 0; i < ThreadPoolSizePerQueue; i++) { - WorkerThreadInfo *thread_info = RedisModule_Alloc(sizeof(WorkerThreadInfo)); - thread_info->run_queue_info = run_queue_info; - thread_info->id = i; if (pthread_create(&(run_queue_info->threads[i]), NULL, RedisAI_Run_ThreadMain, - thread_info) != 0) { - raxRemove(RunQueues, (unsigned char *)upper_device_str, device_str_len, NULL); + run_queue_info) != 0) { + AI_dictDelete(RunQueues, upper_device_str); RunQueueInfoFree(run_queue_info); return NULL; } @@ -90,25 +83,27 @@ RunQueueInfo *CreateRunQueue(const char *device_str) { return run_queue_info; } -bool IsRunQueueExists(const char *device_str) { +RunQueueInfo *GetRunQueueInfo(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; String_ToUpper(device_str, upper_device_str, &device_str_len); - if (raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len) == raxNotFound) { - return false; - } - return true; + AI_dictEntry *entry = AI_dictFind(RunQueues, upper_device_str); + RedisModule_Assert(entry != NULL); + return AI_dictGetVal(entry); } -pthread_key_t GetQueueThreadIdKey(const char *device_str) { +bool IsRunQueueExists(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; String_ToUpper(device_str, upper_device_str, &device_str_len); + if (AI_dictFind(RunQueues, upper_device_str) == NULL) { + return false; + } + return true; +} - RunQueueInfo *run_queue_info = - (RunQueueInfo *)raxFind(RunQueues, (unsigned char *)upper_device_str, device_str_len); - RedisModule_Assert(run_queue_info != raxNotFound); - return run_queue_info->thread_id_key; +uintptr_t GetThreadId() { + return *(uintptr_t *)pthread_getspecific(thread_id_key); } long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } @@ -125,18 +120,16 @@ void RunQueueInfoFree(RunQueueInfo *run_queue_info) { } pthread_mutex_destroy(&(run_queue_info->run_queue_mutex)); pthread_cond_destroy(&(run_queue_info->queue_condition_var)); - pthread_key_delete(run_queue_info->thread_id_key); RedisModule_Free(run_queue_info); } /** - * @brief Save the id for some working thread in thread local storage. Every - * device has a designated id key saved within its run_queue_info, which is used - * for storing and retrieving the id in the thread local storage. + * @brief Save the id for some working thread in thread local storage. */ -static void _SaveThreadId(pthread_key_t thread_id_key, int id) { - int *id_value = RedisModule_Alloc(sizeof(int)); - *id_value = id; +static void _SaveThreadId() { + uintptr_t *id_value = RedisModule_Alloc(sizeof(uintptr_t)); + // Let the current thread have the next available id, and increase the counter. + *id_value = __atomic_fetch_add(&BGWorkersCounter, 1, __ATOMIC_RELAXED); pthread_setspecific(thread_id_key, id_value); } @@ -330,10 +323,8 @@ static RedisAI_RunInfo **_BGThread_BatchOperations(RunQueueInfo *run_queue_info, } void *RedisAI_Run_ThreadMain(void *arg) { - WorkerThreadInfo *thread_info = (WorkerThreadInfo *)arg; - RunQueueInfo *run_queue_info = thread_info->run_queue_info; - _SaveThreadId(run_queue_info->thread_id_key, thread_info->id); - RedisModule_Free(thread_info); + _SaveThreadId(); + RunQueueInfo *run_queue_info = (RunQueueInfo *)arg; RedisAI_RunInfo **batch_rinfo = array_new(RedisAI_RunInfo *, 1); pthread_mutex_lock(&run_queue_info->run_queue_mutex); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 147f07c82..8457d5016 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -30,29 +30,27 @@ #include "util/rax.h" #include "util/queue.h" -rax *RunQueues; +AI_dict *RunQueues; long long ThreadPoolSizePerQueue; +uintptr_t BGWorkersCounter; +pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. typedef struct RunQueueInfo { pthread_mutex_t run_queue_mutex; pthread_cond_t queue_condition_var; queue *run_queue; pthread_t *threads; - pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. char *device_str; } RunQueueInfo; -typedef struct WorkerThreadInfo { - RunQueueInfo *run_queue_info; - int id; -} WorkerThreadInfo; - void RunQueueInfoFree(RunQueueInfo *info); RunQueueInfo *CreateRunQueue(const char *device_str); bool IsRunQueueExists(const char *device_str); -pthread_key_t GetQueueThreadIdKey(const char *device_str); +RunQueueInfo *GetRunQueueInfo(const char *device_str); + +uintptr_t GetThreadId(void); long long GetNumThreadsPerQueue(void); diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index 269a6f697..fb0441232 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -17,74 +17,70 @@ static long long _mstime(void) { } int CreateGlobalOnnxRunSessions() { - OnnxRunSessions = raxNew(); - if (OnnxRunSessions == NULL) { - return REDISMODULE_ERR; - } - return AddDeviceToGlobalRunSessions("CPU"); + onnx_global_run_sessions = RedisModule_Alloc(sizeof(struct OnnxGlobalRunSessions)); + OnnxRunSessionCtx **onnx_run_sessions = + array_new(OnnxRunSessionCtx *, RedisAI_NumThreadsPerQueue()); + onnx_global_run_sessions->OnnxRunSessions = onnx_run_sessions; + pthread_rwlock_init(&(onnx_global_run_sessions->rwlock), NULL); + return ExtendGlobalRunSessions("CPU"); // Add entries for CPU threads. } -int AddDeviceToGlobalRunSessions(const char *device) { +int ExtendGlobalRunSessions(const char *device_str) { + + // Acquire write lock, as we might reallocate the array while extending it. + pthread_rwlock_wrlock(&(onnx_global_run_sessions->rwlock)); + OnnxRunSessionCtx **run_sessions_array = onnx_global_run_sessions->OnnxRunSessions; + // Extend the array with an entry for every working thread on the new device, initialized to NULL. size_t size = RedisAI_NumThreadsPerQueue(); - // Create array with an entry for every working thread, initialized to NULL. - OnnxRunSessionCtx **device_run_sessions = array_new(OnnxRunSessionCtx *, size); for (size_t i = 0; i < size; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); - device_run_sessions = array_append(device_run_sessions, entry); - } - // Add the array to the global rax that holds onnx run sessions per device. - size_t device_str_len = strlen(device); - char upper_device_str[device_str_len + 1]; - String_ToUpper(device, upper_device_str, &device_str_len); - if (raxInsert(OnnxRunSessions, (unsigned char *)upper_device_str, device_str_len, - device_run_sessions, NULL) != 1) { - return REDISMODULE_ERR; + run_sessions_array = array_append(run_sessions_array, entry); } + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); return REDISMODULE_OK; } void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); - - raxIterator rax_it; - raxStart(&rax_it, OnnxRunSessions); - raxSeek(&rax_it, "^", NULL, 0); - - // Go over all the possible existing run sessions for every device. - while (raxNext(&rax_it)) { - OnnxRunSessionCtx **onnx_run_sessions_per_device = rax_it.data; - size_t threads_per_device = array_len(onnx_run_sessions_per_device); - for (size_t i = 0; i < threads_per_device; i++) { - if (onnx_run_sessions_per_device[i]->runOptions == NULL) { - continue; - } - long long curr_time = _mstime(); - // Check if a sessions is running for too long, and kill it if so. - if (curr_time - onnx_run_sessions_per_device[i]->queuingTime > ONNX_MAX_RUNTIME) { - ort->RunOptionsSetTerminate(onnx_run_sessions_per_device[i]->runOptions); - } + pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); + OnnxRunSessionCtx **run_sessions_ctx = onnx_global_run_sessions->OnnxRunSessions; + size_t len = array_len(run_sessions_ctx); + for (size_t i = 0; i < len; i++) { + if (run_sessions_ctx[i]->runOptions == NULL) { + continue; + } + long long curr_time = _mstime(); + // Check if a sessions is running for too long, and kill it if so. + if (curr_time - run_sessions_ctx[i]->queuingTime > ONNX_MAX_RUNTIME) { + ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions); } } + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -OnnxRunSessionCtx *SetGetRunSessionCtx(const char *device, OrtRunOptions *new_run_options) { - - int *thread_ind = (int *)pthread_getspecific(RedisAI_ThreadIdKey()); +void SetRunSessionCtx(OrtRunOptions *new_run_options, + size_t *run_session_index) { - OnnxRunSessionCtx **device_run_sessions = - raxFind(OnnxRunSessions, (unsigned char *)device, strlen(device)); - RedisModule_Assert(device_run_sessions != raxNotFound); - RedisModule_Assert(device_run_sessions[*thread_ind]->runOptions == NULL); + pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); + // Get the thread index (which is the correspondent index in the global sessions array). + *run_session_index = (size_t)RedisAI_ThreadId(); + OnnxRunSessionCtx *entry = + onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; + RedisModule_Assert(entry->runOptions == NULL); - device_run_sessions[*thread_ind]->runOptions = new_run_options; - device_run_sessions[*thread_ind]->queuingTime = _mstime(); - return device_run_sessions[*thread_ind]; + // Update the entry with the current session data. + entry->runOptions = new_run_options; + entry->queuingTime = _mstime(); + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void ClearRunSessionCtx(OnnxRunSessionCtx *run_session_ctx) { +void ClearRunSessionCtx(size_t run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); - ort->ReleaseRunOptions(run_session_ctx->runOptions); - run_session_ctx->runOptions = NULL; + pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); + OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; + ort->ReleaseRunOptions(entry->runOptions); + entry->runOptions = NULL; + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index dbb585432..96f9b76e6 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -13,17 +13,34 @@ typedef struct OnnxRunSessionCtx { OrtRunOptions *runOptions; } OnnxRunSessionCtx; -// This is a global rax that holds an array of OnnxRunSessionCtx for every device -// that onnx models may run on. -rax *OnnxRunSessions; - +// This is a global array of OnnxRunSessionCtx. Contains an entry for every thread +// (on every device) that onnx models may run on. +typedef struct OnnxGlobalRunSessions { + OnnxRunSessionCtx **OnnxRunSessions; + pthread_rwlock_t rwlock; +} OnnxGlobalRunSessions; + +OnnxGlobalRunSessions *onnx_global_run_sessions; + +/** + * @brief This is called whenever Onnx backend is loaded. It creates the global + * OnnxGlobalRunSessions structure with entry-per-thread (for CPU threads at first), + * so that every thread will have a designated entry to update with the onnx session + * that it's going to run. + */ int CreateGlobalOnnxRunSessions(void); -int AddDeviceToGlobalRunSessions(const char *device); +/** + * @brief This is called whenever RedisAI gets a request to store a model that run + * on a new device, and creates some more working thread, as configured in + * ThreadPerQueue. Thus, the global array of onnx sessions that has an + * entry-per-thread is extended accordingly. + */ +int ExtendGlobalRunSessions(const char *device_str); void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data); -OnnxRunSessionCtx *SetGetRunSessionCtx(const char *device, OrtRunOptions *new_run_options); +void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index); -void ClearRunSessionCtx(OnnxRunSessionCtx *run_session_ctx); +void ClearRunSessionCtx(size_t run_session_index); diff --git a/src/redisai.c b/src/redisai.c index 6c173303a..f52f1c7a0 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -1264,42 +1264,45 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_FreeString(NULL, main_thread_used_cpu_sys); RedisModule_FreeString(NULL, main_thread_used_cpu_user); - raxIterator rax_it; - raxStart(&rax_it, RunQueues); - raxSeek(&rax_it, "^", NULL, 0); - while (raxNext(&rax_it)) { - char *queue_name = (char *)rax_it.key; - RunQueueInfo *run_queue_info = (RunQueueInfo *)rax_it.data; - for (int i = 0; i < ThreadPoolSizePerQueue; i++) { - pthread_t current_bg_threads = run_queue_info->threads[i]; - struct timespec ts; - clockid_t cid; - RedisModuleString *queue_used_cpu_total = RedisModule_CreateStringPrintf( - NULL, "queue_%s_bthread_n%d_used_cpu_total", queue_name, i + 1); - RedisModuleString *bthread_used_cpu_total = NULL; + AI_dictIterator *iter = AI_dictGetSafeIterator(RunQueues); + AI_dictEntry *entry = AI_dictNext(iter); + while (entry) { + char *queue_name = (char *)AI_dictGetKey(entry); + RunQueueInfo *run_queue_info = (RunQueueInfo *)AI_dictGetVal(entry); + if (run_queue_info) { + for (int i = 0; i < ThreadPoolSizePerQueue; i++) { + pthread_t current_bg_threads = run_queue_info->threads[i]; + struct timespec ts; + clockid_t cid; + RedisModuleString *queue_used_cpu_total = RedisModule_CreateStringPrintf( + NULL, "queue_%s_bthread_n%d_used_cpu_total", queue_name, i + 1); + RedisModuleString *bthread_used_cpu_total = NULL; #if (!defined(_POSIX_C_SOURCE) && !defined(_XOPEN_SOURCE)) || defined(_DARWIN_C_SOURCE) || \ defined(__cplusplus) - const int status = -1; + const int status = -1; #else - const int status = pthread_getcpuclockid(current_bg_threads, &cid); + const int status = pthread_getcpuclockid(current_bg_threads, &cid); #endif - if (status != 0) { - bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); - } else { - if (clock_gettime(cid, &ts) == -1) { + if (status != 0) { bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); } else { - bthread_used_cpu_total = RedisModule_CreateStringPrintf( - NULL, "%ld.%06ld", (long)ts.tv_sec, (long)(ts.tv_nsec / 1000)); + if (clock_gettime(cid, &ts) == -1) { + bthread_used_cpu_total = RedisModule_CreateStringPrintf(NULL, "N/A"); + } else { + bthread_used_cpu_total = RedisModule_CreateStringPrintf( + NULL, "%ld.%06ld", (long)ts.tv_sec, (long)(ts.tv_nsec / 1000)); + } } + RedisModule_InfoAddFieldString( + ctx, (char *)RedisModule_StringPtrLen(queue_used_cpu_total, NULL), + bthread_used_cpu_total); + RedisModule_FreeString(NULL, queue_used_cpu_total); + RedisModule_FreeString(NULL, bthread_used_cpu_total); } - RedisModule_InfoAddFieldString( - ctx, (char *)RedisModule_StringPtrLen(queue_used_cpu_total, NULL), - bthread_used_cpu_total); - RedisModule_FreeString(NULL, queue_used_cpu_total); - RedisModule_FreeString(NULL, bthread_used_cpu_total); } + entry = AI_dictNext(iter); } + AI_dictReleaseIterator(iter); } int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { @@ -1455,7 +1458,8 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RAI_loadTimeConfig(ctx, argv, argc); - RunQueues = raxNew(); + RunQueues = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + pthread_key_create(&thread_id_key, RedisModule_Free); RunQueueInfo *cpu_run_queue_info = CreateRunQueue("CPU"); if (cpu_run_queue_info == NULL) { RedisModule_Log(ctx, "warning", "RedisAI could not initialize run queue for CPU"); From 26848529a1f1c086148bb925854bbe36726e14a0 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 30 May 2021 16:25:37 +0300 Subject: [PATCH 06/27] Refactor backends loading --- src/backends/backends.c | 339 +++++++----------------- src/backends/backends.h | 5 +- src/backends/onnxruntime.c | 7 +- src/backends/onnxruntime.h | 3 +- src/execution/onnx_timeout.c | 6 +- src/execution/onnx_timeout.h | 22 +- src/redisai.c | 4 +- src/serialization/AOF/rai_aof_rewrite.c | 2 +- 8 files changed, 135 insertions(+), 253 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index 414f70ea3..3031520aa 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -15,10 +15,17 @@ #include #include #include -#include "execution/onnx_timeout.h" #include "redismodule.h" +static bool _ValidateAPICreated(RedisModuleCtx *ctx, void *func_ptr, const char *func_name) { + if (func_ptr == NULL) { + RedisModule_Log(ctx, "warning", "Backend does not export %s", func_name); + return false; + } + return true; +} + int RAI_GetApi(const char *func_name, void **targetPtrPtr) { if (strcmp("ThreadIdKey", func_name) == 0) { @@ -26,7 +33,7 @@ int RAI_GetApi(const char *func_name, void **targetPtrPtr) { } else if (strcmp("NumThreadsPerQueue", func_name) == 0) { *targetPtrPtr = GetNumThreadsPerQueue; } else { - return REDISMODULE_ERR; + return RedisModule_GetApi(func_name, targetPtrPtr); } return REDISMODULE_OK; } @@ -59,7 +66,7 @@ RedisModuleString *RAI_GetBackendsPath(RedisModuleCtx *ctx) { return backends_path; } -const char *RAI_BackendName(int backend) { +const char *GetBackendName(RAI_Backend backend) { switch (backend) { case RAI_BACKEND_TENSORFLOW: return "TF"; @@ -80,24 +87,17 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { } void *handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); - if (handle == NULL) { RedisModule_Log(ctx, "warning", "Could not load TF backend from %s: %s", path, dlerror()); return REDISMODULE_ERR; } - - RAI_LoadedBackend backend = {0}; + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym(handle, "RAI_InitBackendTF"); - if (init_backend == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_InitBackendTF. TF backend not " - "loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTF")) { + goto error; } init_backend(RedisModule_GetApi); @@ -105,63 +105,42 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, size_t, const char **, size_t, const char **, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF"); - if (backend.model_create_with_nodes == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelCreateTF. TF backend not " - "loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_create_with_nodes, "RAI_ModelCreateTF")) { + goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTF"); - if (backend.model_free == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelFreeTF. TF backend not " - "loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTF")) { + goto error; } backend.model_run = (int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunTF"); - if (backend.model_run == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelRunTF. TF backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTF")) { + goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( (unsigned long)dlsym(handle, "RAI_ModelSerializeTF")); - if (backend.model_serialize == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelSerializeTF. TF backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTF")) { + goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTF"); - if (backend.get_version == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetBackendVersionTF. TF backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTF")) { + goto error; } RAI_backends.tf = backend; - RedisModule_Log(ctx, "notice", "TF backend loaded from %s", path); - return REDISMODULE_OK; + + error: + dlclose(handle); + RedisModule_Log(ctx, "warning", "TF backend not loaded from %s", path); + return REDISMODULE_ERR; } int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { @@ -177,83 +156,55 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { dlerror()); return REDISMODULE_ERR; } - - RAI_LoadedBackend backend = {0}; + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( handle, "RAI_InitBackendTFLite"); - if (init_backend == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_InitBackendTFLite. TFLITE " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTFLite")) { + goto error; } init_backend(RedisModule_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite"); - if (backend.model_create == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelCreateTFLite. TFLITE " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateTFLite")) { + goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTFLite"); - if (backend.model_free == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelFreeTFLite. TFLITE " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTFLite")) { + goto error; } backend.model_run = (int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym( handle, "RAI_ModelRunTFLite"); - if (backend.model_run == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelRunTFLite. TFLITE " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTFLite")) { + goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeTFLite"); - if (backend.model_serialize == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelSerializeTFLite. TFLITE " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTFLite")) { + goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTFLite"); - if (backend.get_version == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetBackendVersionTFLite. TFLite backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTFLite")) { + goto error; } RAI_backends.tflite = backend; - RedisModule_Log(ctx, "notice", "TFLITE backend loaded from %s", path); - return REDISMODULE_OK; + + error: + dlclose(handle); + RedisModule_Log(ctx, "warning", "TFLITE backend not loaded from %s", path); + return REDISMODULE_ERR; } int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { @@ -263,122 +214,79 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { } void *handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); - if (handle == NULL) { RedisModule_Log(ctx, "warning", "Could not load TORCH backend from %s: %s", path, dlerror()); return REDISMODULE_ERR; } - RAI_LoadedBackend backend = {0}; + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( handle, "RAI_InitBackendTorch"); - if (init_backend == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_InitBackendTorch. TORCH " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTorch")) { + goto error; } init_backend(RedisModule_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch"); - if (backend.model_create == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelCreateTorch. TORCH " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateTorch")) { + goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTorch"); - if (backend.model_free == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelFreeTorch. TORCH backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTorch")) { + goto error; } backend.model_run = (int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunTorch"); - if (backend.model_run == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelRunTorch. TORCH backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTorch")) { + goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeTorch"); - if (backend.model_serialize == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelSerializeTorch. TORCH " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTorch")) { + goto error; } backend.script_create = (RAI_Script * (*)(const char *, const char *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ScriptCreateTorch"); - if (backend.script_create == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ScriptCreateTorch. TORCH " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.script_create, "RAI_ScriptCreateTorch")) { + goto error; } backend.script_free = (void (*)(RAI_Script *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ScriptFreeTorch"); - if (backend.script_free == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ScriptFreeTorch. TORCH " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.script_free, "RAI_ScriptFreeTorch")) { + goto error; } backend.script_run = (int (*)(RAI_ScriptRunCtx *, RAI_Error *))(unsigned long)dlsym( handle, "RAI_ScriptRunTorch"); - if (backend.script_run == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ScriptRunTorch. TORCH backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.script_run, "RAI_ScriptRunTorch")) { + goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTorch"); - if (backend.get_version == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetBackendVersionTorch. TORCH backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTorch")) { + goto error; } RAI_backends.torch = backend; - RedisModule_Log(ctx, "notice", "TORCH backend loaded from %s", path); - return REDISMODULE_OK; + + error: + dlclose(handle); + RedisModule_Log(ctx, "warning", "TORCH backend not loaded from %s", path); + return REDISMODULE_ERR; } int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { @@ -393,125 +301,82 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { RedisModule_Log(ctx, "warning", "Could not load ONNX backend from %s: %s", path, dlerror()); return REDISMODULE_ERR; } - RAI_LoadedBackend backend = {0}; - int (*init_backend)(int (*)(const char *, void *), int (*)(const char *, void *)); - init_backend = (int (*)(int (*)(const char *, void *), int (*)(const char *, void *)))( + int (*init_backend)(int (*)(const char *, void **)); + init_backend = (int (*) (int (*)(const char *, void **)))( unsigned long)dlsym(handle, "RAI_InitBackendORT"); - if (init_backend == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_InitBackendORT. ONNX backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendORT")) { + goto error; } - init_backend(RedisModule_GetApi, (int (*)(const char *, void *))RAI_GetApi); + init_backend(RAI_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT"); - if (backend.model_create == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelCreateORT. ONNX backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateORT")) { + goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeORT"); - if (backend.model_free == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelFreeORT. ONNX backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeORT")) { + goto error; } backend.model_run = (int (*)(RAI_ModelRunCtx **, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelRunORT"); - if (backend.model_run == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelRunORT. ONNX backend not " - "loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunORT")) { + goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeORT"); - if (backend.model_serialize == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_ModelSerializeORT. ONNX " - "backend not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeORT")) { + goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionORT"); - if (backend.get_version == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetBackendVersionORT. ONNX backend " - "not loaded from %s", - path); - return REDISMODULE_ERR; + if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionORT")) { + goto error; } backend.get_memory_info = (unsigned long long (*)(void))(unsigned long)dlsym(handle, "RAI_GetMemoryInfoORT"); - if (backend.get_memory_info == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetMemoryInfoORT. ONNX backend " - "not loaded from %s", - path); + if (!_ValidateAPICreated(ctx, backend.get_memory_info, "RAI_GetMemoryInfoORT")) { + goto error; } + backend.get_memory_access_num = (unsigned long long (*)(void))(unsigned long)dlsym(handle, "RAI_GetMemoryAccessORT"); - if (backend.get_memory_access_num == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export RAI_GetMemoryAccessORT. ONNX backend " - "not loaded from %s", - path); + if (!_ValidateAPICreated(ctx, backend.get_memory_access_num, "RAI_GetMemoryAccessORT")) { + goto error; } backend.enforce_runtime_duration = (void (*)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *))(unsigned long)dlsym( - handle, "OnnxEnforceTimeoutCallback"); - if (backend.enforce_runtime_duration == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export OnnxEnforceTimeoutCallback. ONNX backend " - "not loaded from %s", - path); + handle, "RAI_EnforceTimeoutORT"); + if (!_ValidateAPICreated(ctx, backend.enforce_runtime_duration, "RAI_EnforceTimeoutORT")) { + goto error; } backend.add_new_device = - (int (*)(const char *))(unsigned long)dlsym(handle, "ExtendGlobalRunSessions"); - if (backend.add_new_device == NULL) { - dlclose(handle); - RedisModule_Log(ctx, "warning", - "Backend does not export AddDeviceToGlobalRunSessions. ONNX backend " - "not loaded from %s", - path); + (int (*)(const char *))(unsigned long)dlsym(handle, "RAI_AddNewDeviceORT"); + if (!_ValidateAPICreated(ctx, backend.add_new_device, "RAI_AddNewDeviceORT")) { + goto error; } RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, - backend.enforce_runtime_duration); - + backend.enforce_runtime_duration); RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); - return REDISMODULE_OK; + + error: + dlclose(handle); + RedisModule_Log(ctx, "warning", "ONNX backend not loaded from %s", path); + return REDISMODULE_ERR; } int RAI_LoadBackend(RedisModuleCtx *ctx, int backend, const char *path) { diff --git a/src/backends/backends.h b/src/backends/backends.h index bf67c87b6..4e00cfe1e 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -102,4 +102,7 @@ char *RAI_BackendsPath; int RAI_LoadBackend(RedisModuleCtx *ctx, int backend, const char *path); int RAI_LoadDefaultBackend(RedisModuleCtx *ctx, int backend); -const char *RAI_BackendName(int backend); +/** + * @brief Returns the backend name as string. + */ +const char *GetBackendName(RAI_Backend backend); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 4b24b7ecc..179b82d57 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -80,8 +80,7 @@ unsigned long long RAI_GetMemoryInfoORT() { return OnnxMemory; } unsigned long long RAI_GetMemoryAccessORT() { return OnnxMemoryAccessCounter; } -int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), - int (*get_api_fn_rai)(const char *, void *)) { +int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) { // Export redis callbacks. get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc)); get_api_fn("RedisModule_Calloc", ((void **)&RedisModule_Calloc)); @@ -94,8 +93,8 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); // Export RedisAI callbacks. - get_api_fn_rai("ThreadIdKey", ((void **)&RedisAI_ThreadId)); - get_api_fn_rai("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); + get_api_fn("ThreadIdKey", ((void **)&RedisAI_ThreadId)); + get_api_fn("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); // Create a global array of onnx runSessions, with an entry for every working thread. CreateGlobalOnnxRunSessions(); diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index 27cfa8037..8fcc7c3d7 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -12,8 +12,7 @@ unsigned long long RAI_GetMemoryAccessORT(void); pthread_key_t (*RedisAI_ThreadId)(void); long long (*RedisAI_NumThreadsPerQueue)(void); -int RAI_InitBackendORT(int (*get_api_fn)(const char *, void *), - int (*get_api_fn_rai)(const char *, void *)); +int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)); RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, const char *modeldef, size_t modellen, RAI_Error *err); diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index fb0441232..abf5d276d 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -22,10 +22,10 @@ int CreateGlobalOnnxRunSessions() { array_new(OnnxRunSessionCtx *, RedisAI_NumThreadsPerQueue()); onnx_global_run_sessions->OnnxRunSessions = onnx_run_sessions; pthread_rwlock_init(&(onnx_global_run_sessions->rwlock), NULL); - return ExtendGlobalRunSessions("CPU"); // Add entries for CPU threads. + return RAI_AddNewDeviceORT("CPU"); // Add entries for CPU threads. } -int ExtendGlobalRunSessions(const char *device_str) { +int RAI_AddNewDeviceORT(const char *device_str) { // Acquire write lock, as we might reallocate the array while extending it. pthread_rwlock_wrlock(&(onnx_global_run_sessions->rwlock)); @@ -41,7 +41,7 @@ int ExtendGlobalRunSessions(const char *device_str) { return REDISMODULE_OK; } -void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, +void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index 96f9b76e6..8e1c35bf1 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -2,7 +2,6 @@ #include "backends/onnxruntime.h" #include "onnxruntime_c_api.h" -#include "util/rax.h" // The maximum time in milliseconds before killing onnx run session. // todo: make it a load time config @@ -36,11 +35,28 @@ int CreateGlobalOnnxRunSessions(void); * ThreadPerQueue. Thus, the global array of onnx sessions that has an * entry-per-thread is extended accordingly. */ -int ExtendGlobalRunSessions(const char *device_str); +int RAI_AddNewDeviceORT(const char *device_str); -void OnnxEnforceTimeoutCallback(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, +/** + * @brief A callback that is registered to RedisCron event, that is, it is called + * periodically and go over all the (possibly running) onnx sessions, and kill + * those that exceeds the timeout. + */ +void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data); +/** + * @brief Set a new OrtRunOptions in the global structure, to allow us to + * "terminate" the run session from the cron callback. + * @param new_run_options - The newly created OrtRunOptions to store. + * @param run_session_index - placeholder for the index of the running thread + * in the global array, to have a quick access later to clean this entry. + */ void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index); +/** + * @brief Release the OrtRunOptions of a session that finished its run and + * reset the corresponding entry in the global structure. + * @param run_session_index - The entry index where OrtRunOptions was stored. + */ void ClearRunSessionCtx(size_t run_session_index); diff --git a/src/redisai.c b/src/redisai.c index f52f1c7a0..155ba6af3 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -463,7 +463,7 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_ReplyWithArray(ctx, outentries); RedisModule_ReplyWithCString(ctx, "backend"); - const char *backendstr = RAI_BackendName(mto->backend); + const char *backendstr = GetBackendName(mto->backend); RedisModule_ReplyWithCString(ctx, backendstr); RedisModule_ReplyWithCString(ctx, "device"); @@ -938,7 +938,7 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModule_ReplyWithCString(ctx, "SCRIPT"); } RedisModule_ReplyWithCString(ctx, "backend"); - RedisModule_ReplyWithCString(ctx, RAI_BackendName(rstats->backend)); + RedisModule_ReplyWithCString(ctx, GetBackendName(rstats->backend)); RedisModule_ReplyWithCString(ctx, "device"); RedisModule_ReplyWithCString(ctx, rstats->devicestr); RedisModule_ReplyWithCString(ctx, "tag"); diff --git a/src/serialization/AOF/rai_aof_rewrite.c b/src/serialization/AOF/rai_aof_rewrite.c index 813877cf2..13512479f 100644 --- a/src/serialization/AOF/rai_aof_rewrite.c +++ b/src/serialization/AOF/rai_aof_rewrite.c @@ -73,7 +73,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value RedisModule_Free(buffer); } - const char *backendstr = RAI_BackendName(model->backend); + const char *backendstr = GetBackendName(model->backend); RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr, model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE", From cd9baa1e7b8f469a3e6e1da4d92b892ef7b72646 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 31 May 2021 10:35:09 +0300 Subject: [PATCH 07/27] Start testing - not finished --- tests/flow/test_data/model_with_infinite_loop.onnx | 3 +++ tests/flow/tests_onnx.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 tests/flow/test_data/model_with_infinite_loop.onnx diff --git a/tests/flow/test_data/model_with_infinite_loop.onnx b/tests/flow/test_data/model_with_infinite_loop.onnx new file mode 100644 index 000000000..c8052d646 --- /dev/null +++ b/tests/flow/test_data/model_with_infinite_loop.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbbdd3e85efb57ab9f8a3848b31cc226f56419fc56a0c55acf2d55f311fcad61 +size 546 diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 743bfd485..9e87394b8 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -463,3 +463,15 @@ def test_onnx_use_custom_allocator_with_GPU(env): for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 11) + +def test_onnx_kill_switch(env): + con = env.getConnection() + model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") + ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) + env.assertEqual(ret, b'OK') + model = con.execute_command('AI.MODELGET', 'm{1}', 'META', 'BLOB') + env.debugPrint(str(model), force=True) + + # Set tensors according to the model inputs + + ret = con.execute_command('AI.TENSORSET', 'in{1}', 'int64', DEVICE, 'BLOB', model_with_inf_loop) From 5c091068dc9596336ee8c6b253aef158475ef416 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 31 May 2021 16:43:00 +0300 Subject: [PATCH 08/27] Support bool type tensor --- src/redis_ai_objects/tensor.c | 43 +++++++++++++++++++++++++++++++---- src/redis_ai_objects/tensor.h | 1 + tests/flow/tests_common.py | 12 ++++++---- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/redis_ai_objects/tensor.c b/src/redis_ai_objects/tensor.c index c010c9cf1..e4bb25ab8 100644 --- a/src/redis_ai_objects/tensor.c +++ b/src/redis_ai_objects/tensor.c @@ -14,6 +14,7 @@ #include "tensor.h" #include "err.h" #include "arr.h" +#include "math.h" #include "redisai.h" #include "version.h" #include "tensor_struct.h" @@ -24,6 +25,24 @@ extern RedisModuleType *RedisAI_TensorType; + +// Check if the given value is in the range of the tensor type. +bool _ValOverflow(long long val, RAI_Tensor *t) { + DLDataType dtype = t->tensor.dl_tensor.dtype; + if (dtype.code == kDLInt) { + uint max_abs_val = (uint) 1 << (uint) (dtype.bits - 1); + if (val >= (long long)max_abs_val || val <= (long long) -1 * max_abs_val) { + return true; + } + } else if (dtype.code == kDLUInt) { + uint max_val = (uint) 1 << dtype.bits; + if(val >= max_val || val < 0) { + return true; + } + } + return false; +} + DLDataType RAI_TensorDataTypeFromString(const char *typestr) { if (strcasecmp(typestr, RAI_DATATYPE_STR_FLOAT) == 0) { return (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1}; @@ -55,10 +74,13 @@ DLDataType RAI_TensorDataTypeFromString(const char *typestr) { return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1}; } } + if (strcasecmp(typestr, "BOOL") == 0) { + return (DLDataType){.code = kDLUInt, .bits = 1, .lanes = 1}; + } return (DLDataType){.bits = 0}; } -static size_t Tensor_DataTypeSize(DLDataType dtype) { return dtype.bits / 8; } +static size_t Tensor_DataTypeSize(DLDataType dtype) { return ceil((double)dtype.bits / 8); } int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) { int result = REDISMODULE_ERR; @@ -92,6 +114,9 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) { } else if (dtype.bits == 16) { strcpy(dtypestr, RAI_DATATYPE_STR_UINT16); result = REDISMODULE_OK; + } else if (dtype.bits == 1) { + strcpy(dtypestr, RAI_DATATYPE_STR_BOOL); + result = REDISMODULE_OK; } } return result; @@ -129,9 +154,10 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in DLDevice device = (DLDevice){.device_type = kDLCPU, .device_id = 0}; // If we return an empty tensor, we initialize the data with zeros to avoid security - // issues. Otherwise, we only allocate without initializing (for better performance) + // issues. Otherwise, we only allocate without initializing (for better performance). + // Also, for boolean tensors we should initialize data with zeros. void *data; - if (empty) { + if (empty || dtype.bits == 1) { data = RedisModule_Calloc(len, dtypeSize); } else { data = RedisModule_Alloc(len * dtypeSize); @@ -414,6 +440,12 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) { } } else if (dtype.code == kDLUInt) { switch (dtype.bits) { + case 1: + // If the val is 1, set the corresponding bit. + if (val % 2) { + ((uint8_t *)data)[i/8] |= ((uint)val << (uint)(i%8)); + } + break; case 8: ((uint8_t *)data)[i] = val; break; @@ -503,6 +535,9 @@ int RAI_TensorGetValueAsLongLong(RAI_Tensor *t, long long i, long long *val) { } } else if (dtype.code == kDLUInt) { switch (dtype.bits) { + case 1: + *val = (((uint8_t *)data)[i/8] >> (uint)(i%8)) % 2; + break; case 8: *val = ((uint8_t *)data)[i]; break; @@ -707,7 +742,7 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i } else { long long val; const int retval = RedisModule_StringToLongLong(argv[argpos], &val); - if (retval != REDISMODULE_OK) { + if (retval != REDISMODULE_OK || _ValOverflow(val, *t)) { RAI_TensorFree(*t); array_free(dims); RAI_SetError(error, RAI_ETENSORSET, "ERR invalid value"); diff --git a/src/redis_ai_objects/tensor.h b/src/redis_ai_objects/tensor.h index f220ba25c..d72a3a3a9 100644 --- a/src/redis_ai_objects/tensor.h +++ b/src/redis_ai_objects/tensor.h @@ -31,6 +31,7 @@ static const char *RAI_DATATYPE_STR_INT32 = "INT32"; static const char *RAI_DATATYPE_STR_INT64 = "INT64"; static const char *RAI_DATATYPE_STR_UINT8 = "UINT8"; static const char *RAI_DATATYPE_STR_UINT16 = "UINT16"; +static const char *RAI_DATATYPE_STR_BOOL = "BOOL"; #define TENSOR_NONE 0 #define TENSOR_VALUES (1 << 0) diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index ec37df748..62c574738 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -10,7 +10,7 @@ def test_common_tensorset(env): con = env.getConnection() - tested_datatypes = ["FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16"] + tested_datatypes = ["FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "BOOL"] for datatype in tested_datatypes: ret = con.execute_command('AI.TENSORSET', 'tensor_{0}'.format(datatype), datatype, 2, 'VALUES', 1, 1) env.assertEqual(ret, b'OK') @@ -164,10 +164,11 @@ def test_common_tensorset_error_replies(env): def test_common_tensorget(env): con = env.getConnection() - tested_datatypes = ["FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16"] + tested_datatypes = ["FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "BOOL"] tested_datatypes_fp = ["FLOAT", "DOUBLE"] - tested_datatypes_int = ["INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16"] + tested_datatypes_int = ["INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "BOOL"] for datatype in tested_datatypes: + env.debugPrint(datatype, force=True) ret = con.execute_command('AI.TENSORSET', 'tensor_{0}'.format(datatype), datatype, 2, 'VALUES', 1, 1) env.assertEqual(ret, b'OK') @@ -204,7 +205,10 @@ def test_common_tensorget(env): if datatype in tested_datatypes_fp: env.assertEqual([b'1', b'1'], tensor_values) if datatype in tested_datatypes_int: - env.assertEqual([1, 1], tensor_values) + if datatype == "BOOL": + env.assertEqual([1, 1], tensor_values) + else: + env.assertEqual([1, 1], tensor_values) # Confirm that the output is the expected for BLOB for datatype in tested_datatypes: From 5d3dd2c2aadc7502a9fb8f1f2b11ceca6bdec1f7 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 31 May 2021 17:06:00 +0300 Subject: [PATCH 09/27] Support tensors of type bool. Add validation that a input value doesn't overflows than the tensor type in TENSORSET. --- src/redis_ai_objects/tensor.c | 13 ++++++------- tests/flow/tests_common.py | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/redis_ai_objects/tensor.c b/src/redis_ai_objects/tensor.c index e4bb25ab8..5e47ac3f1 100644 --- a/src/redis_ai_objects/tensor.c +++ b/src/redis_ai_objects/tensor.c @@ -25,18 +25,17 @@ extern RedisModuleType *RedisAI_TensorType; - // Check if the given value is in the range of the tensor type. bool _ValOverflow(long long val, RAI_Tensor *t) { DLDataType dtype = t->tensor.dl_tensor.dtype; if (dtype.code == kDLInt) { - uint max_abs_val = (uint) 1 << (uint) (dtype.bits - 1); - if (val >= (long long)max_abs_val || val <= (long long) -1 * max_abs_val) { + uint max_abs_val = (uint)1 << (uint)(dtype.bits - 1); + if (val >= (long long)max_abs_val || val <= (long long)-1 * max_abs_val) { return true; } } else if (dtype.code == kDLUInt) { - uint max_val = (uint) 1 << dtype.bits; - if(val >= max_val || val < 0) { + uint max_val = (uint)1 << dtype.bits; + if (val >= max_val || val < 0) { return true; } } @@ -443,7 +442,7 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) { case 1: // If the val is 1, set the corresponding bit. if (val % 2) { - ((uint8_t *)data)[i/8] |= ((uint)val << (uint)(i%8)); + ((uint8_t *)data)[i / 8] |= ((uint)val << (uint)(i % 8)); } break; case 8: @@ -536,7 +535,7 @@ int RAI_TensorGetValueAsLongLong(RAI_Tensor *t, long long i, long long *val) { } else if (dtype.code == kDLUInt) { switch (dtype.bits) { case 1: - *val = (((uint8_t *)data)[i/8] >> (uint)(i%8)) % 2; + *val = (((uint8_t *)data)[i / 8] >> (uint)(i % 8)) % 2; break; case 8: *val = ((uint8_t *)data)[i]; diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 62c574738..027ada678 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -101,6 +101,24 @@ def test_common_tensorset_error_replies(env): env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual(exception.__str__(), "invalid value") + # ERR invalid value - overflow + try: + con.execute_command('AI.TENSORSET', 'z', 'BOOL', 2, 'VALUES', 1, 2) + env.assertFalse(True) + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertEqual(exception.__str__(), "invalid value") + + # ERR invalid value - overflow + try: + con.execute_command('AI.TENSORSET', 'z', 'INT8', 2, 'VALUES', -1, -128) + env.assertFalse(True) + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertEqual(exception.__str__(), "invalid value") + try: con.execute_command('AI.TENSORSET', 1) env.assertFalse(True) @@ -205,10 +223,7 @@ def test_common_tensorget(env): if datatype in tested_datatypes_fp: env.assertEqual([b'1', b'1'], tensor_values) if datatype in tested_datatypes_int: - if datatype == "BOOL": - env.assertEqual([1, 1], tensor_values) - else: - env.assertEqual([1, 1], tensor_values) + env.assertEqual([1, 1], tensor_values) # Confirm that the output is the expected for BLOB for datatype in tested_datatypes: From 04dac08f9ea5bf5c3e45dae7a353d2a165840b77 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Tue, 1 Jun 2021 18:00:43 +0300 Subject: [PATCH 10/27] Support tensor of type bool in ONNX, Add tests for kill switch --- src/backends/onnxruntime.c | 11 +++- src/redis_ai_objects/tensor.c | 4 +- .../test_data/model_with_infinite_loop.onnx | 4 +- tests/flow/tests_onnx.py | 56 +++++++++++++++++-- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 179b82d57..123aea260 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -5,6 +5,7 @@ #include #include "execution/background_workers.h" #include +#include #include "util/arr.h" #include "backends/onnxruntime.h" #include "redis_ai_objects/tensor.h" @@ -157,6 +158,8 @@ ONNXTensorElementDataType RAI_GetOrtDataTypeFromDL(DLDataType dtype) { } } else if (dtype.code == kDLUInt) { switch (dtype.bits) { + case 1: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; case 8: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; case 16: @@ -186,6 +189,8 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) { return (DLDataType){.code = kDLUInt, .bits = 8, .lanes = 1}; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1}; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return (DLDataType){.code = kDLUInt, .bits = 1, .lanes = 1}; default: return (DLDataType){.bits = 0}; } @@ -293,7 +298,7 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long size_t elem_count; ONNX_VALIDATE_STATUS(ort->GetTensorShapeElementCount(info, &elem_count)) - const size_t len = dtype.bits * elem_count / 8; + const size_t len = ceil((double)dtype.bits * elem_count / 8); const size_t total_bytesize = len * sizeof(char); const size_t sample_bytesize = total_bytesize / total_batch_size; const size_t batch_bytesize = sample_bytesize * batch_size; @@ -520,6 +525,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { array_new_on_stack(OrtValue *, 5, inputs); array_new_on_stack(OrtValue *, 5, outputs); OrtRunOptions *run_options = NULL; + size_t run_session_index; OrtTensorTypeAndShapeInfo *info = NULL; { size_t n_input_nodes; @@ -569,7 +575,6 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); // Set the created run option in the global RunSessions and save its index. - size_t run_session_index; SetRunSessionCtx(run_options, &run_session_index); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, @@ -664,7 +669,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) { ort->ReleaseTensorTypeAndShapeInfo(info); } if (run_options) { - ort->ReleaseRunOptions(run_options); + ClearRunSessionCtx(run_session_index); } return REDISMODULE_ERR; } diff --git a/src/redis_ai_objects/tensor.c b/src/redis_ai_objects/tensor.c index 5e47ac3f1..e94f460a6 100644 --- a/src/redis_ai_objects/tensor.c +++ b/src/redis_ai_objects/tensor.c @@ -29,8 +29,8 @@ extern RedisModuleType *RedisAI_TensorType; bool _ValOverflow(long long val, RAI_Tensor *t) { DLDataType dtype = t->tensor.dl_tensor.dtype; if (dtype.code == kDLInt) { - uint max_abs_val = (uint)1 << (uint)(dtype.bits - 1); - if (val >= (long long)max_abs_val || val <= (long long)-1 * max_abs_val) { + unsigned long long max_abs_val = (unsigned long long)1 << (uint)(dtype.bits - 1); + if (val >= max_abs_val || val <= -1 * (long long) max_abs_val) { return true; } } else if (dtype.code == kDLUInt) { diff --git a/tests/flow/test_data/model_with_infinite_loop.onnx b/tests/flow/test_data/model_with_infinite_loop.onnx index c8052d646..a4475545e 100644 --- a/tests/flow/test_data/model_with_infinite_loop.onnx +++ b/tests/flow/test_data/model_with_infinite_loop.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dbbdd3e85efb57ab9f8a3848b31cc226f56419fc56a0c55acf2d55f311fcad61 -size 546 +oid sha256:85848c4e1f96f47b62a178d67a6785b3734a5387c6f42064673d794430290862 +size 692 diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 9e87394b8..be9a430db 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -464,14 +464,60 @@ def test_onnx_use_custom_allocator_with_GPU(env): env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 11) -def test_onnx_kill_switch(env): +def test_onnx_kill_switch_basic(env): con = env.getConnection() model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) env.assertEqual(ret, b'OK') - model = con.execute_command('AI.MODELGET', 'm{1}', 'META', 'BLOB') - env.debugPrint(str(model), force=True) - # Set tensors according to the model inputs + # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' + # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model + # runs a very large number of iterations without doing anything, until it is caught with the kill switch. + con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) + con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) + con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) + con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) + + try: + con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 4, 'outer_scope_input{1}', 'iterations{1}', + 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') + env.assertTrue(False) + except Exception as exception: + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) + + +def test_onnx_kill_switch_multiple_working_threads(): + env = Env(moduleArgs='THREADS_PER_QUEUE 8') + con = env.getConnection() + model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") + ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) + env.assertEqual(ret, b'OK') + + # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' + # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model + # runs a very large number of iterations without doing anything, until it is caught with the kill switch. + con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) + con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) + con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) + con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) + + # Load another onnx model only on CPU (to test multiple devices when DEVICE = GPU + model_pb = load_file_content('mnist.onnx') + sample_raw = load_file_content('one.raw') + ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) - ret = con.execute_command('AI.TENSORSET', 'in{1}', 'int64', DEVICE, 'BLOB', model_with_inf_loop) + def run_parallel_onnx_sessions(con): + ret = con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') + env.assertEqual(ret, b'OK') + try: + con.execute_command('AI.MODELEXECUTE', 'inf_loop_model{1}', 'INPUTS', 4, 'outer_scope_input{1}', 'iterations{1}', + 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') + except Exception as exception: + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) + ret = con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') + env.assertEqual(ret, b'OK') + run_test_multiproc(env, 8, run_parallel_onnx_sessions) From 1d6b3ed08489acfe5ed6702aa4418093f3590816 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Tue, 1 Jun 2021 19:01:21 +0300 Subject: [PATCH 11/27] Add load time config for ONNX_TIMEOUT. Parallel tests seems not to work. --- src/backends/backends.c | 2 ++ src/backends/onnxruntime.c | 1 + src/backends/onnxruntime.h | 1 + src/config/config.c | 34 +++++++++++++++++++++++++++++++++- src/config/config.h | 13 +++++++++++++ src/execution/onnx_timeout.c | 2 +- src/execution/onnx_timeout.h | 4 ---- tests/flow/tests_onnx.py | 2 +- 8 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index 3031520aa..87368bd42 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -32,6 +32,8 @@ int RAI_GetApi(const char *func_name, void **targetPtrPtr) { *targetPtrPtr = GetThreadId; } else if (strcmp("NumThreadsPerQueue", func_name) == 0) { *targetPtrPtr = GetNumThreadsPerQueue; + } else if (strcmp("OnnxTimeout", func_name) == 0) { + *targetPtrPtr = GetOnnxTimeout; } else { return RedisModule_GetApi(func_name, targetPtrPtr); } diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 123aea260..0e3c33894 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -96,6 +96,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) { // Export RedisAI callbacks. get_api_fn("ThreadIdKey", ((void **)&RedisAI_ThreadId)); get_api_fn("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); + get_api_fn("OnnxTimeout", ((void **)&RedisAI_OnnxTimeout)); // Create a global array of onnx runSessions, with an entry for every working thread. CreateGlobalOnnxRunSessions(); diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index 8fcc7c3d7..2dcbab5b9 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -11,6 +11,7 @@ unsigned long long RAI_GetMemoryAccessORT(void); pthread_key_t (*RedisAI_ThreadId)(void); long long (*RedisAI_NumThreadsPerQueue)(void); +long long (*RedisAI_OnnxTimeout)(void); int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)); diff --git a/src/config/config.c b/src/config/config.c index e21a2402d..c8486a7c3 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -18,7 +18,10 @@ long long backends_intra_op_parallelism; // number of threads used within an // individual op for parallelism. long long backends_inter_op_parallelism; // number of threads used for parallelism // between independent operations. -long long model_chunk_size; // size of chunks used to break up model payloads. +long long model_chunk_size; // size of chunks used to break up model payloads. + +long long onnx_max_runtime; // The maximum time in milliseconds + // before killing onnx run session. /** * @@ -86,6 +89,17 @@ int setModelChunkSize(long long size) { return result; } +long long GetOnnxTimeout () { return onnx_max_runtime; } + +int SetOnnxTimeout(long long timeout) { + int result = 1; + if (timeout > 0) { + onnx_max_runtime = timeout; + result = 0; + } + return result; +} + /** * Helper method for AI.CONFIG LOADBACKEND * @@ -209,6 +223,18 @@ int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string) { return result; } +int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout) { + long long temp; + int result = RedisModule_StringToLongLong(onnx_timeout, &temp); + // make sure that the timeout is a positive integer, if not set the value to the default. + if (result == REDISMODULE_OK && temp < 1) { + temp = ONNX_DEFAULT_MAX_RUNTIME; + result = REDISMODULE_ERR; + } + result = SetOnnxTimeout(temp); + return result; +} + /** * * @param ctx Context in which Redis modules operate @@ -253,6 +279,12 @@ int RAI_configParamParse(RedisModuleCtx *ctx, const char *key, const char *val, RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, getModelChunkSize()); } + } else if (strcasecmp((key), "ONNX_TIMEOUT") == 0) { + ret = RedisAI_Config_OnnxTimeout(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, + GetOnnxTimeout()); + } } else if (strcasecmp((key), "BACKENDSPATH") == 0) { // already taken care of } else { diff --git a/src/config/config.h b/src/config/config.h index d8fc461fe..af4ad2fea 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -23,6 +23,7 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; #define REDISAI_DEFAULT_INTRA_OP_PARALLELISM 0 #define REDISAI_DEFAULT_INTER_OP_PARALLELISM 0 #define REDISAI_DEFAULT_MODEL_CHUNK_SIZE 535822336 // (511 * 1024 * 1024) +#define ONNX_DEFAULT_MAX_RUNTIME 5000 #define REDISAI_ERRORMSG_PROCESSING_ARG "ERR error processing argument" #define REDISAI_ERRORMSG_THREADS_PER_QUEUE "ERR error setting THREADS_PER_QUEUE to" #define REDISAI_ERRORMSG_INTRA_OP_PARALLELISM "ERR error setting INTRA_OP_PARALLELISM to" @@ -91,6 +92,11 @@ int setModelChunkSize(long long size); * @param argc Redis command number of arguments * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed */ + +long long GetOnnxTimeout(void); + +int SetOnnxTimeout(long long timeout); + int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); /** @@ -139,6 +145,13 @@ int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_stri */ int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string); +/** + * Set the maximum time in ms that onnx backend allow running a model. + * @param onnx_max_runtime - string containing the max runtime (in ms) + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed + */ +int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout); + /** * * @param ctx Context in which Redis modules operate diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index abf5d276d..147016480 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -53,7 +53,7 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s } long long curr_time = _mstime(); // Check if a sessions is running for too long, and kill it if so. - if (curr_time - run_sessions_ctx[i]->queuingTime > ONNX_MAX_RUNTIME) { + if (curr_time - run_sessions_ctx[i]->queuingTime > RedisAI_OnnxTimeout()) { ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions); } } diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index 8e1c35bf1..2084dc150 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -3,10 +3,6 @@ #include "backends/onnxruntime.h" #include "onnxruntime_c_api.h" -// The maximum time in milliseconds before killing onnx run session. -// todo: make it a load time config -#define ONNX_MAX_RUNTIME 5000 - typedef struct OnnxRunSessionCtx { long long queuingTime; OrtRunOptions *runOptions; diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index be9a430db..dc933114b 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -505,7 +505,7 @@ def test_onnx_kill_switch_multiple_working_threads(): # Load another onnx model only on CPU (to test multiple devices when DEVICE = GPU model_pb = load_file_content('mnist.onnx') sample_raw = load_file_content('one.raw') - ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) + ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU', 'BLOB', model_pb) env.assertEqual(ret, b'OK') con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) From ea3c174a42392c4380ebddbfacd56becfb625933 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Wed, 2 Jun 2021 12:13:55 +0300 Subject: [PATCH 12/27] Some fixes --- src/execution/onnx_timeout.c | 4 +++- src/redisai.c | 1 + tests/flow/tests_onnx.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index 147016480..771dd377a 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -37,6 +37,7 @@ int RAI_AddNewDeviceORT(const char *device_str) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); run_sessions_array = array_append(run_sessions_array, entry); } + onnx_global_run_sessions->OnnxRunSessions = run_sessions_array; pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); return REDISMODULE_OK; } @@ -52,8 +53,9 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s continue; } long long curr_time = _mstime(); + long long timeout = RedisAI_OnnxTimeout(); // Check if a sessions is running for too long, and kill it if so. - if (curr_time - run_sessions_ctx[i]->queuingTime > RedisAI_OnnxTimeout()) { + if (curr_time - run_sessions_ctx[i]->queuingTime > timeout) { ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions); } } diff --git a/src/redisai.c b/src/redisai.c index 155ba6af3..6606ad5aa 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -1455,6 +1455,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) setBackendsInterOpParallelism(REDISAI_DEFAULT_INTER_OP_PARALLELISM); setBackendsIntraOpParallelism(REDISAI_DEFAULT_INTRA_OP_PARALLELISM); setModelChunkSize(REDISAI_DEFAULT_MODEL_CHUNK_SIZE); + SetOnnxTimeout(ONNX_DEFAULT_MAX_RUNTIME); RAI_loadTimeConfig(ctx, argv, argc); diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index dc933114b..5f722506b 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -502,10 +502,10 @@ def test_onnx_kill_switch_multiple_working_threads(): con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) - # Load another onnx model only on CPU (to test multiple devices when DEVICE = GPU + # Load another onnx model as if it runs on a different device (to test existence of multiple queues) model_pb = load_file_content('mnist.onnx') sample_raw = load_file_content('one.raw') - ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU', 'BLOB', model_pb) + ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) env.assertEqual(ret, b'OK') con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) From 4bbfbcd519a80c5be89b25c5ad70efe58cf8ba4a Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 6 Jun 2021 10:59:59 +0300 Subject: [PATCH 13/27] Remove debug print --- tests/flow/tests_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 027ada678..9ace58f08 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -186,7 +186,6 @@ def test_common_tensorget(env): tested_datatypes_fp = ["FLOAT", "DOUBLE"] tested_datatypes_int = ["INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "BOOL"] for datatype in tested_datatypes: - env.debugPrint(datatype, force=True) ret = con.execute_command('AI.TENSORSET', 'tensor_{0}'.format(datatype), datatype, 2, 'VALUES', 1, 1) env.assertEqual(ret, b'OK') From 4aed8cabd33fd366ecfbd129e0aec2c1cbe3eb20 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 6 Jun 2021 15:46:50 +0300 Subject: [PATCH 14/27] Some fixes and documentation complement. --- src/CMakeLists.txt | 3 - src/execution/background_workers.c | 4 +- src/execution/background_workers.h | 28 +- src/execution/onnx_timeout.c | 1 - src/redisai.c | 2 +- src/util/rax.c | 2003 ---------------------------- src/util/rax.h | 218 --- src/util/rax_malloc.h | 43 - 8 files changed, 27 insertions(+), 2275 deletions(-) delete mode 100644 src/util/rax.c delete mode 100644 src/util/rax.h delete mode 100644 src/util/rax_malloc.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5fd0bf0e4..b6e6dcbe9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -30,7 +30,6 @@ ADD_LIBRARY(redisai_obj OBJECT util/dictionaries.c util/queue.c util/string_utils.c - util/rax.c redisai.c execution/command_parser.c execution/parsing/deprecated.c @@ -90,10 +89,8 @@ IF(BUILD_ORT) ADD_LIBRARY(redisai_onnxruntime_obj OBJECT backends/onnxruntime.c execution/onnx_timeout.c - util/rax.c ${BACKEND_COMMON_SRC} ) - SET_PROPERTY(TARGET redisai_onnxruntime_obj PROPERTY ENABLE_EXPORTS 1) ENDIF() INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 9e802410c..8a195f6fd 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -103,7 +103,7 @@ bool IsRunQueueExists(const char *device_str) { } uintptr_t GetThreadId() { - return *(uintptr_t *)pthread_getspecific(thread_id_key); + return *(uintptr_t *)pthread_getspecific(ThreadIdKey); } long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } @@ -130,7 +130,7 @@ static void _SaveThreadId() { uintptr_t *id_value = RedisModule_Alloc(sizeof(uintptr_t)); // Let the current thread have the next available id, and increase the counter. *id_value = __atomic_fetch_add(&BGWorkersCounter, 1, __ATOMIC_RELAXED); - pthread_setspecific(thread_id_key, id_value); + pthread_setspecific(ThreadIdKey, id_value); } /** diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 8457d5016..2ae800f00 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -27,13 +27,12 @@ #include "redis_ai_objects/stats.h" #include "redis_ai_objects/tensor.h" #include "util/arr.h" -#include "util/rax.h" #include "util/queue.h" AI_dict *RunQueues; -long long ThreadPoolSizePerQueue; -uintptr_t BGWorkersCounter; -pthread_key_t thread_id_key; // A key for getting the thread id from its local storage. +long long ThreadPoolSizePerQueue; // Number of working threads for device. +uintptr_t BGWorkersCounter; // Total number of BG threads running currently. +pthread_key_t ThreadIdKey; // Holds the thread id in its local storage. typedef struct RunQueueInfo { pthread_mutex_t run_queue_mutex; @@ -43,14 +42,35 @@ typedef struct RunQueueInfo { char *device_str; } RunQueueInfo; + +/** + * @brief Terminate all working threads and free the run queue with its inner fields. + */ void RunQueueInfoFree(RunQueueInfo *info); +/** + * @brief Create a new run queue for a device. + */ RunQueueInfo *CreateRunQueue(const char *device_str); +/** + * @brief Return true if a ru queue exists for this particular device. + */ bool IsRunQueueExists(const char *device_str); +/** + * @brief Return the RunQueueInfo saved in the global RunQueues dict for a certain + * device name, or NULL if doesn't exist. + */ RunQueueInfo *GetRunQueueInfo(const char *device_str); +/** + * @brief Return the thread id from its local storage by accessing the value + * saved under ThreadIdKey. + */ uintptr_t GetThreadId(void); +/** + * @brief Return the number of working threads per device in RedisAI. + */ long long GetNumThreadsPerQueue(void); diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index 771dd377a..f8978b988 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -2,7 +2,6 @@ #include "util/arr.h" #include #include -#include "util/rax.h" #include "util/string_utils.h" // Gets the current time in milliseconds. diff --git a/src/redisai.c b/src/redisai.c index 738c33a13..b8ded0dd9 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -1471,7 +1471,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RAI_loadTimeConfig(ctx, argv, argc); RunQueues = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); - pthread_key_create(&thread_id_key, RedisModule_Free); + pthread_key_create(&ThreadIdKey, RedisModule_Free); RunQueueInfo *cpu_run_queue_info = CreateRunQueue("CPU"); if (cpu_run_queue_info == NULL) { RedisModule_Log(ctx, "warning", "RedisAI could not initialize run queue for CPU"); diff --git a/src/util/rax.c b/src/util/rax.c deleted file mode 100644 index 221258c58..000000000 --- a/src/util/rax.c +++ /dev/null @@ -1,2003 +0,0 @@ -/* Rax -- A radix tree implementation. - * - * Version 1.0 -- 14 November 2019 - * - * Copyright (c) 2017-2019, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#include -#include -#include -#include -#include -#include -#include "rax.h" - -#ifndef RAX_MALLOC_INCLUDE -#define RAX_MALLOC_INCLUDE "rax_malloc.h" -#endif - -#include RAX_MALLOC_INCLUDE - -/* This is a special pointer that is guaranteed to never have the same value - * of a radix tree node. It's used in order to report "not found" error without - * requiring the function to have multiple return values. */ -void *raxNotFound = (void *)"rax-not-found-pointer"; - -/* -------------------------------- Debugging ------------------------------ */ - -void raxDebugShowNode(const char *msg, raxNode *n); - -/* Turn debugging messages on/off by compiling with RAX_DEBUG_MSG macro on. - * When RAX_DEBUG_MSG is defined by default Rax operations will emit a lot - * of debugging info to the standard output, however you can still turn - * debugging on/off in order to enable it only when you suspect there is an - * operation causing a bug using the function raxSetDebugMsg(). */ -#ifdef RAX_DEBUG_MSG -#define debugf(...) \ - if (raxDebugMsg) { \ - printf("%s:%s:%d:\t", __FILE__, __FUNCTION__, __LINE__); \ - printf(__VA_ARGS__); \ - fflush(stdout); \ - } - -#define debugnode(msg, n) raxDebugShowNode(msg, n) -#else -#define debugf(...) -#define debugnode(msg, n) -#endif - -/* By default log debug info if RAX_DEBUG_MSG is defined. */ -static int raxDebugMsg = 1; - -/* When debug messages are enabled, turn them on/off dynamically. By - * default they are enabled. Set the state to 0 to disable, and 1 to - * re-enable. */ -void raxSetDebugMsg(int onoff) { raxDebugMsg = onoff; } - -/* ------------------------- raxStack functions -------------------------- - * The raxStack is a simple stack of pointers that is capable of switching - * from using a stack-allocated array to dynamic heap once a given number of - * items are reached. It is used in order to retain the list of parent nodes - * while walking the radix tree in order to implement certain operations that - * need to navigate the tree upward. - * ------------------------------------------------------------------------- */ - -/* Initialize the stack. */ -static inline void raxStackInit(raxStack *ts) { - ts->stack = ts->static_items; - ts->items = 0; - ts->maxitems = RAX_STACK_STATIC_ITEMS; - ts->oom = 0; -} - -/* Push an item into the stack, returns 1 on success, 0 on out of memory. */ -static inline int raxStackPush(raxStack *ts, void *ptr) { - if (ts->items == ts->maxitems) { - if (ts->stack == ts->static_items) { - ts->stack = rax_malloc(sizeof(void *) * ts->maxitems * 2); - if (ts->stack == NULL) { - ts->stack = ts->static_items; - ts->oom = 1; - errno = ENOMEM; - return 0; - } - memcpy(ts->stack, ts->static_items, sizeof(void *) * ts->maxitems); - } else { - void **newalloc = rax_realloc(ts->stack, sizeof(void *) * ts->maxitems * 2); - if (newalloc == NULL) { - ts->oom = 1; - errno = ENOMEM; - return 0; - } - ts->stack = newalloc; - } - ts->maxitems *= 2; - } - ts->stack[ts->items] = ptr; - ts->items++; - return 1; -} - -/* Pop an item from the stack, the function returns NULL if there are no - * items to pop. */ -static inline void *raxStackPop(raxStack *ts) { - if (ts->items == 0) - return NULL; - ts->items--; - return ts->stack[ts->items]; -} - -/* Return the stack item at the top of the stack without actually consuming - * it. */ -static inline void *raxStackPeek(raxStack *ts) { - if (ts->items == 0) - return NULL; - return ts->stack[ts->items - 1]; -} - -/* Free the stack in case we used heap allocation. */ -static inline void raxStackFree(raxStack *ts) { - if (ts->stack != ts->static_items) - rax_free(ts->stack); -} - -/* ---------------------------------------------------------------------------- - * Radix tree implementation - * --------------------------------------------------------------------------*/ - -/* Return the padding needed in the characters section of a node having size - * 'nodesize'. The padding is needed to store the child pointers to aligned - * addresses. Note that we add 4 to the node size because the node has a four - * bytes header. */ -#define raxPadding(nodesize) \ - ((sizeof(void *) - ((nodesize + 4) % sizeof(void *))) & (sizeof(void *) - 1)) - -/* Return the pointer to the last child pointer in a node. For the compressed - * nodes this is the only child pointer. */ -#define raxNodeLastChildPtr(n) \ - ((raxNode **)(((char *)(n)) + raxNodeCurrentLength(n) - sizeof(raxNode *) - \ - (((n)->iskey && !(n)->isnull) ? sizeof(void *) : 0))) - -/* Return the pointer to the first child pointer. */ -#define raxNodeFirstChildPtr(n) ((raxNode **)((n)->data + (n)->size + raxPadding((n)->size))) - -/* Return the current total size of the node. Note that the second line - * computes the padding after the string of characters, needed in order to - * save pointers to aligned addresses. */ -#define raxNodeCurrentLength(n) \ - (sizeof(raxNode) + (n)->size + raxPadding((n)->size) + \ - ((n)->iscompr ? sizeof(raxNode *) : sizeof(raxNode *) * (n)->size) + \ - (((n)->iskey && !(n)->isnull) * sizeof(void *))) - -/* Allocate a new non compressed node with the specified number of children. - * If datafiled is true, the allocation is made large enough to hold the - * associated data pointer. - * Returns the new node pointer. On out of memory NULL is returned. */ -raxNode *raxNewNode(size_t children, int datafield) { - size_t nodesize = - sizeof(raxNode) + children + raxPadding(children) + sizeof(raxNode *) * children; - if (datafield) - nodesize += sizeof(void *); - raxNode *node = rax_malloc(nodesize); - if (node == NULL) - return NULL; - node->iskey = 0; - node->isnull = 0; - node->iscompr = 0; - node->size = children; - return node; -} - -/* Allocate a new rax and return its pointer. On out of memory the function - * returns NULL. */ -rax *raxNew(void) { - rax *rax = rax_malloc(sizeof(*rax)); - if (rax == NULL) - return NULL; - rax->numele = 0; - rax->numnodes = 1; - rax->head = raxNewNode(0, 0); - if (rax->head == NULL) { - rax_free(rax); - return NULL; - } else { - return rax; - } -} - -/* realloc the node to make room for auxiliary data in order - * to store an item in that node. On out of memory NULL is returned. */ -raxNode *raxReallocForData(raxNode *n, void *data) { - if (data == NULL) - return n; /* No reallocation needed, setting isnull=1 */ - size_t curlen = raxNodeCurrentLength(n); - return rax_realloc(n, curlen + sizeof(void *)); -} - -/* Set the node auxiliary data to the specified pointer. */ -void raxSetData(raxNode *n, void *data) { - n->iskey = 1; - if (data != NULL) { - n->isnull = 0; - void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); - memcpy(ndata, &data, sizeof(data)); - } else { - n->isnull = 1; - } -} - -/* Get the node auxiliary data. */ -void *raxGetData(raxNode *n) { - if (n->isnull) - return NULL; - void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); - void *data; - memcpy(&data, ndata, sizeof(data)); - return data; -} - -/* Add a new child to the node 'n' representing the character 'c' and return - * its new pointer, as well as the child pointer by reference. Additionally - * '***parentlink' is populated with the raxNode pointer-to-pointer of where - * the new child was stored, which is useful for the caller to replace the - * child pointer if it gets reallocated. - * - * On success the new parent node pointer is returned (it may change because - * of the realloc, so the caller should discard 'n' and use the new value). - * On out of memory NULL is returned, and the old node is still valid. */ -raxNode *raxAddChild(raxNode *n, unsigned char c, raxNode **childptr, raxNode ***parentlink) { - assert(n->iscompr == 0); - - size_t curlen = raxNodeCurrentLength(n); - n->size++; - size_t newlen = raxNodeCurrentLength(n); - n->size--; /* For now restore the orignal size. We'll update it only on - success at the end. */ - - /* Alloc the new child we will link to 'n'. */ - raxNode *child = raxNewNode(0, 0); - if (child == NULL) - return NULL; - - /* Make space in the original node. */ - raxNode *newn = rax_realloc(n, newlen); - if (newn == NULL) { - rax_free(child); - return NULL; - } - n = newn; - - /* After the reallocation, we have up to 8/16 (depending on the system - * pointer size, and the required node padding) bytes at the end, that is, - * the additional char in the 'data' section, plus one pointer to the new - * child, plus the padding needed in order to store addresses into aligned - * locations. - * - * So if we start with the following node, having "abde" edges. - * - * Note: - * - We assume 4 bytes pointer for simplicity. - * - Each space below corresponds to one byte - * - * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP| - * - * After the reallocation we need: 1 byte for the new edge character - * plus 4 bytes for a new child pointer (assuming 32 bit machine). - * However after adding 1 byte to the edge char, the header + the edge - * characters are no longer aligned, so we also need 3 bytes of padding. - * In total the reallocation will add 1+4+3 bytes = 8 bytes: - * - * (Blank bytes are represented by ".") - * - * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP|[....][....] - * - * Let's find where to insert the new child in order to make sure - * it is inserted in-place lexicographically. Assuming we are adding - * a child "c" in our case pos will be = 2 after the end of the following - * loop. */ - int pos; - for (pos = 0; pos < n->size; pos++) { - if (n->data[pos] > c) - break; - } - - /* Now, if present, move auxiliary data pointer at the end - * so that we can mess with the other data without overwriting it. - * We will obtain something like that: - * - * [HDR*][abde][Aptr][Bptr][Dptr][Eptr][....][....]|AUXP| - */ - unsigned char *src, *dst; - if (n->iskey && !n->isnull) { - src = ((unsigned char *)n + curlen - sizeof(void *)); - dst = ((unsigned char *)n + newlen - sizeof(void *)); - memmove(dst, src, sizeof(void *)); - } - - /* Compute the "shift", that is, how many bytes we need to move the - * pointers section forward because of the addition of the new child - * byte in the string section. Note that if we had no padding, that - * would be always "1", since we are adding a single byte in the string - * section of the node (where now there is "abde" basically). - * - * However we have padding, so it could be zero, or up to 8. - * - * Another way to think at the shift is, how many bytes we need to - * move child pointers forward *other than* the obvious sizeof(void*) - * needed for the additional pointer itself. */ - size_t shift = newlen - curlen - sizeof(void *); - - /* We said we are adding a node with edge 'c'. The insertion - * point is between 'b' and 'd', so the 'pos' variable value is - * the index of the first child pointer that we need to move forward - * to make space for our new pointer. - * - * To start, move all the child pointers after the insertion point - * of shift+sizeof(pointer) bytes on the right, to obtain: - * - * [HDR*][abde][Aptr][Bptr][....][....][Dptr][Eptr]|AUXP| - */ - src = n->data + n->size + raxPadding(n->size) + sizeof(raxNode *) * pos; - memmove(src + shift + sizeof(raxNode *), src, sizeof(raxNode *) * (n->size - pos)); - - /* Move the pointers to the left of the insertion position as well. Often - * we don't need to do anything if there was already some padding to use. In - * that case the final destination of the pointers will be the same, however - * in our example there was no pre-existing padding, so we added one byte - * plus thre bytes of padding. After the next memmove() things will look - * like thata: - * - * [HDR*][abde][....][Aptr][Bptr][....][Dptr][Eptr]|AUXP| - */ - if (shift) { - src = (unsigned char *)raxNodeFirstChildPtr(n); - memmove(src + shift, src, sizeof(raxNode *) * pos); - } - - /* Now make the space for the additional char in the data section, - * but also move the pointers before the insertion point to the right - * by shift bytes, in order to obtain the following: - * - * [HDR*][ab.d][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| - */ - src = n->data + pos; - memmove(src + 1, src, n->size - pos); - - /* We can now set the character and its child node pointer to get: - * - * [HDR*][abcd][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| - * [HDR*][abcd][e...][Aptr][Bptr][Cptr][Dptr][Eptr]|AUXP| - */ - n->data[pos] = c; - n->size++; - src = (unsigned char *)raxNodeFirstChildPtr(n); - raxNode **childfield = (raxNode **)(src + sizeof(raxNode *) * pos); - memcpy(childfield, &child, sizeof(child)); - *childptr = child; - *parentlink = childfield; - return n; -} - -/* Turn the node 'n', that must be a node without any children, into a - * compressed node representing a set of nodes linked one after the other - * and having exactly one child each. The node can be a key or not: this - * property and the associated value if any will be preserved. - * - * The function also returns a child node, since the last node of the - * compressed chain cannot be part of the chain: it has zero children while - * we can only compress inner nodes with exactly one child each. */ -raxNode *raxCompressNode(raxNode *n, unsigned char *s, size_t len, raxNode **child) { - assert(n->size == 0 && n->iscompr == 0); - void *data = NULL; /* Initialized only to avoid warnings. */ - size_t newsize; - - debugf("Compress node: %.*s\n", (int)len, s); - - /* Allocate the child to link to this node. */ - *child = raxNewNode(0, 0); - if (*child == NULL) - return NULL; - - /* Make space in the parent node. */ - newsize = sizeof(raxNode) + len + raxPadding(len) + sizeof(raxNode *); - if (n->iskey) { - data = raxGetData(n); /* To restore it later. */ - if (!n->isnull) - newsize += sizeof(void *); - } - raxNode *newn = rax_realloc(n, newsize); - if (newn == NULL) { - rax_free(*child); - return NULL; - } - n = newn; - - n->iscompr = 1; - n->size = len; - memcpy(n->data, s, len); - if (n->iskey) - raxSetData(n, data); - raxNode **childfield = raxNodeLastChildPtr(n); - memcpy(childfield, child, sizeof(*child)); - return n; -} - -/* Low level function that walks the tree looking for the string - * 's' of 'len' bytes. The function returns the number of characters - * of the key that was possible to process: if the returned integer - * is the same as 'len', then it means that the node corresponding to the - * string was found (however it may not be a key in case the node->iskey is - * zero or if simply we stopped in the middle of a compressed node, so that - * 'splitpos' is non zero). - * - * Otherwise if the returned integer is not the same as 'len', there was an - * early stop during the tree walk because of a character mismatch. - * - * The node where the search ended (because the full string was processed - * or because there was an early stop) is returned by reference as - * '*stopnode' if the passed pointer is not NULL. This node link in the - * parent's node is returned as '*plink' if not NULL. Finally, if the - * search stopped in a compressed node, '*splitpos' returns the index - * inside the compressed node where the search ended. This is useful to - * know where to split the node for insertion. - * - * Note that when we stop in the middle of a compressed node with - * a perfect match, this function will return a length equal to the - * 'len' argument (all the key matched), and will return a *splitpos which is - * always positive (that will represent the index of the character immediately - * *after* the last match in the current compressed node). - * - * When instead we stop at a compressed node and *splitpos is zero, it - * means that the current node represents the key (that is, none of the - * compressed node characters are needed to represent the key, just all - * its parents nodes). */ -static inline size_t raxLowWalk(rax *rax, unsigned char *s, size_t len, raxNode **stopnode, - raxNode ***plink, int *splitpos, raxStack *ts) { - raxNode *h = rax->head; - raxNode **parentlink = &rax->head; - - size_t i = 0; /* Position in the string. */ - size_t j = 0; /* Position in the node children (or bytes if compressed).*/ - while (h->size && i < len) { - debugnode("Lookup current node", h); - unsigned char *v = h->data; - - if (h->iscompr) { - for (j = 0; j < h->size && i < len; j++, i++) { - if (v[j] != s[i]) - break; - } - if (j != h->size) - break; - } else { - /* Even when h->size is large, linear scan provides good - * performances compared to other approaches that are in theory - * more sounding, like performing a binary search. */ - for (j = 0; j < h->size; j++) { - if (v[j] == s[i]) - break; - } - if (j == h->size) - break; - i++; - } - - if (ts) - raxStackPush(ts, h); /* Save stack of parent nodes. */ - raxNode **children = raxNodeFirstChildPtr(h); - if (h->iscompr) - j = 0; /* Compressed node only child is at index 0. */ - memcpy(&h, children + j, sizeof(h)); - parentlink = children + j; - j = 0; /* If the new node is compressed and we do not - iterate again (since i == l) set the split - position to 0 to signal this node represents - the searched key. */ - } - debugnode("Lookup stop node is", h); - if (stopnode) - *stopnode = h; - if (plink) - *plink = parentlink; - if (splitpos && h->iscompr) - *splitpos = j; - return i; -} - -/* Insert the element 's' of size 'len', setting as auxiliary data - * the pointer 'data'. If the element is already present, the associated - * data is updated (only if 'overwrite' is set to 1), and 0 is returned, - * otherwise the element is inserted and 1 is returned. On out of memory the - * function returns 0 as well but sets errno to ENOMEM, otherwise errno will - * be set to 0. - */ -int raxGenericInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old, - int overwrite) { - size_t i; - int j = 0; /* Split position. If raxLowWalk() stops in a compressed - node, the index 'j' represents the char we stopped within the - compressed node, that is, the position where to split the - node for insertion. */ - raxNode *h, **parentlink; - - debugf("### Insert %.*s with value %p\n", (int)len, s, data); - i = raxLowWalk(rax, s, len, &h, &parentlink, &j, NULL); - - /* If i == len we walked following the whole string. If we are not - * in the middle of a compressed node, the string is either already - * inserted or this middle node is currently not a key, but can represent - * our key. We have just to reallocate the node and make space for the - * data pointer. */ - if (i == len && (!h->iscompr || j == 0 /* not in the middle if j is 0 */)) { - debugf("### Insert: node representing key exists\n"); - /* Make space for the value pointer if needed. */ - if (!h->iskey || (h->isnull && overwrite)) { - h = raxReallocForData(h, data); - if (h) - memcpy(parentlink, &h, sizeof(h)); - } - if (h == NULL) { - errno = ENOMEM; - return 0; - } - - /* Update the existing key if there is already one. */ - if (h->iskey) { - if (old) - *old = raxGetData(h); - if (overwrite) - raxSetData(h, data); - errno = 0; - return 0; /* Element already exists. */ - } - - /* Otherwise set the node as a key. Note that raxSetData() - * will set h->iskey. */ - raxSetData(h, data); - rax->numele++; - return 1; /* Element inserted. */ - } - - /* If the node we stopped at is a compressed node, we need to - * split it before to continue. - * - * Splitting a compressed node have a few possible cases. - * Imagine that the node 'h' we are currently at is a compressed - * node contaning the string "ANNIBALE" (it means that it represents - * nodes A -> N -> N -> I -> B -> A -> L -> E with the only child - * pointer of this node pointing at the 'E' node, because remember that - * we have characters at the edges of the graph, not inside the nodes - * themselves. - * - * In order to show a real case imagine our node to also point to - * another compressed node, that finally points at the node without - * children, representing 'O': - * - * "ANNIBALE" -> "SCO" -> [] - * - * When inserting we may face the following cases. Note that all the cases - * require the insertion of a non compressed node with exactly two - * children, except for the last case which just requires splitting a - * compressed node. - * - * 1) Inserting "ANNIENTARE" - * - * |B| -> "ALE" -> "SCO" -> [] - * "ANNI" -> |-| - * |E| -> (... continue algo ...) "NTARE" -> [] - * - * 2) Inserting "ANNIBALI" - * - * |E| -> "SCO" -> [] - * "ANNIBAL" -> |-| - * |I| -> (... continue algo ...) [] - * - * 3) Inserting "AGO" (Like case 1, but set iscompr = 0 into original node) - * - * |N| -> "NIBALE" -> "SCO" -> [] - * |A| -> |-| - * |G| -> (... continue algo ...) |O| -> [] - * - * 4) Inserting "CIAO" - * - * |A| -> "NNIBALE" -> "SCO" -> [] - * |-| - * |C| -> (... continue algo ...) "IAO" -> [] - * - * 5) Inserting "ANNI" - * - * "ANNI" -> "BALE" -> "SCO" -> [] - * - * The final algorithm for insertion covering all the above cases is as - * follows. - * - * ============================= ALGO 1 ============================= - * - * For the above cases 1 to 4, that is, all cases where we stopped in - * the middle of a compressed node for a character mismatch, do: - * - * Let $SPLITPOS be the zero-based index at which, in the - * compressed node array of characters, we found the mismatching - * character. For example if the node contains "ANNIBALE" and we add - * "ANNIENTARE" the $SPLITPOS is 4, that is, the index at which the - * mismatching character is found. - * - * 1. Save the current compressed node $NEXT pointer (the pointer to the - * child element, that is always present in compressed nodes). - * - * 2. Create "split node" having as child the non common letter - * at the compressed node. The other non common letter (at the key) - * will be added later as we continue the normal insertion algorithm - * at step "6". - * - * 3a. IF $SPLITPOS == 0: - * Replace the old node with the split node, by copying the auxiliary - * data if any. Fix parent's reference. Free old node eventually - * (we still need its data for the next steps of the algorithm). - * - * 3b. IF $SPLITPOS != 0: - * Trim the compressed node (reallocating it as well) in order to - * contain $splitpos characters. Change chilid pointer in order to link - * to the split node. If new compressed node len is just 1, set - * iscompr to 0 (layout is the same). Fix parent's reference. - * - * 4a. IF the postfix len (the length of the remaining string of the - * original compressed node after the split character) is non zero, - * create a "postfix node". If the postfix node has just one character - * set iscompr to 0, otherwise iscompr to 1. Set the postfix node - * child pointer to $NEXT. - * - * 4b. IF the postfix len is zero, just use $NEXT as postfix pointer. - * - * 5. Set child[0] of split node to postfix node. - * - * 6. Set the split node as the current node, set current index at child[1] - * and continue insertion algorithm as usually. - * - * ============================= ALGO 2 ============================= - * - * For case 5, that is, if we stopped in the middle of a compressed - * node but no mismatch was found, do: - * - * Let $SPLITPOS be the zero-based index at which, in the - * compressed node array of characters, we stopped iterating because - * there were no more keys character to match. So in the example of - * the node "ANNIBALE", addig the string "ANNI", the $SPLITPOS is 4. - * - * 1. Save the current compressed node $NEXT pointer (the pointer to the - * child element, that is always present in compressed nodes). - * - * 2. Create a "postfix node" containing all the characters from $SPLITPOS - * to the end. Use $NEXT as the postfix node child pointer. - * If the postfix node length is 1, set iscompr to 0. - * Set the node as a key with the associated value of the new - * inserted key. - * - * 3. Trim the current node to contain the first $SPLITPOS characters. - * As usually if the new node length is just 1, set iscompr to 0. - * Take the iskey / associated value as it was in the orignal node. - * Fix the parent's reference. - * - * 4. Set the postfix node as the only child pointer of the trimmed - * node created at step 1. - */ - - /* ------------------------- ALGORITHM 1 --------------------------- */ - if (h->iscompr && i != len) { - debugf("ALGO 1: Stopped at compressed node %.*s (%p)\n", h->size, h->data, (void *)h); - debugf("Still to insert: %.*s\n", (int)(len - i), s + i); - debugf("Splitting at %d: '%c'\n", j, ((char *)h->data)[j]); - debugf("Other (key) letter is '%c'\n", s[i]); - - /* 1: Save next pointer. */ - raxNode **childfield = raxNodeLastChildPtr(h); - raxNode *next; - memcpy(&next, childfield, sizeof(next)); - debugf("Next is %p\n", (void *)next); - debugf("iskey %d\n", h->iskey); - if (h->iskey) { - debugf("key value is %p\n", raxGetData(h)); - } - - /* Set the length of the additional nodes we will need. */ - size_t trimmedlen = j; - size_t postfixlen = h->size - j - 1; - int split_node_is_key = !trimmedlen && h->iskey && !h->isnull; - size_t nodesize; - - /* 2: Create the split node. Also allocate the other nodes we'll need - * ASAP, so that it will be simpler to handle OOM. */ - raxNode *splitnode = raxNewNode(1, split_node_is_key); - raxNode *trimmed = NULL; - raxNode *postfix = NULL; - - if (trimmedlen) { - nodesize = sizeof(raxNode) + trimmedlen + raxPadding(trimmedlen) + sizeof(raxNode *); - if (h->iskey && !h->isnull) - nodesize += sizeof(void *); - trimmed = rax_malloc(nodesize); - } - - if (postfixlen) { - nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); - postfix = rax_malloc(nodesize); - } - - /* OOM? Abort now that the tree is untouched. */ - if (splitnode == NULL || (trimmedlen && trimmed == NULL) || - (postfixlen && postfix == NULL)) { - rax_free(splitnode); - rax_free(trimmed); - rax_free(postfix); - errno = ENOMEM; - return 0; - } - splitnode->data[0] = h->data[j]; - - if (j == 0) { - /* 3a: Replace the old node with the split node. */ - if (h->iskey) { - void *ndata = raxGetData(h); - raxSetData(splitnode, ndata); - } - memcpy(parentlink, &splitnode, sizeof(splitnode)); - } else { - /* 3b: Trim the compressed node. */ - trimmed->size = j; - memcpy(trimmed->data, h->data, j); - trimmed->iscompr = j > 1 ? 1 : 0; - trimmed->iskey = h->iskey; - trimmed->isnull = h->isnull; - if (h->iskey && !h->isnull) { - void *ndata = raxGetData(h); - raxSetData(trimmed, ndata); - } - raxNode **cp = raxNodeLastChildPtr(trimmed); - memcpy(cp, &splitnode, sizeof(splitnode)); - memcpy(parentlink, &trimmed, sizeof(trimmed)); - parentlink = cp; /* Set parentlink to splitnode parent. */ - rax->numnodes++; - } - - /* 4: Create the postfix node: what remains of the original - * compressed node after the split. */ - if (postfixlen) { - /* 4a: create a postfix node. */ - postfix->iskey = 0; - postfix->isnull = 0; - postfix->size = postfixlen; - postfix->iscompr = postfixlen > 1; - memcpy(postfix->data, h->data + j + 1, postfixlen); - raxNode **cp = raxNodeLastChildPtr(postfix); - memcpy(cp, &next, sizeof(next)); - rax->numnodes++; - } else { - /* 4b: just use next as postfix node. */ - postfix = next; - } - - /* 5: Set splitnode first child as the postfix node. */ - raxNode **splitchild = raxNodeLastChildPtr(splitnode); - memcpy(splitchild, &postfix, sizeof(postfix)); - - /* 6. Continue insertion: this will cause the splitnode to - * get a new child (the non common character at the currently - * inserted key). */ - rax_free(h); - h = splitnode; - } else if (h->iscompr && i == len) { - /* ------------------------- ALGORITHM 2 --------------------------- */ - debugf("ALGO 2: Stopped at compressed node %.*s (%p) j = %d\n", h->size, h->data, (void *)h, - j); - - /* Allocate postfix & trimmed nodes ASAP to fail for OOM gracefully. */ - size_t postfixlen = h->size - j; - size_t nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); - if (data != NULL) - nodesize += sizeof(void *); - raxNode *postfix = rax_malloc(nodesize); - - nodesize = sizeof(raxNode) + j + raxPadding(j) + sizeof(raxNode *); - if (h->iskey && !h->isnull) - nodesize += sizeof(void *); - raxNode *trimmed = rax_malloc(nodesize); - - if (postfix == NULL || trimmed == NULL) { - rax_free(postfix); - rax_free(trimmed); - errno = ENOMEM; - return 0; - } - - /* 1: Save next pointer. */ - raxNode **childfield = raxNodeLastChildPtr(h); - raxNode *next; - memcpy(&next, childfield, sizeof(next)); - - /* 2: Create the postfix node. */ - postfix->size = postfixlen; - postfix->iscompr = postfixlen > 1; - postfix->iskey = 1; - postfix->isnull = 0; - memcpy(postfix->data, h->data + j, postfixlen); - raxSetData(postfix, data); - raxNode **cp = raxNodeLastChildPtr(postfix); - memcpy(cp, &next, sizeof(next)); - rax->numnodes++; - - /* 3: Trim the compressed node. */ - trimmed->size = j; - trimmed->iscompr = j > 1; - trimmed->iskey = 0; - trimmed->isnull = 0; - memcpy(trimmed->data, h->data, j); - memcpy(parentlink, &trimmed, sizeof(trimmed)); - if (h->iskey) { - void *aux = raxGetData(h); - raxSetData(trimmed, aux); - } - - /* Fix the trimmed node child pointer to point to - * the postfix node. */ - cp = raxNodeLastChildPtr(trimmed); - memcpy(cp, &postfix, sizeof(postfix)); - - /* Finish! We don't need to continue with the insertion - * algorithm for ALGO 2. The key is already inserted. */ - rax->numele++; - rax_free(h); - return 1; /* Key inserted. */ - } - - /* We walked the radix tree as far as we could, but still there are left - * chars in our string. We need to insert the missing nodes. */ - while (i < len) { - raxNode *child; - - /* If this node is going to have a single child, and there - * are other characters, so that that would result in a chain - * of single-childed nodes, turn it into a compressed node. */ - if (h->size == 0 && len - i > 1) { - debugf("Inserting compressed node\n"); - size_t comprsize = len - i; - if (comprsize > RAX_NODE_MAX_SIZE) - comprsize = RAX_NODE_MAX_SIZE; - raxNode *newh = raxCompressNode(h, s + i, comprsize, &child); - if (newh == NULL) - goto oom; - h = newh; - memcpy(parentlink, &h, sizeof(h)); - parentlink = raxNodeLastChildPtr(h); - i += comprsize; - } else { - debugf("Inserting normal node\n"); - raxNode **new_parentlink; - raxNode *newh = raxAddChild(h, s[i], &child, &new_parentlink); - if (newh == NULL) - goto oom; - h = newh; - memcpy(parentlink, &h, sizeof(h)); - parentlink = new_parentlink; - i++; - } - rax->numnodes++; - h = child; - } - raxNode *newh = raxReallocForData(h, data); - if (newh == NULL) - goto oom; - h = newh; - if (!h->iskey) - rax->numele++; - raxSetData(h, data); - memcpy(parentlink, &h, sizeof(h)); - return 1; /* Element inserted. */ - -oom: - /* This code path handles out of memory after part of the sub-tree was - * already modified. Set the node as a key, and then remove it. However we - * do that only if the node is a terminal node, otherwise if the OOM - * happened reallocating a node in the middle, we don't need to free - * anything. */ - if (h->size == 0) { - h->isnull = 1; - h->iskey = 1; - rax->numele++; /* Compensate the next remove. */ - assert(raxRemove(rax, s, i, NULL) != 0); - } - errno = ENOMEM; - return 0; -} - -/* Overwriting insert. Just a wrapper for raxGenericInsert() that will - * update the element if there is already one for the same key. */ -int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { - return raxGenericInsert(rax, s, len, data, old, 1); -} - -/* Non overwriting insert function: this if an element with the same key - * exists, the value is not updated and the function returns 0. - * This is a just a wrapper for raxGenericInsert(). */ -int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { - return raxGenericInsert(rax, s, len, data, old, 0); -} - -/* Find a key in the rax, returns raxNotFound special void pointer value - * if the item was not found, otherwise the value associated with the - * item is returned. */ -void *raxFind(rax *rax, unsigned char *s, size_t len) { - raxNode *h; - - debugf("### Lookup: %.*s\n", (int)len, s); - int splitpos = 0; - size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, NULL); - if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) - return raxNotFound; - return raxGetData(h); -} - -/* Return the memory address where the 'parent' node stores the specified - * 'child' pointer, so that the caller can update the pointer with another - * one if needed. The function assumes it will find a match, otherwise the - * operation is an undefined behavior (it will continue scanning the - * memory without any bound checking). */ -raxNode **raxFindParentLink(raxNode *parent, raxNode *child) { - raxNode **cp = raxNodeFirstChildPtr(parent); - raxNode *c; - while (1) { - memcpy(&c, cp, sizeof(c)); - if (c == child) - break; - cp++; - } - return cp; -} - -/* Low level child removal from node. The new node pointer (after the child - * removal) is returned. Note that this function does not fix the pointer - * of the parent node in its parent, so this task is up to the caller. - * The function never fails for out of memory. */ -raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { - debugnode("raxRemoveChild before", parent); - /* If parent is a compressed node (having a single child, as for definition - * of the data structure), the removal of the child consists into turning - * it into a normal node without children. */ - if (parent->iscompr) { - void *data = NULL; - if (parent->iskey) - data = raxGetData(parent); - parent->isnull = 0; - parent->iscompr = 0; - parent->size = 0; - if (parent->iskey) - raxSetData(parent, data); - debugnode("raxRemoveChild after", parent); - return parent; - } - - /* Otherwise we need to scan for the child pointer and memmove() - * accordingly. - * - * 1. To start we seek the first element in both the children - * pointers and edge bytes in the node. */ - raxNode **cp = raxNodeFirstChildPtr(parent); - raxNode **c = cp; - unsigned char *e = parent->data; - - /* 2. Search the child pointer to remove inside the array of children - * pointers. */ - while (1) { - raxNode *aux; - memcpy(&aux, c, sizeof(aux)); - if (aux == child) - break; - c++; - e++; - } - - /* 3. Remove the edge and the pointer by memmoving the remaining children - * pointer and edge bytes one position before. */ - int taillen = parent->size - (e - parent->data) - 1; - debugf("raxRemoveChild tail len: %d\n", taillen); - memmove(e, e + 1, taillen); - - /* Compute the shift, that is the amount of bytes we should move our - * child pointers to the left, since the removal of one edge character - * and the corresponding padding change, may change the layout. - * We just check if in the old version of the node there was at the - * end just a single byte and all padding: in that case removing one char - * will remove a whole sizeof(void*) word. */ - size_t shift = ((parent->size + 4) % sizeof(void *)) == 1 ? sizeof(void *) : 0; - - /* Move the children pointers before the deletion point. */ - if (shift) - memmove(((char *)cp) - shift, cp, (parent->size - taillen - 1) * sizeof(raxNode **)); - - /* Move the remaining "tail" pointers at the right position as well. */ - size_t valuelen = (parent->iskey && !parent->isnull) ? sizeof(void *) : 0; - memmove(((char *)c) - shift, c + 1, taillen * sizeof(raxNode **) + valuelen); - - /* 4. Update size. */ - parent->size--; - - /* realloc the node according to the theoretical memory usage, to free - * data if we are over-allocating right now. */ - raxNode *newnode = rax_realloc(parent, raxNodeCurrentLength(parent)); - if (newnode) { - debugnode("raxRemoveChild after", newnode); - } - /* Note: if rax_realloc() fails we just return the old address, which - * is valid. */ - return newnode ? newnode : parent; -} - -/* Remove the specified item. Returns 1 if the item was found and - * deleted, 0 otherwise. */ -int raxRemove(rax *rax, unsigned char *s, size_t len, void **old) { - raxNode *h; - raxStack ts; - - debugf("### Delete: %.*s\n", (int)len, s); - raxStackInit(&ts); - int splitpos = 0; - size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, &ts); - if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { - raxStackFree(&ts); - return 0; - } - if (old) - *old = raxGetData(h); - h->iskey = 0; - rax->numele--; - - /* If this node has no children, the deletion needs to reclaim the - * no longer used nodes. This is an iterative process that needs to - * walk the three upward, deleting all the nodes with just one child - * that are not keys, until the head of the rax is reached or the first - * node with more than one child is found. */ - - int trycompress = 0; /* Will be set to 1 if we should try to optimize the - tree resulting from the deletion. */ - - if (h->size == 0) { - debugf("Key deleted in node without children. Cleanup needed.\n"); - raxNode *child = NULL; - while (h != rax->head) { - child = h; - debugf("Freeing child %p [%.*s] key:%d\n", (void *)child, (int)child->size, - (char *)child->data, child->iskey); - rax_free(child); - rax->numnodes--; - h = raxStackPop(&ts); - /* If this node has more then one child, or actually holds - * a key, stop here. */ - if (h->iskey || (!h->iscompr && h->size != 1)) - break; - } - if (child) { - debugf("Unlinking child %p from parent %p\n", (void *)child, (void *)h); - raxNode *new = raxRemoveChild(h, child); - if (new != h) { - raxNode *parent = raxStackPeek(&ts); - raxNode **parentlink; - if (parent == NULL) { - parentlink = &rax->head; - } else { - parentlink = raxFindParentLink(parent, h); - } - memcpy(parentlink, &new, sizeof(new)); - } - - /* If after the removal the node has just a single child - * and is not a key, we need to try to compress it. */ - if (new->size == 1 && new->iskey == 0) { - trycompress = 1; - h = new; - } - } - } else if (h->size == 1) { - /* If the node had just one child, after the removal of the key - * further compression with adjacent nodes is pontentially possible. */ - trycompress = 1; - } - - /* Don't try node compression if our nodes pointers stack is not - * complete because of OOM while executing raxLowWalk() */ - if (trycompress && ts.oom) - trycompress = 0; - - /* Recompression: if trycompress is true, 'h' points to a radix tree node - * that changed in a way that could allow to compress nodes in this - * sub-branch. Compressed nodes represent chains of nodes that are not - * keys and have a single child, so there are two deletion events that - * may alter the tree so that further compression is needed: - * - * 1) A node with a single child was a key and now no longer is a key. - * 2) A node with two children now has just one child. - * - * We try to navigate upward till there are other nodes that can be - * compressed, when we reach the upper node which is not a key and has - * a single child, we scan the chain of children to collect the - * compressable part of the tree, and replace the current node with the - * new one, fixing the child pointer to reference the first non - * compressable node. - * - * Example of case "1". A tree stores the keys "FOO" = 1 and - * "FOOBAR" = 2: - * - * - * "FOO" -> "BAR" -> [] (2) - * (1) - * - * After the removal of "FOO" the tree can be compressed as: - * - * "FOOBAR" -> [] (2) - * - * - * Example of case "2". A tree stores the keys "FOOBAR" = 1 and - * "FOOTER" = 2: - * - * |B| -> "AR" -> [] (1) - * "FOO" -> |-| - * |T| -> "ER" -> [] (2) - * - * After the removal of "FOOTER" the resulting tree is: - * - * "FOO" -> |B| -> "AR" -> [] (1) - * - * That can be compressed into: - * - * "FOOBAR" -> [] (1) - */ - if (trycompress) { - debugf("After removing %.*s:\n", (int)len, s); - debugnode("Compression may be needed", h); - debugf("Seek start node\n"); - - /* Try to reach the upper node that is compressible. - * At the end of the loop 'h' will point to the first node we - * can try to compress and 'parent' to its parent. */ - raxNode *parent; - while (1) { - parent = raxStackPop(&ts); - if (!parent || parent->iskey || (!parent->iscompr && parent->size != 1)) - break; - h = parent; - debugnode("Going up to", h); - } - raxNode *start = h; /* Compression starting node. */ - - /* Scan chain of nodes we can compress. */ - size_t comprsize = h->size; - int nodes = 1; - while (h->size != 0) { - raxNode **cp = raxNodeLastChildPtr(h); - memcpy(&h, cp, sizeof(h)); - if (h->iskey || (!h->iscompr && h->size != 1)) - break; - /* Stop here if going to the next node would result into - * a compressed node larger than h->size can hold. */ - if (comprsize + h->size > RAX_NODE_MAX_SIZE) - break; - nodes++; - comprsize += h->size; - } - if (nodes > 1) { - /* If we can compress, create the new node and populate it. */ - size_t nodesize = - sizeof(raxNode) + comprsize + raxPadding(comprsize) + sizeof(raxNode *); - raxNode *new = rax_malloc(nodesize); - /* An out of memory here just means we cannot optimize this - * node, but the tree is left in a consistent state. */ - if (new == NULL) { - raxStackFree(&ts); - return 1; - } - new->iskey = 0; - new->isnull = 0; - new->iscompr = 1; - new->size = comprsize; - rax->numnodes++; - - /* Scan again, this time to populate the new node content and - * to fix the new node child pointer. At the same time we free - * all the nodes that we'll no longer use. */ - comprsize = 0; - h = start; - while (h->size != 0) { - memcpy(new->data + comprsize, h->data, h->size); - comprsize += h->size; - raxNode **cp = raxNodeLastChildPtr(h); - raxNode *tofree = h; - memcpy(&h, cp, sizeof(h)); - rax_free(tofree); - rax->numnodes--; - if (h->iskey || (!h->iscompr && h->size != 1)) - break; - } - debugnode("New node", new); - - /* Now 'h' points to the first node that we still need to use, - * so our new node child pointer will point to it. */ - raxNode **cp = raxNodeLastChildPtr(new); - memcpy(cp, &h, sizeof(h)); - - /* Fix parent link. */ - if (parent) { - raxNode **parentlink = raxFindParentLink(parent, start); - memcpy(parentlink, &new, sizeof(new)); - } else { - rax->head = new; - } - - debugf("Compressed %d nodes, %d total bytes\n", nodes, (int)comprsize); - } - } - raxStackFree(&ts); - return 1; -} - -/* This is the core of raxFree(): performs a depth-first scan of the - * tree and releases all the nodes found. */ -void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void *)) { - debugnode("free traversing", n); - int numchildren = n->iscompr ? 1 : n->size; - raxNode **cp = raxNodeLastChildPtr(n); - while (numchildren--) { - raxNode *child; - memcpy(&child, cp, sizeof(child)); - raxRecursiveFree(rax, child, free_callback); - cp--; - } - debugnode("free depth-first", n); - if (free_callback && n->iskey && !n->isnull) - free_callback(raxGetData(n)); - rax_free(n); - rax->numnodes--; -} - -/* Free a whole radix tree, calling the specified callback in order to - * free the auxiliary data. */ -void raxFreeWithCallback(rax *rax, void (*free_callback)(void *)) { - raxRecursiveFree(rax, rax->head, free_callback); - assert(rax->numnodes == 0); - rax_free(rax); -} - -/* Free a whole radix tree. */ -void raxFree(rax *rax) { raxFreeWithCallback(rax, NULL); } - -/* ------------------------------- Iterator --------------------------------- */ - -/* Initialize a Rax iterator. This call should be performed a single time - * to initialize the iterator, and must be followed by a raxSeek() call, - * otherwise the raxPrev()/raxNext() functions will just return EOF. */ -void raxStart(raxIterator *it, rax *rt) { - it->flags = RAX_ITER_EOF; /* No crash if the iterator is not seeked. */ - it->rt = rt; - it->key_len = 0; - it->key = it->key_static_string; - it->key_max = RAX_ITER_STATIC_LEN; - it->data = NULL; - it->node_cb = NULL; - raxStackInit(&it->stack); -} - -/* Append characters at the current key string of the iterator 'it'. This - * is a low level function used to implement the iterator, not callable by - * the user. Returns 0 on out of memory, otherwise 1 is returned. */ -int raxIteratorAddChars(raxIterator *it, unsigned char *s, size_t len) { - if (it->key_max < it->key_len + len) { - unsigned char *old = (it->key == it->key_static_string) ? NULL : it->key; - size_t new_max = (it->key_len + len) * 2; - it->key = rax_realloc(old, new_max); - if (it->key == NULL) { - it->key = (!old) ? it->key_static_string : old; - errno = ENOMEM; - return 0; - } - if (old == NULL) - memcpy(it->key, it->key_static_string, it->key_len); - it->key_max = new_max; - } - /* Use memmove since there could be an overlap between 's' and - * it->key when we use the current key in order to re-seek. */ - memmove(it->key + it->key_len, s, len); - it->key_len += len; - return 1; -} - -/* Remove the specified number of chars from the right of the current - * iterator key. */ -void raxIteratorDelChars(raxIterator *it, size_t count) { it->key_len -= count; } - -/* Do an iteration step towards the next element. At the end of the step the - * iterator key will represent the (new) current key. If it is not possible - * to step in the specified direction since there are no longer elements, the - * iterator is flagged with RAX_ITER_EOF. - * - * If 'noup' is true the function starts directly scanning for the next - * lexicographically smaller children, and the current node is already assumed - * to be the parent of the last key node, so the first operation to go back to - * the parent will be skipped. This option is used by raxSeek() when - * implementing seeking a non existing element with the ">" or "<" options: - * the starting node is not a key in that particular case, so we start the scan - * from a node that does not represent the key set. - * - * The function returns 1 on success or 0 on out of memory. */ -int raxIteratorNextStep(raxIterator *it, int noup) { - if (it->flags & RAX_ITER_EOF) { - return 1; - } else if (it->flags & RAX_ITER_JUST_SEEKED) { - it->flags &= ~RAX_ITER_JUST_SEEKED; - return 1; - } - - /* Save key len, stack items and the node where we are currently - * so that on iterator EOF we can restore the current key and state. */ - size_t orig_key_len = it->key_len; - size_t orig_stack_items = it->stack.items; - raxNode *orig_node = it->node; - - while (1) { - int children = it->node->iscompr ? 1 : it->node->size; - if (!noup && children) { - debugf("GO DEEPER\n"); - /* Seek the lexicographically smaller key in this subtree, which - * is the first one found always going torwards the first child - * of every successive node. */ - if (!raxStackPush(&it->stack, it->node)) - return 0; - raxNode **cp = raxNodeFirstChildPtr(it->node); - if (!raxIteratorAddChars(it, it->node->data, it->node->iscompr ? it->node->size : 1)) - return 0; - memcpy(&it->node, cp, sizeof(it->node)); - /* Call the node callback if any, and replace the node pointer - * if the callback returns true. */ - if (it->node_cb && it->node_cb(&it->node)) - memcpy(cp, &it->node, sizeof(it->node)); - /* For "next" step, stop every time we find a key along the - * way, since the key is lexicograhically smaller compared to - * what follows in the sub-children. */ - if (it->node->iskey) { - it->data = raxGetData(it->node); - return 1; - } - } else { - /* If we finished exporing the previous sub-tree, switch to the - * new one: go upper until a node is found where there are - * children representing keys lexicographically greater than the - * current key. */ - while (1) { - int old_noup = noup; - - /* Already on head? Can't go up, iteration finished. */ - if (!noup && it->node == it->rt->head) { - it->flags |= RAX_ITER_EOF; - it->stack.items = orig_stack_items; - it->key_len = orig_key_len; - it->node = orig_node; - return 1; - } - /* If there are no children at the current node, try parent's - * next child. */ - unsigned char prevchild = it->key[it->key_len - 1]; - if (!noup) { - it->node = raxStackPop(&it->stack); - } else { - noup = 0; - } - /* Adjust the current key to represent the node we are - * at. */ - int todel = it->node->iscompr ? it->node->size : 1; - raxIteratorDelChars(it, todel); - - /* Try visiting the next child if there was at least one - * additional child. */ - if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { - raxNode **cp = raxNodeFirstChildPtr(it->node); - int i = 0; - while (i < it->node->size) { - debugf("SCAN NEXT %c\n", it->node->data[i]); - if (it->node->data[i] > prevchild) - break; - i++; - cp++; - } - if (i != it->node->size) { - debugf("SCAN found a new node\n"); - raxIteratorAddChars(it, it->node->data + i, 1); - if (!raxStackPush(&it->stack, it->node)) - return 0; - memcpy(&it->node, cp, sizeof(it->node)); - /* Call the node callback if any, and replace the node - * pointer if the callback returns true. */ - if (it->node_cb && it->node_cb(&it->node)) - memcpy(cp, &it->node, sizeof(it->node)); - if (it->node->iskey) { - it->data = raxGetData(it->node); - return 1; - } - break; - } - } - } - } - } -} - -/* Seek the greatest key in the subtree at the current node. Return 0 on - * out of memory, otherwise 1. This is an helper function for different - * iteration functions below. */ -int raxSeekGreatest(raxIterator *it) { - while (it->node->size) { - if (it->node->iscompr) { - if (!raxIteratorAddChars(it, it->node->data, it->node->size)) - return 0; - } else { - if (!raxIteratorAddChars(it, it->node->data + it->node->size - 1, 1)) - return 0; - } - raxNode **cp = raxNodeLastChildPtr(it->node); - if (!raxStackPush(&it->stack, it->node)) - return 0; - memcpy(&it->node, cp, sizeof(it->node)); - } - return 1; -} - -/* Like raxIteratorNextStep() but implements an iteration step moving - * to the lexicographically previous element. The 'noup' option has a similar - * effect to the one of raxIteratorNextStep(). */ -int raxIteratorPrevStep(raxIterator *it, int noup) { - if (it->flags & RAX_ITER_EOF) { - return 1; - } else if (it->flags & RAX_ITER_JUST_SEEKED) { - it->flags &= ~RAX_ITER_JUST_SEEKED; - return 1; - } - - /* Save key len, stack items and the node where we are currently - * so that on iterator EOF we can restore the current key and state. */ - size_t orig_key_len = it->key_len; - size_t orig_stack_items = it->stack.items; - raxNode *orig_node = it->node; - - while (1) { - int old_noup = noup; - - /* Already on head? Can't go up, iteration finished. */ - if (!noup && it->node == it->rt->head) { - it->flags |= RAX_ITER_EOF; - it->stack.items = orig_stack_items; - it->key_len = orig_key_len; - it->node = orig_node; - return 1; - } - - unsigned char prevchild = it->key[it->key_len - 1]; - if (!noup) { - it->node = raxStackPop(&it->stack); - } else { - noup = 0; - } - - /* Adjust the current key to represent the node we are - * at. */ - int todel = it->node->iscompr ? it->node->size : 1; - raxIteratorDelChars(it, todel); - - /* Try visiting the prev child if there is at least one - * child. */ - if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { - raxNode **cp = raxNodeLastChildPtr(it->node); - int i = it->node->size - 1; - while (i >= 0) { - debugf("SCAN PREV %c\n", it->node->data[i]); - if (it->node->data[i] < prevchild) - break; - i--; - cp--; - } - /* If we found a new subtree to explore in this node, - * go deeper following all the last children in order to - * find the key lexicographically greater. */ - if (i != -1) { - debugf("SCAN found a new node\n"); - /* Enter the node we just found. */ - if (!raxIteratorAddChars(it, it->node->data + i, 1)) - return 0; - if (!raxStackPush(&it->stack, it->node)) - return 0; - memcpy(&it->node, cp, sizeof(it->node)); - /* Seek sub-tree max. */ - if (!raxSeekGreatest(it)) - return 0; - } - } - - /* Return the key: this could be the key we found scanning a new - * subtree, or if we did not find a new subtree to explore here, - * before giving up with this node, check if it's a key itself. */ - if (it->node->iskey) { - it->data = raxGetData(it->node); - return 1; - } - } -} - -/* Seek an iterator at the specified element. - * Return 0 if the seek failed for syntax error or out of memory. Otherwise - * 1 is returned. When 0 is returned for out of memory, errno is set to - * the ENOMEM value. */ -int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len) { - int eq = 0, lt = 0, gt = 0, first = 0, last = 0; - - it->stack.items = 0; /* Just resetting. Intialized by raxStart(). */ - it->flags |= RAX_ITER_JUST_SEEKED; - it->flags &= ~RAX_ITER_EOF; - it->key_len = 0; - it->node = NULL; - - /* Set flags according to the operator used to perform the seek. */ - if (op[0] == '>') { - gt = 1; - if (op[1] == '=') - eq = 1; - } else if (op[0] == '<') { - lt = 1; - if (op[1] == '=') - eq = 1; - } else if (op[0] == '=') { - eq = 1; - } else if (op[0] == '^') { - first = 1; - } else if (op[0] == '$') { - last = 1; - } else { - errno = 0; - return 0; /* Error. */ - } - - /* If there are no elements, set the EOF condition immediately and - * return. */ - if (it->rt->numele == 0) { - it->flags |= RAX_ITER_EOF; - return 1; - } - - if (first) { - /* Seeking the first key greater or equal to the empty string - * is equivalent to seeking the smaller key available. */ - return raxSeek(it, ">=", NULL, 0); - } - - if (last) { - /* Find the greatest key taking always the last child till a - * final node is found. */ - it->node = it->rt->head; - if (!raxSeekGreatest(it)) - return 0; - assert(it->node->iskey); - it->data = raxGetData(it->node); - return 1; - } - - /* We need to seek the specified key. What we do here is to actually - * perform a lookup, and later invoke the prev/next key code that - * we already use for iteration. */ - int splitpos = 0; - size_t i = raxLowWalk(it->rt, ele, len, &it->node, NULL, &splitpos, &it->stack); - - /* Return OOM on incomplete stack info. */ - if (it->stack.oom) - return 0; - - if (eq && i == len && (!it->node->iscompr || splitpos == 0) && it->node->iskey) { - /* We found our node, since the key matches and we have an - * "equal" condition. */ - if (!raxIteratorAddChars(it, ele, len)) - return 0; /* OOM. */ - it->data = raxGetData(it->node); - } else if (lt || gt) { - /* Exact key not found or eq flag not set. We have to set as current - * key the one represented by the node we stopped at, and perform - * a next/prev operation to seek. To reconstruct the key at this node - * we start from the parent and go to the current node, accumulating - * the characters found along the way. */ - if (!raxStackPush(&it->stack, it->node)) - return 0; - for (size_t j = 1; j < it->stack.items; j++) { - raxNode *parent = it->stack.stack[j - 1]; - raxNode *child = it->stack.stack[j]; - if (parent->iscompr) { - if (!raxIteratorAddChars(it, parent->data, parent->size)) - return 0; - } else { - raxNode **cp = raxNodeFirstChildPtr(parent); - unsigned char *p = parent->data; - while (1) { - raxNode *aux; - memcpy(&aux, cp, sizeof(aux)); - if (aux == child) - break; - cp++; - p++; - } - if (!raxIteratorAddChars(it, p, 1)) - return 0; - } - } - raxStackPop(&it->stack); - - /* We need to set the iterator in the correct state to call next/prev - * step in order to seek the desired element. */ - debugf("After initial seek: i=%d len=%d key=%.*s\n", (int)i, (int)len, (int)it->key_len, - it->key); - if (i != len && !it->node->iscompr) { - /* If we stopped in the middle of a normal node because of a - * mismatch, add the mismatching character to the current key - * and call the iterator with the 'noup' flag so that it will try - * to seek the next/prev child in the current node directly based - * on the mismatching character. */ - if (!raxIteratorAddChars(it, ele + i, 1)) - return 0; - debugf("Seek normal node on mismatch: %.*s\n", (int)it->key_len, (char *)it->key); - - it->flags &= ~RAX_ITER_JUST_SEEKED; - if (lt && !raxIteratorPrevStep(it, 1)) - return 0; - if (gt && !raxIteratorNextStep(it, 1)) - return 0; - it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ - } else if (i != len && it->node->iscompr) { - debugf("Compressed mismatch: %.*s\n", (int)it->key_len, (char *)it->key); - /* In case of a mismatch within a compressed node. */ - int nodechar = it->node->data[splitpos]; - int keychar = ele[i]; - it->flags &= ~RAX_ITER_JUST_SEEKED; - if (gt) { - /* If the key the compressed node represents is greater - * than our seek element, continue forward, otherwise set the - * state in order to go back to the next sub-tree. */ - if (nodechar > keychar) { - if (!raxIteratorNextStep(it, 0)) - return 0; - } else { - if (!raxIteratorAddChars(it, it->node->data, it->node->size)) - return 0; - if (!raxIteratorNextStep(it, 1)) - return 0; - } - } - if (lt) { - /* If the key the compressed node represents is smaller - * than our seek element, seek the greater key in this - * subtree, otherwise set the state in order to go back to - * the previous sub-tree. */ - if (nodechar < keychar) { - if (!raxSeekGreatest(it)) - return 0; - it->data = raxGetData(it->node); - } else { - if (!raxIteratorAddChars(it, it->node->data, it->node->size)) - return 0; - if (!raxIteratorPrevStep(it, 1)) - return 0; - } - } - it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ - } else { - debugf("No mismatch: %.*s\n", (int)it->key_len, (char *)it->key); - /* If there was no mismatch we are into a node representing the - * key, (but which is not a key or the seek operator does not - * include 'eq'), or we stopped in the middle of a compressed node - * after processing all the key. Continue iterating as this was - * a legitimate key we stopped at. */ - it->flags &= ~RAX_ITER_JUST_SEEKED; - if (it->node->iscompr && it->node->iskey && splitpos && lt) { - /* If we stopped in the middle of a compressed node with - * perfect match, and the condition is to seek a key "<" than - * the specified one, then if this node is a key it already - * represents our match. For instance we may have nodes: - * - * "f" -> "oobar" = 1 -> "" = 2 - * - * Representing keys "f" = 1, "foobar" = 2. A seek for - * the key < "foo" will stop in the middle of the "oobar" - * node, but will be our match, representing the key "f". - * - * So in that case, we don't seek backward. */ - it->data = raxGetData(it->node); - } else { - if (gt && !raxIteratorNextStep(it, 0)) - return 0; - if (lt && !raxIteratorPrevStep(it, 0)) - return 0; - } - it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ - } - } else { - /* If we are here just eq was set but no match was found. */ - it->flags |= RAX_ITER_EOF; - return 1; - } - return 1; -} - -/* Go to the next element in the scope of the iterator 'it'. - * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is - * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ -int raxNext(raxIterator *it) { - if (!raxIteratorNextStep(it, 0)) { - errno = ENOMEM; - return 0; - } - if (it->flags & RAX_ITER_EOF) { - errno = 0; - return 0; - } - return 1; -} - -/* Go to the previous element in the scope of the iterator 'it'. - * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is - * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ -int raxPrev(raxIterator *it) { - if (!raxIteratorPrevStep(it, 0)) { - errno = ENOMEM; - return 0; - } - if (it->flags & RAX_ITER_EOF) { - errno = 0; - return 0; - } - return 1; -} - -/* Perform a random walk starting in the current position of the iterator. - * Return 0 if the tree is empty or on out of memory. Otherwise 1 is returned - * and the iterator is set to the node reached after doing a random walk - * of 'steps' steps. If the 'steps' argument is 0, the random walk is performed - * using a random number of steps between 1 and two times the logarithm of - * the number of elements. - * - * NOTE: if you use this function to generate random elements from the radix - * tree, expect a disappointing distribution. A random walk produces good - * random elements if the tree is not sparse, however in the case of a radix - * tree certain keys will be reported much more often than others. At least - * this function should be able to expore every possible element eventually. */ -int raxRandomWalk(raxIterator *it, size_t steps) { - if (it->rt->numele == 0) { - it->flags |= RAX_ITER_EOF; - return 0; - } - - if (steps == 0) { - size_t fle = floor(log(it->rt->numele)); - fle *= 2; - steps = 1 + rand() % fle; - } - - raxNode *n = it->node; - while (steps > 0 || !n->iskey) { - int numchildren = n->iscompr ? 1 : n->size; - int r = rand() % (numchildren + (n != it->rt->head)); - - if (r == numchildren) { - /* Go up to parent. */ - n = raxStackPop(&it->stack); - int todel = n->iscompr ? n->size : 1; - raxIteratorDelChars(it, todel); - } else { - /* Select a random child. */ - if (n->iscompr) { - if (!raxIteratorAddChars(it, n->data, n->size)) - return 0; - } else { - if (!raxIteratorAddChars(it, n->data + r, 1)) - return 0; - } - raxNode **cp = raxNodeFirstChildPtr(n) + r; - if (!raxStackPush(&it->stack, n)) - return 0; - memcpy(&n, cp, sizeof(n)); - } - if (n->iskey) - steps--; - } - it->node = n; - return 1; -} - -/* Compare the key currently pointed by the iterator to the specified - * key according to the specified operator. Returns 1 if the comparison is - * true, otherwise 0 is returned. */ -int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len) { - int eq = 0, lt = 0, gt = 0; - - if (op[0] == '=' || op[1] == '=') - eq = 1; - if (op[0] == '>') - gt = 1; - else if (op[0] == '<') - lt = 1; - else if (op[1] != '=') - return 0; /* Syntax error. */ - - size_t minlen = key_len < iter->key_len ? key_len : iter->key_len; - int cmp = memcmp(iter->key, key, minlen); - - /* Handle == */ - if (lt == 0 && gt == 0) - return cmp == 0 && key_len == iter->key_len; - - /* Handle >, >=, <, <= */ - if (cmp == 0) { - /* Same prefix: longer wins. */ - if (eq && key_len == iter->key_len) - return 1; - else if (lt) - return iter->key_len < key_len; - else if (gt) - return iter->key_len > key_len; - else - return 0; /* Avoid warning, just 'eq' is handled before. */ - } else if (cmp > 0) { - return gt ? 1 : 0; - } else /* (cmp < 0) */ { - return lt ? 1 : 0; - } -} - -/* Free the iterator. */ -void raxStop(raxIterator *it) { - if (it->key != it->key_static_string) - rax_free(it->key); - raxStackFree(&it->stack); -} - -/* Return if the iterator is in an EOF state. This happens when raxSeek() - * failed to seek an appropriate element, so that raxNext() or raxPrev() - * will return zero, or when an EOF condition was reached while iterating - * with raxNext() and raxPrev(). */ -int raxEOF(raxIterator *it) { return it->flags & RAX_ITER_EOF; } - -/* Return the number of elements inside the radix tree. */ -uint64_t raxSize(rax *rax) { return rax->numele; } - -/* ----------------------------- Introspection ------------------------------ */ - -/* This function is mostly used for debugging and learning purposes. - * It shows an ASCII representation of a tree on standard output, outling - * all the nodes and the contained keys. - * - * The representation is as follow: - * - * "foobar" (compressed node) - * [abc] (normal node with three children) - * [abc]=0x12345678 (node is a key, pointing to value 0x12345678) - * [] (a normal empty node) - * - * Children are represented in new idented lines, each children prefixed by - * the "`-(x)" string, where "x" is the edge byte. - * - * [abc] - * `-(a) "ladin" - * `-(b) [kj] - * `-(c) [] - * - * However when a node has a single child the following representation - * is used instead: - * - * [abc] -> "ladin" -> [] - */ - -/* The actual implementation of raxShow(). */ -void raxRecursiveShow(int level, int lpad, raxNode *n) { - char s = n->iscompr ? '"' : '['; - char e = n->iscompr ? '"' : ']'; - - int numchars = printf("%c%.*s%c", s, n->size, n->data, e); - if (n->iskey) { - numchars += printf("=%p", raxGetData(n)); - } - - int numchildren = n->iscompr ? 1 : n->size; - /* Note that 7 and 4 magic constants are the string length - * of " `-(x) " and " -> " respectively. */ - if (level) { - lpad += (numchildren > 1) ? 7 : 4; - if (numchildren == 1) - lpad += numchars; - } - raxNode **cp = raxNodeFirstChildPtr(n); - for (int i = 0; i < numchildren; i++) { - char *branch = " `-(%c) "; - if (numchildren > 1) { - printf("\n"); - for (int j = 0; j < lpad; j++) - putchar(' '); - printf(branch, n->data[i]); - } else { - printf(" -> "); - } - raxNode *child; - memcpy(&child, cp, sizeof(child)); - raxRecursiveShow(level + 1, lpad, child); - cp++; - } -} - -/* Show a tree, as outlined in the comment above. */ -void raxShow(rax *rax) { - raxRecursiveShow(0, 0, rax->head); - putchar('\n'); -} - -/* Used by debugnode() macro to show info about a given node. */ -void raxDebugShowNode(const char *msg, raxNode *n) { - if (raxDebugMsg == 0) - return; - printf("%s: %p [%.*s] key:%d size:%d children:", msg, (void *)n, (int)n->size, (char *)n->data, - n->iskey, n->size); - int numcld = n->iscompr ? 1 : n->size; - raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld - 1); - while (numcld--) { - raxNode *child; - memcpy(&child, cldptr, sizeof(child)); - cldptr++; - printf("%p ", (void *)child); - } - printf("\n"); - fflush(stdout); -} - -/* Touch all the nodes of a tree returning a check sum. This is useful - * in order to make Valgrind detect if there is something wrong while - * reading the data structure. - * - * This function was used in order to identify Rax bugs after a big refactoring - * using this technique: - * - * 1. The rax-test is executed using Valgrind, adding a printf() so that for - * the fuzz tester we see what iteration in the loop we are in. - * 2. After every modification of the radix tree made by the fuzz tester - * in rax-test.c, we add a call to raxTouch(). - * 3. Now as soon as an operation will corrupt the tree, raxTouch() will - * detect it (via Valgrind) immediately. We can add more calls to narrow - * the state. - * 4. At this point a good idea is to enable Rax debugging messages immediately - * before the moment the tree is corrupted, to see what happens. - */ -unsigned long raxTouch(raxNode *n) { - debugf("Touching %p\n", (void *)n); - unsigned long sum = 0; - if (n->iskey) { - sum += (unsigned long)raxGetData(n); - } - - int numchildren = n->iscompr ? 1 : n->size; - raxNode **cp = raxNodeFirstChildPtr(n); - int count = 0; - for (int i = 0; i < numchildren; i++) { - if (numchildren > 1) { - sum += (long)n->data[i]; - } - raxNode *child; - memcpy(&child, cp, sizeof(child)); - if (child == (void *)0x65d1760) - count++; - if (count > 1) - exit(1); - sum += raxTouch(child); - cp++; - } - return sum; -} diff --git a/src/util/rax.h b/src/util/rax.h deleted file mode 100644 index 6ccb69200..000000000 --- a/src/util/rax.h +++ /dev/null @@ -1,218 +0,0 @@ -/* Rax -- A radix tree implementation. - * - * Copyright (c) 2017-2018, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef RAX_H -#define RAX_H - -#include - -/* Representation of a radix tree as implemented in this file, that contains - * the strings "foo", "foobar" and "footer" after the insertion of each - * word. When the node represents a key inside the radix tree, we write it - * between [], otherwise it is written between (). - * - * This is the vanilla representation: - * - * (f) "" - * \ - * (o) "f" - * \ - * (o) "fo" - * \ - * [t b] "foo" - * / \ - * "foot" (e) (a) "foob" - * / \ - * "foote" (r) (r) "fooba" - * / \ - * "footer" [] [] "foobar" - * - * However, this implementation implements a very common optimization where - * successive nodes having a single child are "compressed" into the node - * itself as a string of characters, each representing a next-level child, - * and only the link to the node representing the last character node is - * provided inside the representation. So the above representation is turend - * into: - * - * ["foo"] "" - * | - * [t b] "foo" - * / \ - * "foot" ("er") ("ar") "foob" - * / \ - * "footer" [] [] "foobar" - * - * However this optimization makes the implementation a bit more complex. - * For instance if a key "first" is added in the above radix tree, a - * "node splitting" operation is needed, since the "foo" prefix is no longer - * composed of nodes having a single child one after the other. This is the - * above tree and the resulting node splitting after this event happens: - * - * - * (f) "" - * / - * (i o) "f" - * / \ - * "firs" ("rst") (o) "fo" - * / \ - * "first" [] [t b] "foo" - * / \ - * "foot" ("er") ("ar") "foob" - * / \ - * "footer" [] [] "foobar" - * - * Similarly after deletion, if a new chain of nodes having a single child - * is created (the chain must also not include nodes that represent keys), - * it must be compressed back into a single node. - * - */ - -#define RAX_NODE_MAX_SIZE ((1 << 29) - 1) -typedef struct raxNode { - uint32_t iskey : 1; /* Does this node contain a key? */ - uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ - uint32_t iscompr : 1; /* Node is compressed. */ - uint32_t size : 29; /* Number of children, or compressed string len. */ - /* Data layout is as follows: - * - * If node is not compressed we have 'size' bytes, one for each children - * character, and 'size' raxNode pointers, point to each child node. - * Note how the character is not stored in the children but in the - * edge of the parents: - * - * [header iscompr=0][abc][a-ptr][b-ptr][c-ptr](value-ptr?) - * - * if node is compressed (iscompr bit is 1) the node has 1 children. - * In that case the 'size' bytes of the string stored immediately at - * the start of the data section, represent a sequence of successive - * nodes linked one after the other, for which only the last one in - * the sequence is actually represented as a node, and pointed to by - * the current compressed node. - * - * [header iscompr=1][xyz][z-ptr](value-ptr?) - * - * Both compressed and not compressed nodes can represent a key - * with associated data in the radix tree at any level (not just terminal - * nodes). - * - * If the node has an associated key (iskey=1) and is not NULL - * (isnull=0), then after the raxNode pointers poiting to the - * children, an additional value pointer is present (as you can see - * in the representation above as "value-ptr" field). - */ - unsigned char data[]; -} raxNode; - -typedef struct rax { - raxNode *head; - uint64_t numele; - uint64_t numnodes; -} rax; - -/* Stack data structure used by raxLowWalk() in order to, optionally, return - * a list of parent nodes to the caller. The nodes do not have a "parent" - * field for space concerns, so we use the auxiliary stack when needed. */ -#define RAX_STACK_STATIC_ITEMS 32 -typedef struct raxStack { - void **stack; /* Points to static_items or an heap allocated array. */ - size_t items, maxitems; /* Number of items contained and total space. */ - /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap - * and use this static array of pointers instead. */ - void *static_items[RAX_STACK_STATIC_ITEMS]; - int oom; /* True if pushing into this stack failed for OOM at some point. */ -} raxStack; - -/* Optional callback used for iterators and be notified on each rax node, - * including nodes not representing keys. If the callback returns true - * the callback changed the node pointer in the iterator structure, and the - * iterator implementation will have to replace the pointer in the radix tree - * internals. This allows the callback to reallocate the node to perform - * very special operations, normally not needed by normal applications. - * - * This callback is used to perform very low level analysis of the radix tree - * structure, scanning each possible node (but the root node), or in order to - * reallocate the nodes to reduce the allocation fragmentation (this is the - * Redis application for this callback). - * - * This is currently only supported in forward iterations (raxNext) */ -typedef int (*raxNodeCallback)(raxNode **noderef); - -/* Radix tree iterator state is encapsulated into this data structure. */ -#define RAX_ITER_STATIC_LEN 128 -#define RAX_ITER_JUST_SEEKED \ - (1 << 0) /* Iterator was just seeked. Return current \ - element for the first iteration and \ - clear the flag. */ -#define RAX_ITER_EOF (1 << 1) /* End of iteration reached. */ -#define RAX_ITER_SAFE \ - (1 << 2) /* Safe iterator, allows operations while \ - iterating. But it is slower. */ -typedef struct raxIterator { - int flags; - rax *rt; /* Radix tree we are iterating. */ - unsigned char *key; /* The current string. */ - void *data; /* Data associated to this key. */ - size_t key_len; /* Current key length. */ - size_t key_max; /* Max key len the current key buffer can hold. */ - unsigned char key_static_string[RAX_ITER_STATIC_LEN]; - raxNode *node; /* Current node. Only for unsafe iteration. */ - raxStack stack; /* Stack used for unsafe iteration. */ - raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ -} raxIterator; - -/* A special pointer returned for not found items. */ -extern void *raxNotFound; - -/* Exported API. */ -rax *raxNew(void); -int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); -int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); -int raxRemove(rax *rax, unsigned char *s, size_t len, void **old); -void *raxFind(rax *rax, unsigned char *s, size_t len); -void raxFree(rax *rax); -void raxFreeWithCallback(rax *rax, void (*free_callback)(void *)); -void raxStart(raxIterator *it, rax *rt); -int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len); -int raxNext(raxIterator *it); -int raxPrev(raxIterator *it); -int raxRandomWalk(raxIterator *it, size_t steps); -int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len); -void raxStop(raxIterator *it); -int raxEOF(raxIterator *it); -void raxShow(rax *rax); -uint64_t raxSize(rax *rax); -unsigned long raxTouch(raxNode *n); -void raxSetDebugMsg(int onoff); - -/* Internal API. May be used by the node callback in order to access rax nodes - * in a low level way, so this function is exported as well. */ -void raxSetData(raxNode *n, void *data); - -#endif diff --git a/src/util/rax_malloc.h b/src/util/rax_malloc.h deleted file mode 100644 index c4e92199e..000000000 --- a/src/util/rax_malloc.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Rax -- A radix tree implementation. - * - * Copyright (c) 2017, Salvatore Sanfilippo - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of Redis nor the names of its contributors may be used - * to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -/* Allocator selection. - * - * This file is used in order to change the Rax allocator at compile time. - * Just define the following defines to what you want to use. Also add - * the include of your alternate allocator if needed (not needed in order - * to use the default libc allocator). */ - -#ifndef RAX_ALLOC_H -#define RAX_ALLOC_H -#define rax_malloc malloc -#define rax_realloc realloc -#define rax_free free -#endif From 6cd96525bedcb177ef6443b4ff1101838b9ca36b Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 6 Jun 2021 18:00:16 +0300 Subject: [PATCH 15/27] Refactor load time config --- src/backends/backends.c | 30 +-- src/config/config.c | 305 ++++++++++------------------- src/config/config.h | 66 +++---- src/execution/background_workers.c | 6 +- src/execution/background_workers.h | 6 - src/execution/onnx_timeout.c | 15 +- src/execution/onnx_timeout.h | 2 +- src/redisai.c | 14 +- tests/flow/tests_onnx.py | 2 +- 9 files changed, 170 insertions(+), 276 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index e9cb61bf2..0bd3be509 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -93,7 +93,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { RedisModule_Log(ctx, "warning", "Could not load TF backend from %s: %s", path, dlerror()); return REDISMODULE_ERR; } - RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = @@ -118,7 +118,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( - unsigned long)dlsym(handle, "RAI_ModelRunTF"); + unsigned long)dlsym(handle, "RAI_ModelRunTF"); if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTF")) { goto error; } @@ -139,7 +139,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { RedisModule_Log(ctx, "notice", "TF backend loaded from %s", path); return REDISMODULE_OK; - error: +error: dlclose(handle); RedisModule_Log(ctx, "warning", "TF backend not loaded from %s", path); return REDISMODULE_ERR; @@ -158,7 +158,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { dlerror()); return REDISMODULE_ERR; } - RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( @@ -182,7 +182,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( - unsigned long)dlsym(handle, "RAI_ModelRunTFLite"); + unsigned long)dlsym(handle, "RAI_ModelRunTFLite"); if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTFLite")) { goto error; } @@ -203,7 +203,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { RedisModule_Log(ctx, "notice", "TFLITE backend loaded from %s", path); return REDISMODULE_OK; - error: +error: dlclose(handle); RedisModule_Log(ctx, "warning", "TFLITE backend not loaded from %s", path); return REDISMODULE_ERR; @@ -222,7 +222,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { return REDISMODULE_ERR; } - RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. + RAI_LoadedBackend backend = {0}; // Initialize all the callbacks to NULL. int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( @@ -246,7 +246,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( - unsigned long)dlsym(handle, "RAI_ModelRunTorch"); + unsigned long)dlsym(handle, "RAI_ModelRunTorch"); if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTorch")) { goto error; } @@ -270,7 +270,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { } backend.script_run = (int (*)(RAI_Script *, const char *, RAI_ExecutionCtx *, RAI_Error *))( - unsigned long)dlsym(handle, "RAI_ScriptRunTorch"); + unsigned long)dlsym(handle, "RAI_ScriptRunTorch"); if (!_ValidateAPICreated(ctx, backend.script_run, "RAI_ScriptRunTorch")) { goto error; } @@ -285,7 +285,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { RedisModule_Log(ctx, "notice", "TORCH backend loaded from %s", path); return REDISMODULE_OK; - error: +error: dlclose(handle); RedisModule_Log(ctx, "warning", "TORCH backend not loaded from %s", path); return REDISMODULE_ERR; @@ -306,8 +306,8 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { RAI_LoadedBackend backend = {0}; int (*init_backend)(int (*)(const char *, void **)); - init_backend = (int (*) (int (*)(const char *, void **)))( - unsigned long)dlsym(handle, "RAI_InitBackendORT"); + init_backend = + (int (*)(int (*)(const char *, void **)))(unsigned long)dlsym(handle, "RAI_InitBackendORT"); if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendORT")) { goto error; } @@ -327,7 +327,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( - unsigned long)dlsym(handle, "RAI_ModelRunORT"); + unsigned long)dlsym(handle, "RAI_ModelRunORT"); if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunORT")) { goto error; } @@ -370,12 +370,12 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { } RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, - backend.enforce_runtime_duration); + backend.enforce_runtime_duration); RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); return REDISMODULE_OK; - error: +error: dlclose(handle); RedisModule_Log(ctx, "warning", "ONNX backend not loaded from %s", path); return REDISMODULE_ERR; diff --git a/src/config/config.c b/src/config/config.c index c8486a7c3..11751748f 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -20,95 +20,112 @@ long long backends_inter_op_parallelism; // number of threads used for parallel // between independent operations. long long model_chunk_size; // size of chunks used to break up model payloads. -long long onnx_max_runtime; // The maximum time in milliseconds - // before killing onnx run session. +long long onnx_max_runtime; // The maximum time in milliseconds + // before killing onnx run session. + +static int _RAIConfig_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val, + RedisModuleString *rsval) { + int ret = REDISMODULE_OK; + if (strcasecmp((key), "TF") == 0) { + ret = RAI_LoadBackend(ctx, RAI_BACKEND_TENSORFLOW, (val)); + } else if (strcasecmp((key), "TFLITE") == 0) { + ret = RAI_LoadBackend(ctx, RAI_BACKEND_TFLITE, (val)); + } else if (strcasecmp((key), "TORCH") == 0) { + ret = RAI_LoadBackend(ctx, RAI_BACKEND_TORCH, (val)); + } else if (strcasecmp((key), "ONNX") == 0) { + ret = RAI_LoadBackend(ctx, RAI_BACKEND_ONNXRUNTIME, (val)); + } + // enable configuring the main thread to create a fixed number of worker + // threads up front per device. by default we'll use 1 + else if (strcasecmp((key), "THREADS_PER_QUEUE") == 0) { + ret = RedisAI_Config_QueueThreads(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_THREADS_PER_QUEUE, (val)); + } + } else if (strcasecmp((key), "INTRA_OP_PARALLELISM") == 0) { + ret = RedisAI_Config_IntraOperationParallelism(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, + getBackendsIntraOpParallelism()); + } + } else if (strcasecmp((key), "INTER_OP_PARALLELISM") == 0) { + ret = RedisAI_Config_InterOperationParallelism(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTER_OP_PARALLELISM, + getBackendsInterOpParallelism()); + } + } else if (strcasecmp((key), "MODEL_CHUNK_SIZE") == 0) { + ret = RedisAI_Config_ModelChunkSize(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, + getModelChunkSize()); + } + } else if (strcasecmp((key), "ONNX_TIMEOUT") == 0) { + ret = RedisAI_Config_OnnxTimeout(rsval); + if (ret == REDISMODULE_OK) { + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_ONNX_TIMEOUT, + GetOnnxTimeout()); + } + } else if (strcasecmp((key), "BACKENDSPATH") == 0) { + // already taken care of + } else { + ret = REDISMODULE_ERR; + } + return ret; +} -/** - * - * @return number of threads used within an individual op for parallelism. - */ long long getBackendsInterOpParallelism() { return backends_inter_op_parallelism; } -/** - * Set number of threads used for parallelism between independent operations, by - * backend. - * - * @param num_threads - * @return 0 on success, or 1 if failed - */ int setBackendsInterOpParallelism(long long num_threads) { - int result = 1; - if (num_threads >= 0) { - backends_inter_op_parallelism = num_threads; - result = 0; + if (num_threads <= 0) { + return REDISMODULE_ERR; } - return result; + backends_inter_op_parallelism = num_threads; + return REDISMODULE_OK; } -/** - * - * @return - */ long long getBackendsIntraOpParallelism() { return backends_intra_op_parallelism; } -/** - * Set number of threads used within an individual op for parallelism, by - * backend. - * - * @param num_threads - * @return 0 on success, or 1 if failed - */ int setBackendsIntraOpParallelism(long long num_threads) { - int result = 1; - if (num_threads >= 0) { - backends_intra_op_parallelism = num_threads; - result = 0; + if (num_threads <= 0) { + return REDISMODULE_ERR; } - return result; + backends_intra_op_parallelism = num_threads; + return REDISMODULE_OK; } -/** - * @return size of chunks (in bytes) in which models are split for - * set, get, serialization and replication. - */ long long getModelChunkSize() { return model_chunk_size; } -/** - * Set size of chunks (in bytes) in which models are split for set, - * get, serialization and replication. - * - * @param size - * @return 0 on success, or 1 if failed - */ int setModelChunkSize(long long size) { - int result = 1; - if (size > 0) { - model_chunk_size = size; - result = 0; + if (size <= 0) { + return REDISMODULE_ERR; } - return result; + model_chunk_size = size; + return REDISMODULE_OK; } -long long GetOnnxTimeout () { return onnx_max_runtime; } +long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } + +int SetNumThreadsPerQueue(long long num_threads) { + if (num_threads <= 0) { + return REDISMODULE_ERR; + } + ThreadPoolSizePerQueue = num_threads; + return REDISMODULE_OK; +} + +long long GetOnnxTimeout() { return onnx_max_runtime; } int SetOnnxTimeout(long long timeout) { - int result = 1; - if (timeout > 0) { - onnx_max_runtime = timeout; - result = 0; + // Timeout should not be lower than the time passing between two consecutive + // runs of Redis cron callback, that is no more than (1/CONFIG_MIN_HZ) + if (timeout < 1000) { + return REDISMODULE_ERR; } - return result; + onnx_max_runtime = timeout; + return REDISMODULE_OK; } -/** - * Helper method for AI.CONFIG LOADBACKEND - * - * - * @param ctx Context in which Redis modules operate - * @param argv Redis command arguments, as an array of strings - * @param argc Redis command number of arguments - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed - */ int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (argc < 3) return RedisModule_WrongArity(ctx); @@ -128,201 +145,85 @@ int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, in } else { return RedisModule_ReplyWithError(ctx, "ERR unsupported backend"); } - if (result == REDISMODULE_OK) { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } - return RedisModule_ReplyWithError(ctx, "ERR error loading backend"); } -/** - * Helper method for AI.CONFIG BACKENDSPATH - * - * - * @param ctx Context in which Redis modules operate - * @param path string containing backend path - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed - */ -int RedisAI_Config_BackendsPath(RedisModuleCtx *ctx, const char *path) { +void RedisAI_Config_BackendsPath(const char *path) { if (RAI_BackendsPath != NULL) { RedisModule_Free(RAI_BackendsPath); } RAI_BackendsPath = RedisModule_Strdup(path); - - return RedisModule_ReplyWithSimpleString(ctx, "OK"); } -/** - * Set number of threads used for parallelism between RedisAI independent - * blocking commands ( AI.DAGRUN, AI.SCRIPTRUN, AI.MODELRUN ). - * - * @param num_threads_string string containing thread number - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string) { - int result = RedisModule_StringToLongLong(num_threads_string, &ThreadPoolSizePerQueue); - // make sure the number of threads is a positive integer - // if not set the value to the default - if (result == REDISMODULE_OK && ThreadPoolSizePerQueue < 1) { - ThreadPoolSizePerQueue = REDISAI_DEFAULT_THREADS_PER_QUEUE; - result = REDISMODULE_ERR; + long long temp; + int result = RedisModule_StringToLongLong(num_threads_string, &temp); + if (result != REDISMODULE_OK) { + return REDISMODULE_ERR; } - return result; + return SetNumThreadsPerQueue(temp); } -/** - * Set number of threads used for parallelism between independent operations, by - * backend. - * - * @param num_threads_string string containing thread number - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ int RedisAI_Config_InterOperationParallelism(RedisModuleString *num_threads_string) { long long temp; int result = RedisModule_StringToLongLong(num_threads_string, &temp); - if (result == REDISMODULE_OK) { - result = setBackendsInterOpParallelism(temp); + if (result != REDISMODULE_OK) { + return REDISMODULE_ERR; } - return result; + return setBackendsInterOpParallelism(temp); } -/** - * Set number of threads used within an individual op for parallelism, by - * backend. - * - * @param num_threads_string string containing thread number - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_string) { long long temp; int result = RedisModule_StringToLongLong(num_threads_string, &temp); - if (result == REDISMODULE_OK) { - result = setBackendsIntraOpParallelism(temp); + if (result != REDISMODULE_OK) { + return REDISMODULE_ERR; } - return result; + return setBackendsIntraOpParallelism(temp); } -/** - * Set size of chunks in which model payloads are split for set, - * get, serialization and replication. - * - * @param chunk_size_string string containing chunk size (in bytes) - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string) { long long temp; int result = RedisModule_StringToLongLong(chunk_size_string, &temp); - // make sure chunk size is a positive integer - // if not set the value to the default - if (result == REDISMODULE_OK && temp < 1) { - temp = REDISAI_DEFAULT_MODEL_CHUNK_SIZE; - result = REDISMODULE_ERR; + if (result != REDISMODULE_OK) { + return REDISMODULE_ERR; } - result = setModelChunkSize(temp); - return result; + return setModelChunkSize(temp); } int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout) { long long temp; int result = RedisModule_StringToLongLong(onnx_timeout, &temp); - // make sure that the timeout is a positive integer, if not set the value to the default. - if (result == REDISMODULE_OK && temp < 1) { - temp = ONNX_DEFAULT_MAX_RUNTIME; - result = REDISMODULE_ERR; - } - result = SetOnnxTimeout(temp); - return result; -} - -/** - * - * @param ctx Context in which Redis modules operate - * @param key - * @param val - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ -int RAI_configParamParse(RedisModuleCtx *ctx, const char *key, const char *val, - RedisModuleString *rsval) { - int ret = REDISMODULE_OK; - if (strcasecmp((key), "TF") == 0) { - ret = RAI_LoadBackend(ctx, RAI_BACKEND_TENSORFLOW, (val)); - } else if (strcasecmp((key), "TFLITE") == 0) { - ret = RAI_LoadBackend(ctx, RAI_BACKEND_TFLITE, (val)); - } else if (strcasecmp((key), "TORCH") == 0) { - ret = RAI_LoadBackend(ctx, RAI_BACKEND_TORCH, (val)); - } else if (strcasecmp((key), "ONNX") == 0) { - ret = RAI_LoadBackend(ctx, RAI_BACKEND_ONNXRUNTIME, (val)); - } - // enable configuring the main thread to create a fixed number of worker - // threads up front per device. by default we'll use 1 - else if (strcasecmp((key), "THREADS_PER_QUEUE") == 0) { - ret = RedisAI_Config_QueueThreads(rsval); - if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_THREADS_PER_QUEUE, (val)); - } - } else if (strcasecmp((key), "INTRA_OP_PARALLELISM") == 0) { - ret = RedisAI_Config_IntraOperationParallelism(rsval); - if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, - getBackendsIntraOpParallelism()); - } - } else if (strcasecmp((key), "INTER_OP_PARALLELISM") == 0) { - ret = RedisAI_Config_InterOperationParallelism(rsval); - if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTER_OP_PARALLELISM, - getBackendsInterOpParallelism()); - } - } else if (strcasecmp((key), "MODEL_CHUNK_SIZE") == 0) { - ret = RedisAI_Config_ModelChunkSize(rsval); - if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, - getModelChunkSize()); - } - } else if (strcasecmp((key), "ONNX_TIMEOUT") == 0) { - ret = RedisAI_Config_OnnxTimeout(rsval); - if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, - GetOnnxTimeout()); - } - } else if (strcasecmp((key), "BACKENDSPATH") == 0) { - // already taken care of - } else { - ret = REDISMODULE_ERR; + if (result != REDISMODULE_OK) { + return REDISMODULE_ERR; } - return ret; + return SetOnnxTimeout(temp); } -/** - * Load time configuration parser - * - * @param ctx Context in which Redis modules operate - * @param argv Redis command arguments, as an array of strings - * @param argc Redis command number of arguments - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed - */ int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc) { if (argc > 0 && argc % 2 != 0) { RedisModule_Log(ctx, "warning", "Even number of arguments provided to module. Please " "provide arguments as KEY VAL pairs"); + return REDISMODULE_ERR; } // need BACKENDSPATH set up before loading specific backends for (int i = 0; i < argc / 2; i++) { const char *key = RedisModule_StringPtrLen(argv[2 * i], NULL); const char *val = RedisModule_StringPtrLen(argv[2 * i + 1], NULL); - - int ret = REDISMODULE_OK; if (strcasecmp(key, "BACKENDSPATH") == 0) { - ret = RedisAI_Config_BackendsPath(ctx, val); + RedisAI_Config_BackendsPath(val); } } for (int i = 0; i < argc / 2; i++) { const char *key = RedisModule_StringPtrLen(argv[2 * i], NULL); const char *val = RedisModule_StringPtrLen(argv[2 * i + 1], NULL); - int ret = RAI_configParamParse(ctx, key, val, argv[2 * i + 1]); + int ret = _RAIConfig_LoadTimeParamParse(ctx, key, val, argv[2 * i + 1]); if (ret == REDISMODULE_ERR) { char *buffer = RedisModule_Alloc( diff --git a/src/config/config.h b/src/config/config.h index af4ad2fea..d17b8c147 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -33,19 +33,17 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; #define REDISAI_INFOMSG_INTRA_OP_PARALLELISM "Setting INTRA_OP_PARALLELISM parameter to" #define REDISAI_INFOMSG_INTER_OP_PARALLELISM "Setting INTER_OP_PARALLELISM parameter to" #define REDISAI_INFOMSG_MODEL_CHUNK_SIZE "Setting MODEL_CHUNK_SIZE parameter to" +#define REDISAI_INFOMSG_ONNX_TIMEOUT "Setting ONNX_TIMEOUT parameter to" /** * Get number of threads used for parallelism between independent operations, by * backend. - * @return number of threads used for parallelism between independent - * operations, by backend */ long long getBackendsInterOpParallelism(); /** * Set number of threads used for parallelism between independent operations, by * backend. - * * @param num_threads * @return 0 on success, or 1 if failed */ @@ -54,15 +52,12 @@ int setBackendsInterOpParallelism(long long num_threads); /** * Get number of threads used within an individual op for parallelism, by * backend. - * @return number of threads used within an individual op for parallelism, by - * backend. */ long long getBackendsIntraOpParallelism(); /** * Set number of threads used within an individual op for parallelism, by * backend. - * * @param num_threads * @return 0 on success, or 1 if failed */ @@ -77,12 +72,36 @@ long long getModelChunkSize(); /** * Set size of chunks (in bytes) in which models are split for set, * get, serialization and replication. - * * @param size * @return 0 on success, or 1 if failed */ int setModelChunkSize(long long size); +/** + * @brief Return the number of working threads per device in RedisAI. + */ +long long GetNumThreadsPerQueue(void); + +/** + * Set the number of working threads per device in RedisAI. + * @param num_threads + * @return 0 on success, or 1 if failed + */ +int SetNumThreadsPerQueue(long long num_threads); + +/** + * @return Number of milliseconds that onnxruntime session is allowed to run + * before killing it + */ +long long GetOnnxTimeout(void); + +/** + * Set the maximal number of milliseconds that onnxruntime session is allowed to run + * @param timeout in ms + * @return 0 on success, or 1 if failed + */ +int SetOnnxTimeout(long long timeout); + /** * Helper method for AI.CONFIG LOADBACKEND * @@ -90,13 +109,8 @@ int setModelChunkSize(long long size); * @param ctx Context in which Redis modules operate * @param argv Redis command arguments, as an array of strings * @param argc Redis command number of arguments - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed + * @return REDISMODULE_OK on success, or REDISMODULE_ERR otherwise. */ - -long long GetOnnxTimeout(void); - -int SetOnnxTimeout(long long timeout); - int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); /** @@ -105,23 +119,20 @@ int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, in * * @param ctx Context in which Redis modules operate * @param path string containing backend path - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed */ -int RedisAI_Config_BackendsPath(RedisModuleCtx *ctx, const char *path); +void RedisAI_Config_BackendsPath(const char *path); /** * Set number of threads used for parallelism between RedisAI independent - * blocking commands ( AI.DAGRUN, AI.SCRIPTRUN, AI.MODELRUN ). - * + * blocking commands (AI.DAGEXECUTE, AI.SCRIPTEXECUTE, AI.MODELEXECUTE). * @param num_threads_string string containing thread number - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string); /** * Set number of threads used for parallelism between independent operations, by * backend. - * * @param num_threads_string string containing thread number * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ @@ -130,7 +141,6 @@ int RedisAI_Config_InterOperationParallelism(RedisModuleString *num_threads_stri /** * Set number of threads used within an individual op for parallelism, by * backend. - * * @param num_threads_string string containing thread number * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ @@ -139,9 +149,8 @@ int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_stri /** * Set size of chunks in which model payloads are split for set, * get, serialization and replication. - * * @param chunk_size_string string containing chunk size (in bytes) - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string); @@ -152,22 +161,11 @@ int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string); */ int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout); -/** - * - * @param ctx Context in which Redis modules operate - * @param key - * @param val - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ -int RAI_configParamParse(RedisModuleCtx *ctx, const char *key, const char *val, - RedisModuleString *rsval); - /** * Load time configuration parser - * * @param ctx Context in which Redis modules operate * @param argv Redis command arguments, as an array of strings * @param argc Redis command number of arguments - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the DAGRUN failed + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc); diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 8a195f6fd..24db22de4 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -102,11 +102,7 @@ bool IsRunQueueExists(const char *device_str) { return true; } -uintptr_t GetThreadId() { - return *(uintptr_t *)pthread_getspecific(ThreadIdKey); -} - -long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } +uintptr_t GetThreadId() { return *(uintptr_t *)pthread_getspecific(ThreadIdKey); } void RunQueueInfoFree(RunQueueInfo *run_queue_info) { RedisModule_Assert(queueLength(run_queue_info->run_queue) == 0); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 2ae800f00..71f979db4 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -42,7 +42,6 @@ typedef struct RunQueueInfo { char *device_str; } RunQueueInfo; - /** * @brief Terminate all working threads and free the run queue with its inner fields. */ @@ -69,8 +68,3 @@ RunQueueInfo *GetRunQueueInfo(const char *device_str); * saved under ThreadIdKey. */ uintptr_t GetThreadId(void); - -/** - * @brief Return the number of working threads per device in RedisAI. - */ -long long GetNumThreadsPerQueue(void); diff --git a/src/execution/onnx_timeout.c b/src/execution/onnx_timeout.c index f8978b988..0d91ff120 100644 --- a/src/execution/onnx_timeout.c +++ b/src/execution/onnx_timeout.c @@ -18,10 +18,10 @@ static long long _mstime(void) { int CreateGlobalOnnxRunSessions() { onnx_global_run_sessions = RedisModule_Alloc(sizeof(struct OnnxGlobalRunSessions)); OnnxRunSessionCtx **onnx_run_sessions = - array_new(OnnxRunSessionCtx *, RedisAI_NumThreadsPerQueue()); + array_new(OnnxRunSessionCtx *, RedisAI_NumThreadsPerQueue()); onnx_global_run_sessions->OnnxRunSessions = onnx_run_sessions; pthread_rwlock_init(&(onnx_global_run_sessions->rwlock), NULL); - return RAI_AddNewDeviceORT("CPU"); // Add entries for CPU threads. + return RAI_AddNewDeviceORT("CPU"); // Add entries for CPU threads. } int RAI_AddNewDeviceORT(const char *device_str) { @@ -30,7 +30,8 @@ int RAI_AddNewDeviceORT(const char *device_str) { pthread_rwlock_wrlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx **run_sessions_array = onnx_global_run_sessions->OnnxRunSessions; - // Extend the array with an entry for every working thread on the new device, initialized to NULL. + // Extend the array with an entry for every working thread on the new device, initialized to + // NULL. size_t size = RedisAI_NumThreadsPerQueue(); for (size_t i = 0; i < size; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); @@ -42,7 +43,7 @@ int RAI_AddNewDeviceORT(const char *device_str) { } void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, - void *data) { + void *data) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx **run_sessions_ctx = onnx_global_run_sessions->OnnxRunSessions; @@ -61,14 +62,12 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void SetRunSessionCtx(OrtRunOptions *new_run_options, - size_t *run_session_index) { +void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index) { pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); // Get the thread index (which is the correspondent index in the global sessions array). *run_session_index = (size_t)RedisAI_ThreadId(); - OnnxRunSessionCtx *entry = - onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; + OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; RedisModule_Assert(entry->runOptions == NULL); // Update the entry with the current session data. diff --git a/src/execution/onnx_timeout.h b/src/execution/onnx_timeout.h index 2084dc150..0d2785d9b 100644 --- a/src/execution/onnx_timeout.h +++ b/src/execution/onnx_timeout.h @@ -39,7 +39,7 @@ int RAI_AddNewDeviceORT(const char *device_str); * those that exceeds the timeout. */ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, - void *data); + void *data); /** * @brief Set a new OrtRunOptions in the global structure, to allow us to diff --git a/src/redisai.c b/src/redisai.c index b8ded0dd9..83a63de40 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -986,7 +986,7 @@ int RedisAI_Config_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i if (!strcasecmp(subcommand, "BACKENDSPATH")) { if (argc > 2) { - return RedisAI_Config_BackendsPath(ctx, RedisModule_StringPtrLen(argv[2], NULL)); + return RedisModule_ReplyWithSimpleString(ctx, "OK"); } else { return RedisModule_ReplyWithError(ctx, "ERR BACKENDSPATH: missing path argument"); } @@ -994,8 +994,11 @@ int RedisAI_Config_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i if (!strcasecmp(subcommand, "MODEL_CHUNK_SIZE")) { if (argc > 2) { - RedisAI_Config_ModelChunkSize(argv[2]); - return RedisModule_ReplyWithSimpleString(ctx, "OK"); + if (RedisAI_Config_ModelChunkSize(argv[2]) == REDISMODULE_OK) { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + } else { + return RedisModule_ReplyWithError(ctx, "ERR MODEL_CHUNK_SIZE: invalid chunk size"); + } } else { return RedisModule_ReplyWithError(ctx, "ERR MODEL_CHUNK_SIZE: missing chunk size"); } @@ -1211,6 +1214,7 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", ThreadPoolSizePerQueue); RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", getBackendsInterOpParallelism()); RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", getBackendsIntraOpParallelism()); + RedisModule_InfoAddFieldLongLong(ctx, "timeout_for_onnxruntime_sessions", GetOnnxTimeout()); RedisModule_InfoAddSection(ctx, "memory_usage"); if (RAI_backends.onnx.get_memory_info) { RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", @@ -1468,7 +1472,9 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) setModelChunkSize(REDISAI_DEFAULT_MODEL_CHUNK_SIZE); SetOnnxTimeout(ONNX_DEFAULT_MAX_RUNTIME); - RAI_loadTimeConfig(ctx, argv, argc); + if (RAI_loadTimeConfig(ctx, argv, argc) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } RunQueues = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); pthread_key_create(&ThreadIdKey, RedisModule_Free); diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 5f722506b..c1d632ed9 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -488,7 +488,7 @@ def test_onnx_kill_switch_basic(env): def test_onnx_kill_switch_multiple_working_threads(): - env = Env(moduleArgs='THREADS_PER_QUEUE 8') + env = Env(moduleArgs='THREADS_PER_QUEUE 8 ONNX_TIMEOUT 1000') con = env.getConnection() model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) From 42059b839e03431d5777ce7e92acb1d2767cb43f Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 6 Jun 2021 23:50:45 +0300 Subject: [PATCH 16/27] Remove redundant include --- src/execution/background_workers.c | 1 - 1 file changed, 1 deletion(-) diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 24db22de4..6ffb3492c 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -19,7 +19,6 @@ #include #include "string_utils.h" #include "backends/backends.h" -#include #include "redisai.h" #include "run_info.h" #include "background_workers.h" From 23749c488a658dc6e680a5201115dd8b83c60675 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 10 Jun 2021 12:40:37 +0300 Subject: [PATCH 17/27] PR fixes part 1: refactor config and run queue info files (and all places that are affected) --- src/CMakeLists.txt | 3 +- src/backends/backends.c | 24 ++- src/backends/backends.h | 1 - src/{execution => backends}/onnx_timeout.c | 48 ++--- src/{execution => backends}/onnx_timeout.h | 2 +- src/backends/onnxruntime.c | 14 +- src/backends/onnxruntime.h | 7 +- src/config/config.c | 190 +++++++----------- src/config/config.h | 93 +++------ src/execution/DAG/dag_execute.c | 3 +- src/execution/background_workers.c | 125 ++---------- src/execution/background_workers.h | 38 +--- src/execution/parsing/deprecated.c | 9 +- src/execution/run_queue_info.c | 80 ++++++++ src/execution/run_queue_info.h | 42 ++++ src/execution/utils.c | 3 +- src/execution/utils.h | 1 + src/redisai.c | 48 ++--- src/serialization/AOF/rai_aof_rewrite.c | 2 +- .../RDB/decoder/current/v2/decode_v2.c | 4 +- .../RDB/decoder/previous/v0/decode_v0.c | 4 +- .../RDB/decoder/previous/v1/decode_v1.c | 4 +- src/serialization/RDB/encoder/v2/encode_v2.c | 2 +- src/util/queue.c | 2 +- src/util/queue.h | 2 +- tests/flow/tests_onnx.py | 2 +- 26 files changed, 346 insertions(+), 407 deletions(-) rename src/{execution => backends}/onnx_timeout.c (71%) rename src/{execution => backends}/onnx_timeout.h (97%) create mode 100644 src/execution/run_queue_info.c create mode 100644 src/execution/run_queue_info.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b6e6dcbe9..ddcb1340d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,7 @@ ADD_LIBRARY(redisai_obj OBJECT execution/parsing/parse_utils.c execution/run_info.c execution/background_workers.c + execution/run_queue_info.c execution/utils.c config/config.c execution/DAG/dag.c @@ -88,7 +89,7 @@ ENDIF() IF(BUILD_ORT) ADD_LIBRARY(redisai_onnxruntime_obj OBJECT backends/onnxruntime.c - execution/onnx_timeout.c + backends/onnx_timeout.c ${BACKEND_COMMON_SRC} ) ENDIF() diff --git a/src/backends/backends.c b/src/backends/backends.c index 0bd3be509..cf794c619 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -14,9 +14,11 @@ #include #include #include -#include - #include "redismodule.h" +#include "config/config.h" +#include "execution/background_workers.h" + + static bool _ValidateAPICreated(RedisModuleCtx *ctx, void *func_ptr, const char *func_name) { if (func_ptr == NULL) { @@ -28,12 +30,14 @@ static bool _ValidateAPICreated(RedisModuleCtx *ctx, void *func_ptr, const char int RAI_GetApi(const char *func_name, void **targetPtrPtr) { - if (strcmp("ThreadIdKey", func_name) == 0) { - *targetPtrPtr = GetThreadId; - } else if (strcmp("NumThreadsPerQueue", func_name) == 0) { - *targetPtrPtr = GetNumThreadsPerQueue; - } else if (strcmp("OnnxTimeout", func_name) == 0) { - *targetPtrPtr = GetOnnxTimeout; + if (strcmp("GetThreadId", func_name) == 0) { + *targetPtrPtr = BGWorker_GetThreadId; + } else if (strcmp("GetNumThreadsPerQueue", func_name) == 0) { + *targetPtrPtr = Config_GetNumThreadsPerQueue; + } else if (strcmp("GetModelExecutionTimeout", func_name) == 0) { + *targetPtrPtr = Config_GetModelExecutionTimeout; + } else if (strcmp("GetThreadsCount", func_name) == 0) { + *targetPtrPtr = BGWorker_GetThreadsCount; } else { return RedisModule_GetApi(func_name, targetPtrPtr); } @@ -56,8 +60,8 @@ RedisModuleString *RAI_GetModulePath(RedisModuleCtx *ctx) { RedisModuleString *RAI_GetBackendsPath(RedisModuleCtx *ctx) { Dl_info info; RedisModuleString *backends_path = NULL; - if (RAI_BackendsPath != NULL) { - backends_path = RedisModule_CreateString(ctx, RAI_BackendsPath, strlen(RAI_BackendsPath)); + if (Config_GetBackendsPath() != NULL) { + backends_path = RedisModule_CreateString(ctx, Config_GetBackendsPath(), strlen(Config_GetBackendsPath())); } else { RedisModuleString *module_path = RAI_GetModulePath(ctx); backends_path = RedisModule_CreateStringPrintf(ctx, "%s/backends", diff --git a/src/backends/backends.h b/src/backends/backends.h index 0989db42c..01ccef8b2 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -98,7 +98,6 @@ typedef struct RAI_LoadedBackends { } RAI_LoadedBackends; RAI_LoadedBackends RAI_backends; -char *RAI_BackendsPath; int RAI_LoadBackend(RedisModuleCtx *ctx, int backend, const char *path); int RAI_LoadDefaultBackend(RedisModuleCtx *ctx, int backend); diff --git a/src/execution/onnx_timeout.c b/src/backends/onnx_timeout.c similarity index 71% rename from src/execution/onnx_timeout.c rename to src/backends/onnx_timeout.c index 0d91ff120..7ce577c8a 100644 --- a/src/execution/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -1,27 +1,29 @@ #include "onnx_timeout.h" #include "util/arr.h" -#include +#include "util.h" +#include "execution/utils.h" +#include "config/config.h" #include #include "util/string_utils.h" +#include "redis_ai_objects/stats.h" -// Gets the current time in milliseconds. -static long long _mstime(void) { - struct timeval tv; - long long ust; - - gettimeofday(&tv, NULL); - ust = ((long long)tv.tv_sec) * 1000000; - ust += tv.tv_usec; - return ust / 1000; -} int CreateGlobalOnnxRunSessions() { onnx_global_run_sessions = RedisModule_Alloc(sizeof(struct OnnxGlobalRunSessions)); - OnnxRunSessionCtx **onnx_run_sessions = - array_new(OnnxRunSessionCtx *, RedisAI_NumThreadsPerQueue()); - onnx_global_run_sessions->OnnxRunSessions = onnx_run_sessions; + + // Initialize the array with entries number equals to the number of currently + // working threads in RedisAI (note that CPU threads must exist form the start). + size_t RAI_working_threads_num = RedisAI_GetThreadsCount(); + OnnxRunSessionCtx **run_sessions_array = + array_new(OnnxRunSessionCtx *, RAI_working_threads_num); + for (size_t i = 0; i < RAI_working_threads_num; i++) { + OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); + run_sessions_array = array_append(run_sessions_array, entry); + } + onnx_global_run_sessions->OnnxRunSessions = run_sessions_array; pthread_rwlock_init(&(onnx_global_run_sessions->rwlock), NULL); - return RAI_AddNewDeviceORT("CPU"); // Add entries for CPU threads. + + return REDISMODULE_OK; } int RAI_AddNewDeviceORT(const char *device_str) { @@ -30,9 +32,9 @@ int RAI_AddNewDeviceORT(const char *device_str) { pthread_rwlock_wrlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx **run_sessions_array = onnx_global_run_sessions->OnnxRunSessions; - // Extend the array with an entry for every working thread on the new device, initialized to - // NULL. - size_t size = RedisAI_NumThreadsPerQueue(); + // Extend the array with an entry for every working thread on the new device, + // initialized to NULL. + size_t size = RedisAI_GetNumThreadsPerQueue(); for (size_t i = 0; i < size; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); run_sessions_array = array_append(run_sessions_array, entry); @@ -52,8 +54,8 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s if (run_sessions_ctx[i]->runOptions == NULL) { continue; } - long long curr_time = _mstime(); - long long timeout = RedisAI_OnnxTimeout(); + long long curr_time = mstime(); + long long timeout = RedisAI_GetModelExecutionTimeout(); // Check if a sessions is running for too long, and kill it if so. if (curr_time - run_sessions_ctx[i]->queuingTime > timeout) { ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions); @@ -66,17 +68,17 @@ void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index) pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); // Get the thread index (which is the correspondent index in the global sessions array). - *run_session_index = (size_t)RedisAI_ThreadId(); + *run_session_index = RedisAI_GetThreadId(); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; RedisModule_Assert(entry->runOptions == NULL); // Update the entry with the current session data. entry->runOptions = new_run_options; - entry->queuingTime = _mstime(); + entry->queuingTime = mstime(); pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void ClearRunSessionCtx(size_t run_session_index) { +void InvalidateRunSessionCtx(size_t run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; diff --git a/src/execution/onnx_timeout.h b/src/backends/onnx_timeout.h similarity index 97% rename from src/execution/onnx_timeout.h rename to src/backends/onnx_timeout.h index 0d2785d9b..c16671f2a 100644 --- a/src/execution/onnx_timeout.h +++ b/src/backends/onnx_timeout.h @@ -55,4 +55,4 @@ void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index) * reset the corresponding entry in the global structure. * @param run_session_index - The entry index where OrtRunOptions was stored. */ -void ClearRunSessionCtx(size_t run_session_index); +void InvalidateRunSessionCtx(size_t run_session_index); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index d0149938b..482d6fcfd 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -2,7 +2,7 @@ #include #include "backends/util.h" #include -#include +#include #include "execution/background_workers.h" #include #include @@ -94,10 +94,10 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) { get_api_fn("RedisModule_MallocSize", ((void **)&RedisModule_MallocSize)); // Export RedisAI callbacks. - get_api_fn("ThreadIdKey", ((void **)&RedisAI_ThreadId)); - get_api_fn("NumThreadsPerQueue", ((void **)&RedisAI_NumThreadsPerQueue)); - get_api_fn("OnnxTimeout", ((void **)&RedisAI_OnnxTimeout)); - + get_api_fn("GetThreadId", ((void **)&RedisAI_GetThreadId)); + get_api_fn("GetNumThreadsPerQueue", ((void **)&RedisAI_GetNumThreadsPerQueue)); + get_api_fn("GetModelExecutionTimeout", ((void **)&RedisAI_GetModelExecutionTimeout)); + get_api_fn("GetThreadsCount", ((void **)&RedisAI_GetThreadsCount)); // Create a global array of onnx runSessions, with an entry for every working thread. CreateGlobalOnnxRunSessions(); @@ -586,7 +586,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - ClearRunSessionCtx(run_session_index); + InvalidateRunSessionCtx(run_session_index); run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { @@ -674,7 +674,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ort->ReleaseTensorTypeAndShapeInfo(info); } if (run_options) { - ClearRunSessionCtx(run_session_index); + InvalidateRunSessionCtx(run_session_index); } return REDISMODULE_ERR; } diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index 9cff6aecd..2f7da40fc 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -9,9 +9,10 @@ unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); -pthread_key_t (*RedisAI_ThreadId)(void); -long long (*RedisAI_NumThreadsPerQueue)(void); -long long (*RedisAI_OnnxTimeout)(void); +uintptr_t (*RedisAI_GetThreadId)(void); +uintptr_t (*RedisAI_GetThreadsCount)(void); +long long (*RedisAI_GetNumThreadsPerQueue)(void); +long long (*RedisAI_GetModelExecutionTimeout)(void); int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)); diff --git a/src/config/config.c b/src/config/config.c index 11751748f..7a205467e 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -1,31 +1,27 @@ #include "config.h" - -#include #include -#include -#include - #include "redismodule.h" -#include "rmutil/alloc.h" -#include "backends/util.h" #include "backends/backends.h" -#include "util/dict.h" -#include "util/queue.h" -#include "util/arr.h" -#include "execution/background_workers.h" -long long backends_intra_op_parallelism; // number of threads used within an - // individual op for parallelism. -long long backends_inter_op_parallelism; // number of threads used for parallelism - // between independent operations. -long long model_chunk_size; // size of chunks used to break up model payloads. +// Default configs +char *BackendsPath = NULL; // Path to backends dir. + +long long BackendsIntraOpParallelism = 0; // number of threads used within an + // individual op for parallelism. +long long BackendsInterOpParallelism = 0; // number of threads used for parallelism + // between independent operations. +long long ModelChunkSize = 535822336; // size of chunks used to break up model payloads. + // default is 511 * 1024 * 1024 +long long ThreadPoolSizePerQueue = 1; // Number of working threads for device. -long long onnx_max_runtime; // The maximum time in milliseconds - // before killing onnx run session. +long long ModelExecutionTimeout = 5000; // The maximum time in milliseconds + // before killing onnx run session. -static int _RAIConfig_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val, + +static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val, RedisModuleString *rsval) { int ret = REDISMODULE_OK; + long long param_val; if (strcasecmp((key), "TF") == 0) { ret = RAI_LoadBackend(ctx, RAI_BACKEND_TENSORFLOW, (val)); } else if (strcasecmp((key), "TFLITE") == 0) { @@ -38,33 +34,33 @@ static int _RAIConfig_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, c // enable configuring the main thread to create a fixed number of worker // threads up front per device. by default we'll use 1 else if (strcasecmp((key), "THREADS_PER_QUEUE") == 0) { - ret = RedisAI_Config_QueueThreads(rsval); + ret = Config_SetQueueThreadsNum(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_THREADS_PER_QUEUE, (val)); + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_THREADS_PER_QUEUE, (param_val)); } } else if (strcasecmp((key), "INTRA_OP_PARALLELISM") == 0) { - ret = RedisAI_Config_IntraOperationParallelism(rsval); + ret = Config_SetIntraOperationParallelism(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, - getBackendsIntraOpParallelism()); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, + val); } } else if (strcasecmp((key), "INTER_OP_PARALLELISM") == 0) { - ret = RedisAI_Config_InterOperationParallelism(rsval); + ret = Config_SetInterOperationParallelism(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_INTER_OP_PARALLELISM, - getBackendsInterOpParallelism()); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTER_OP_PARALLELISM, + val); } } else if (strcasecmp((key), "MODEL_CHUNK_SIZE") == 0) { - ret = RedisAI_Config_ModelChunkSize(rsval); + ret = Config_SetModelChunkSize(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, - getModelChunkSize()); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, + val); } - } else if (strcasecmp((key), "ONNX_TIMEOUT") == 0) { - ret = RedisAI_Config_OnnxTimeout(rsval); + } else if (strcasecmp((key), "MODEL_EXECUTION_TIMEOUT") == 0) { + ret = Config_SetModelExecutionTimeout(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_ONNX_TIMEOUT, - GetOnnxTimeout()); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT, + val); } } else if (strcasecmp((key), "BACKENDSPATH") == 0) { // already taken care of @@ -74,59 +70,19 @@ static int _RAIConfig_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, c return ret; } -long long getBackendsInterOpParallelism() { return backends_inter_op_parallelism; } - -int setBackendsInterOpParallelism(long long num_threads) { - if (num_threads <= 0) { - return REDISMODULE_ERR; - } - backends_inter_op_parallelism = num_threads; - return REDISMODULE_OK; -} - -long long getBackendsIntraOpParallelism() { return backends_intra_op_parallelism; } +long long Config_GetBackendsInterOpParallelism() { return BackendsInterOpParallelism; } -int setBackendsIntraOpParallelism(long long num_threads) { - if (num_threads <= 0) { - return REDISMODULE_ERR; - } - backends_intra_op_parallelism = num_threads; - return REDISMODULE_OK; -} - -long long getModelChunkSize() { return model_chunk_size; } - -int setModelChunkSize(long long size) { - if (size <= 0) { - return REDISMODULE_ERR; - } - model_chunk_size = size; - return REDISMODULE_OK; -} +long long Config_GetBackendsIntraOpParallelism() { return BackendsIntraOpParallelism; } -long long GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } +long long Config_GetModelChunkSize() { return ModelChunkSize; } -int SetNumThreadsPerQueue(long long num_threads) { - if (num_threads <= 0) { - return REDISMODULE_ERR; - } - ThreadPoolSizePerQueue = num_threads; - return REDISMODULE_OK; -} +long long Config_GetNumThreadsPerQueue() { return ThreadPoolSizePerQueue; } -long long GetOnnxTimeout() { return onnx_max_runtime; } +long long Config_GetModelExecutionTimeout() { return ModelExecutionTimeout; } -int SetOnnxTimeout(long long timeout) { - // Timeout should not be lower than the time passing between two consecutive - // runs of Redis cron callback, that is no more than (1/CONFIG_MIN_HZ) - if (timeout < 1000) { - return REDISMODULE_ERR; - } - onnx_max_runtime = timeout; - return REDISMODULE_OK; -} +char *Config_GetBackendsPath() { return BackendsPath; } -int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { +int Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (argc < 3) return RedisModule_WrongArity(ctx); @@ -151,59 +107,66 @@ int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, in return RedisModule_ReplyWithError(ctx, "ERR error loading backend"); } -void RedisAI_Config_BackendsPath(const char *path) { - if (RAI_BackendsPath != NULL) { - RedisModule_Free(RAI_BackendsPath); +void Config_SetBackendsPath(const char *path) { + if (BackendsPath != NULL) { + RedisModule_Free(BackendsPath); } - RAI_BackendsPath = RedisModule_Strdup(path); + BackendsPath = RedisModule_Strdup(path); } -int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string) { - long long temp; - int result = RedisModule_StringToLongLong(num_threads_string, &temp); - if (result != REDISMODULE_OK) { +int Config_SetQueueThreadsNum(RedisModuleString *num_threads_string) { + long long val; + int result = RedisModule_StringToLongLong(num_threads_string, &val); + if (result != REDISMODULE_OK || val <= 0) { return REDISMODULE_ERR; } - return SetNumThreadsPerQueue(temp); + ThreadPoolSizePerQueue = val; + return REDISMODULE_OK; } -int RedisAI_Config_InterOperationParallelism(RedisModuleString *num_threads_string) { - long long temp; - int result = RedisModule_StringToLongLong(num_threads_string, &temp); - if (result != REDISMODULE_OK) { +int Config_SetInterOperationParallelism(RedisModuleString *num_threads_string) { + long long val; + int result = RedisModule_StringToLongLong(num_threads_string, &val); + if (result != REDISMODULE_OK || val <= 0) { return REDISMODULE_ERR; } - return setBackendsInterOpParallelism(temp); + BackendsInterOpParallelism = val; + return REDISMODULE_OK; } -int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_string) { - long long temp; - int result = RedisModule_StringToLongLong(num_threads_string, &temp); - if (result != REDISMODULE_OK) { +int Config_SetIntraOperationParallelism(RedisModuleString *num_threads_string) { + long long val; + int result = RedisModule_StringToLongLong(num_threads_string, &val); + if (result != REDISMODULE_OK || val <= 0) { return REDISMODULE_ERR; } - return setBackendsIntraOpParallelism(temp); + BackendsIntraOpParallelism = val; + return REDISMODULE_OK; } -int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string) { - long long temp; - int result = RedisModule_StringToLongLong(chunk_size_string, &temp); - if (result != REDISMODULE_OK) { +int Config_SetModelChunkSize(RedisModuleString *chunk_size_string) { + long long val; + int result = RedisModule_StringToLongLong(chunk_size_string, &val); + if (result != REDISMODULE_OK || val <= 0) { return REDISMODULE_ERR; } - return setModelChunkSize(temp); + ModelChunkSize = val; + return REDISMODULE_OK; } -int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout) { - long long temp; - int result = RedisModule_StringToLongLong(onnx_timeout, &temp); - if (result != REDISMODULE_OK) { +int Config_SetModelExecutionTimeout(RedisModuleString *timeout) { + long long val; + int result = RedisModule_StringToLongLong(timeout, &val); + // Timeout should not be lower than the time passing between two consecutive + // runs of Redis cron callback, that is no more than (1/CONFIG_MIN_HZ) + if (result != REDISMODULE_OK || val < 1000) { return REDISMODULE_ERR; } - return SetOnnxTimeout(temp); + ModelExecutionTimeout = val; + return REDISMODULE_OK; } -int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc) { +int Config_SetLoadTimeParams(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc) { if (argc > 0 && argc % 2 != 0) { RedisModule_Log(ctx, "warning", "Even number of arguments provided to module. Please " @@ -216,14 +179,14 @@ int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int const char *key = RedisModule_StringPtrLen(argv[2 * i], NULL); const char *val = RedisModule_StringPtrLen(argv[2 * i + 1], NULL); if (strcasecmp(key, "BACKENDSPATH") == 0) { - RedisAI_Config_BackendsPath(val); + Config_SetBackendsPath(val); } } for (int i = 0; i < argc / 2; i++) { const char *key = RedisModule_StringPtrLen(argv[2 * i], NULL); const char *val = RedisModule_StringPtrLen(argv[2 * i + 1], NULL); - int ret = _RAIConfig_LoadTimeParamParse(ctx, key, val, argv[2 * i + 1]); + int ret = _Config_LoadTimeParamParse(ctx, key, val, argv[2 * i + 1]); if (ret == REDISMODULE_ERR) { char *buffer = RedisModule_Alloc( @@ -235,6 +198,5 @@ int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int return ret; } } - return REDISMODULE_OK; } diff --git a/src/config/config.h b/src/config/config.h index d17b8c147..78cfb16ee 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -16,91 +16,50 @@ typedef enum { typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; -//#define RAI_COPY_RUN_INPUT #define RAI_COPY_RUN_OUTPUT #define RAI_PRINT_BACKEND_ERRORS -#define REDISAI_DEFAULT_THREADS_PER_QUEUE 1 -#define REDISAI_DEFAULT_INTRA_OP_PARALLELISM 0 -#define REDISAI_DEFAULT_INTER_OP_PARALLELISM 0 -#define REDISAI_DEFAULT_MODEL_CHUNK_SIZE 535822336 // (511 * 1024 * 1024) -#define ONNX_DEFAULT_MAX_RUNTIME 5000 -#define REDISAI_ERRORMSG_PROCESSING_ARG "ERR error processing argument" -#define REDISAI_ERRORMSG_THREADS_PER_QUEUE "ERR error setting THREADS_PER_QUEUE to" -#define REDISAI_ERRORMSG_INTRA_OP_PARALLELISM "ERR error setting INTRA_OP_PARALLELISM to" -#define REDISAI_ERRORMSG_INTER_OP_PARALLELISM "ERR error setting INTER_OP_PARALLELISM to" - -#define REDISAI_INFOMSG_THREADS_PER_QUEUE "Setting THREADS_PER_QUEUE parameter to" -#define REDISAI_INFOMSG_INTRA_OP_PARALLELISM "Setting INTRA_OP_PARALLELISM parameter to" -#define REDISAI_INFOMSG_INTER_OP_PARALLELISM "Setting INTER_OP_PARALLELISM parameter to" -#define REDISAI_INFOMSG_MODEL_CHUNK_SIZE "Setting MODEL_CHUNK_SIZE parameter to" -#define REDISAI_INFOMSG_ONNX_TIMEOUT "Setting ONNX_TIMEOUT parameter to" -/** - * Get number of threads used for parallelism between independent operations, by - * backend. - */ -long long getBackendsInterOpParallelism(); +#define REDISAI_ERRORMSG_PROCESSING_ARG "ERR error processing argument" -/** - * Set number of threads used for parallelism between independent operations, by - * backend. - * @param num_threads - * @return 0 on success, or 1 if failed - */ -int setBackendsInterOpParallelism(long long num_threads); +#define REDISAI_INFOMSG_THREADS_PER_QUEUE "Setting THREADS_PER_QUEUE parameter to" +#define REDISAI_INFOMSG_INTRA_OP_PARALLELISM "Setting INTRA_OP_PARALLELISM parameter to" +#define REDISAI_INFOMSG_INTER_OP_PARALLELISM "Setting INTER_OP_PARALLELISM parameter to" +#define REDISAI_INFOMSG_MODEL_CHUNK_SIZE "Setting MODEL_CHUNK_SIZE parameter to" +#define REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT "Setting MODEL_EXECUTION_TIMEOUT parameter to" /** - * Get number of threads used within an individual op for parallelism, by + * Get number of threads used for parallelism between independent operations, by * backend. */ -long long getBackendsIntraOpParallelism(); +long long Config_GetBackendsInterOpParallelism(void); /** - * Set number of threads used within an individual op for parallelism, by + * Get number of threads used within an individual op for parallelism, by * backend. - * @param num_threads - * @return 0 on success, or 1 if failed */ -int setBackendsIntraOpParallelism(long long num_threads); +long long Config_GetBackendsIntraOpParallelism(void); /** * @return size of chunks (in bytes) in which models are split for * set, get, serialization and replication. */ -long long getModelChunkSize(); - -/** - * Set size of chunks (in bytes) in which models are split for set, - * get, serialization and replication. - * @param size - * @return 0 on success, or 1 if failed - */ -int setModelChunkSize(long long size); +long long Config_GetModelChunkSize(void); /** * @brief Return the number of working threads per device in RedisAI. */ -long long GetNumThreadsPerQueue(void); +long long Config_GetNumThreadsPerQueue(void); /** - * Set the number of working threads per device in RedisAI. - * @param num_threads - * @return 0 on success, or 1 if failed + * @return Number of milliseconds that a model session is allowed to run + * before killing it. Currently supported only for onnxruntime backend. */ -int SetNumThreadsPerQueue(long long num_threads); +long long Config_GetModelExecutionTimeout(void); /** - * @return Number of milliseconds that onnxruntime session is allowed to run - * before killing it + * @return Returns the backends path string. */ -long long GetOnnxTimeout(void); - -/** - * Set the maximal number of milliseconds that onnxruntime session is allowed to run - * @param timeout in ms - * @return 0 on success, or 1 if failed - */ -int SetOnnxTimeout(long long timeout); +char *Config_GetBackendsPath(void); /** * Helper method for AI.CONFIG LOADBACKEND @@ -111,16 +70,14 @@ int SetOnnxTimeout(long long timeout); * @param argc Redis command number of arguments * @return REDISMODULE_OK on success, or REDISMODULE_ERR otherwise. */ -int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); +int Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); /** * Helper method for AI.CONFIG BACKENDSPATH * - * - * @param ctx Context in which Redis modules operate * @param path string containing backend path */ -void RedisAI_Config_BackendsPath(const char *path); +void Config_SetBackendsPath(const char *path); /** * Set number of threads used for parallelism between RedisAI independent @@ -128,7 +85,7 @@ void RedisAI_Config_BackendsPath(const char *path); * @param num_threads_string string containing thread number * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string); +int Config_SetQueueThreadsNum(RedisModuleString *num_threads_string); /** * Set number of threads used for parallelism between independent operations, by @@ -136,7 +93,7 @@ int RedisAI_Config_QueueThreads(RedisModuleString *num_threads_string); * @param num_threads_string string containing thread number * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RedisAI_Config_InterOperationParallelism(RedisModuleString *num_threads_string); +int Config_SetInterOperationParallelism(RedisModuleString *num_threads_string); /** * Set number of threads used within an individual op for parallelism, by @@ -144,7 +101,7 @@ int RedisAI_Config_InterOperationParallelism(RedisModuleString *num_threads_stri * @param num_threads_string string containing thread number * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_string); +int Config_SetIntraOperationParallelism(RedisModuleString *num_threads_string); /** * Set size of chunks in which model payloads are split for set, @@ -152,14 +109,14 @@ int RedisAI_Config_IntraOperationParallelism(RedisModuleString *num_threads_stri * @param chunk_size_string string containing chunk size (in bytes) * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RedisAI_Config_ModelChunkSize(RedisModuleString *chunk_size_string); +int Config_SetModelChunkSize(RedisModuleString *chunk_size_string); /** * Set the maximum time in ms that onnx backend allow running a model. * @param onnx_max_runtime - string containing the max runtime (in ms) * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout); +int Config_SetModelExecutionTimeout(RedisModuleString *timeout); /** * Load time configuration parser @@ -168,4 +125,4 @@ int RedisAI_Config_OnnxTimeout(RedisModuleString *onnx_timeout); * @param argc Redis command number of arguments * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ -int RAI_loadTimeConfig(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc); +int Config_SetLoadTimeParams(RedisModuleCtx *ctx, RedisModuleString *const *argv, int argc); diff --git a/src/execution/DAG/dag_execute.c b/src/execution/DAG/dag_execute.c index 1b3e88d49..9a6dd2454 100644 --- a/src/execution/DAG/dag_execute.c +++ b/src/execution/DAG/dag_execute.c @@ -1,4 +1,5 @@ #include +#include #include "dag_execute.h" #include "util/string_utils.h" #include "execution/run_info.h" @@ -106,7 +107,7 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo) { RunQueueInfo **run_queues_info = array_new(RunQueueInfo *, ndevices); for (long long i = 0; i < ndevices; i++) { const char *device_str = devices[i]; - RunQueueInfo *run_queue_info = GetRunQueueInfo(device_str); + RunQueueInfo *run_queue_info = RunQueue_GetInfo(device_str); run_queues_info = array_append(run_queues_info, run_queue_info); } for (long long i = 0; i < ndevices; i++) { diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index 6ffb3492c..c6498bd4e 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -1,28 +1,7 @@ -/** - * background_workers.c - * - * Contains the structure to manage the per-device queues, used for decoupling - * the work from the main thread to the background worker threads. For each of - * the incoming ModelRun, ScriptRun, and DagRun commands, the request is queued - * and evaded asynchronously to one the device queues. - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "string_utils.h" -#include "backends/backends.h" -#include "redisai.h" +#include "sys/time.h" #include "run_info.h" -#include "background_workers.h" -#include "onnx_timeout.h" +#include "run_queue_info.h" +#include "execution/DAG/dag.h" /* Define for RedisAI thread name setter */ #ifdef __linux__ @@ -42,90 +21,17 @@ int pthread_setname_np(const char *name); #endif #endif -void *RedisAI_Run_ThreadMain(void *arg); - -RunQueueInfo *CreateRunQueue(const char *device_str) { - - size_t device_str_len = strlen(device_str); - char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); - - // Create new run queue and initialize its inner fields. - RunQueueInfo *run_queue_info = RedisModule_Alloc(sizeof(RunQueueInfo)); - run_queue_info->run_queue = queueCreate(); - run_queue_info->device_str = RedisModule_Strdup(upper_device_str); - pthread_cond_init(&(run_queue_info->queue_condition_var), NULL); - pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); - run_queue_info->threads = - (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * ThreadPoolSizePerQueue); - - // Save device with its associate run queue info in the dictionary. - if (AI_dictAdd(RunQueues, upper_device_str, run_queue_info) != DICT_OK) { - RunQueueInfoFree(run_queue_info); - return NULL; - } - - // Create worker threads. - for (int i = 0; i < ThreadPoolSizePerQueue; i++) { - if (pthread_create(&(run_queue_info->threads[i]), NULL, RedisAI_Run_ThreadMain, - run_queue_info) != 0) { - AI_dictDelete(RunQueues, upper_device_str); - RunQueueInfoFree(run_queue_info); - return NULL; - } - } - - // Add the new device worker threads to onnx run sessions monitoring. - if (RAI_backends.onnx.add_new_device) { - RAI_backends.onnx.add_new_device(device_str); - } - return run_queue_info; -} - -RunQueueInfo *GetRunQueueInfo(const char *device_str) { - size_t device_str_len = strlen(device_str); - char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); - AI_dictEntry *entry = AI_dictFind(RunQueues, upper_device_str); - RedisModule_Assert(entry != NULL); - return AI_dictGetVal(entry); -} - -bool IsRunQueueExists(const char *device_str) { - size_t device_str_len = strlen(device_str); - char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); - if (AI_dictFind(RunQueues, upper_device_str) == NULL) { - return false; - } - return true; -} - -uintptr_t GetThreadId() { return *(uintptr_t *)pthread_getspecific(ThreadIdKey); } - -void RunQueueInfoFree(RunQueueInfo *run_queue_info) { - RedisModule_Assert(queueLength(run_queue_info->run_queue) == 0); - RedisModule_Free(run_queue_info->run_queue); - RedisModule_Free(run_queue_info->device_str); - - // Wait for workers to exit and free the pool. - for (int i = 0; i < ThreadPoolSizePerQueue; i++) { - RedisModule_Assert(pthread_join(run_queue_info->threads[i], NULL) == 0); - RedisModule_Free(run_queue_info->threads); - } - pthread_mutex_destroy(&(run_queue_info->run_queue_mutex)); - pthread_cond_destroy(&(run_queue_info->queue_condition_var)); - RedisModule_Free(run_queue_info); -} +uintptr_t BGWorkersCounter; // Total number of BG threads running currently. +pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage. /** * @brief Save the id for some working thread in thread local storage. */ -static void _SaveThreadId() { - uintptr_t *id_value = RedisModule_Alloc(sizeof(uintptr_t)); +static void _BGWorker_SaveThreadId() { // Let the current thread have the next available id, and increase the counter. - *id_value = __atomic_fetch_add(&BGWorkersCounter, 1, __ATOMIC_RELAXED); - pthread_setspecific(ThreadIdKey, id_value); + uintptr_t id_value = __atomic_fetch_add(&BGWorkersCounter, 1, __ATOMIC_RELAXED); + // Convert the id value to a pointer and store it the thread local storage. + pthread_setspecific(ThreadIdKey, (const void *)id_value); } /** @@ -317,10 +223,17 @@ static RedisAI_RunInfo **_BGThread_BatchOperations(RunQueueInfo *run_queue_info, return batch_rinfo; } -void *RedisAI_Run_ThreadMain(void *arg) { - _SaveThreadId(); - RunQueueInfo *run_queue_info = (RunQueueInfo *)arg; +uintptr_t BGWorker_GetThreadId() { + return (uintptr_t)pthread_getspecific(ThreadIdKey); +} +uintptr_t BGWorker_GetThreadsCount() { + return BGWorkersCounter; +} + +void *BGWorker_ThreadMain(void *arg) { + _BGWorker_SaveThreadId(); + RunQueueInfo *run_queue_info = (RunQueueInfo *)arg; RedisAI_RunInfo **batch_rinfo = array_new(RedisAI_RunInfo *, 1); pthread_mutex_lock(&run_queue_info->run_queue_mutex); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 71f979db4..0e8948c2a 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -29,42 +29,20 @@ #include "util/arr.h" #include "util/queue.h" -AI_dict *RunQueues; -long long ThreadPoolSizePerQueue; // Number of working threads for device. -uintptr_t BGWorkersCounter; // Total number of BG threads running currently. -pthread_key_t ThreadIdKey; // Holds the thread id in its local storage. - -typedef struct RunQueueInfo { - pthread_mutex_t run_queue_mutex; - pthread_cond_t queue_condition_var; - queue *run_queue; - pthread_t *threads; - char *device_str; -} RunQueueInfo; - -/** - * @brief Terminate all working threads and free the run queue with its inner fields. - */ -void RunQueueInfoFree(RunQueueInfo *info); - -/** - * @brief Create a new run queue for a device. - */ -RunQueueInfo *CreateRunQueue(const char *device_str); /** - * @brief Return true if a ru queue exists for this particular device. + * @brief RedisAI main loop for every background working thread + * @param arg - This is the run queue info of the device on which this thread is + * running the AI model/script */ -bool IsRunQueueExists(const char *device_str); +void *BGWorker_ThreadMain(void *arg); /** - * @brief Return the RunQueueInfo saved in the global RunQueues dict for a certain - * device name, or NULL if doesn't exist. + * @brief Returns the thread id (among RedisAI working threads). */ -RunQueueInfo *GetRunQueueInfo(const char *device_str); +uintptr_t BGWorker_GetThreadId(void); /** - * @brief Return the thread id from its local storage by accessing the value - * saved under ThreadIdKey. + * @brief Returns the total number of RedisAI working threads (for all devices). */ -uintptr_t GetThreadId(void); +uintptr_t BGWorker_GetThreadsCount(void); \ No newline at end of file diff --git a/src/execution/parsing/deprecated.c b/src/execution/parsing/deprecated.c index e0b9ca302..e1d77bbb8 100644 --- a/src/execution/parsing/deprecated.c +++ b/src/execution/parsing/deprecated.c @@ -1,4 +1,5 @@ +#include #include "deprecated.h" #include "rmutil/args.h" #include "backends/backends.h" @@ -236,8 +237,8 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { .batchsize = batchsize, .minbatchsize = minbatchsize, .minbatchtimeout = minbatchtimeout, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), + .backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(), + .backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(), }; RAI_Model *model = NULL; @@ -305,8 +306,8 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } // TODO: if backend loaded, make sure there's a queue - if (!IsRunQueueExists(devicestr)) { - RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (!RunQueue_IsExists(devicestr)) { + RunQueueInfo *run_queue_info = RunQueue_Create(devicestr); if (run_queue_info == NULL) { RAI_ModelFree(model, &err); RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); diff --git a/src/execution/run_queue_info.c b/src/execution/run_queue_info.c new file mode 100644 index 000000000..57ce5d977 --- /dev/null +++ b/src/execution/run_queue_info.c @@ -0,0 +1,80 @@ +#include "string_utils.h" +#include "run_queue_info.h" +#include "backends/backends.h" +#include "background_workers.h" + +RunQueueInfo *RunQueue_Create(const char *device_str) { + + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + + // Create new run queue and initialize its inner fields. + RunQueueInfo *run_queue_info = RedisModule_Alloc(sizeof(RunQueueInfo)); + run_queue_info->run_queue = queueCreate(); + run_queue_info->device_str = RedisModule_Strdup(upper_device_str); + pthread_cond_init(&(run_queue_info->queue_condition_var), NULL); + pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); + run_queue_info->threads = + (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * Config_GetNumThreadsPerQueue()); + + // Save device with its associate run queue info in the dictionary. + if (AI_dictAdd(RunQueues, upper_device_str, run_queue_info) != DICT_OK) { + RunQueue_Free(run_queue_info); + return NULL; + } + + // Create worker threads. + for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { + if (pthread_create(&(run_queue_info->threads[i]), NULL, BGWorker_ThreadMain, + run_queue_info) != 0) { + AI_dictDelete(RunQueues, upper_device_str); + RunQueue_Free(run_queue_info); + return NULL; + } + } + + // Add the new device worker threads to onnx run sessions tracking. + if (RAI_backends.onnx.add_new_device) { + RAI_backends.onnx.add_new_device(device_str); + } + return run_queue_info; +} + +RunQueueInfo *RunQueue_GetInfo(const char *device_str) { + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + AI_dictEntry *entry = AI_dictFind(RunQueues, upper_device_str); + RedisModule_Assert(entry != NULL); + return AI_dictGetVal(entry); +} + +bool RunQueue_IsExists(const char *device_str) { + size_t device_str_len = strlen(device_str); + char upper_device_str[device_str_len + 1]; + String_ToUpper(device_str, upper_device_str, &device_str_len); + if (AI_dictFind(RunQueues, upper_device_str) == NULL) { + return false; + } + return true; +} + +void RunQueue_Free(RunQueueInfo *run_queue_info) { + RedisModule_Assert(queueLength(run_queue_info->run_queue) == 0); + RedisModule_Free(run_queue_info->run_queue); + RedisModule_Free(run_queue_info->device_str); + + // Wait for workers to exit and free the pool. + for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { + RedisModule_Assert(pthread_join(run_queue_info->threads[i], NULL) == 0); + RedisModule_Free(run_queue_info->threads); + } + pthread_mutex_destroy(&(run_queue_info->run_queue_mutex)); + pthread_cond_destroy(&(run_queue_info->queue_condition_var)); + RedisModule_Free(run_queue_info); +} + +AI_dict *RunQueue_GetRAIRunQueuesDict() { + return RunQueues; +} \ No newline at end of file diff --git a/src/execution/run_queue_info.h b/src/execution/run_queue_info.h new file mode 100644 index 000000000..e1e819f1e --- /dev/null +++ b/src/execution/run_queue_info.h @@ -0,0 +1,42 @@ +# pragma once + +/** + * Contains the structure to manage the per-device queues, used for decoupling + * the work from the main thread to the background worker threads. For each of + * the incoming ModelRun, ScriptRun, and DagRun commands, the request is queued + * and evaded asynchronously to one the device queues. + */ + +#include "utils.h" +#include "queue.h" + +AI_dict *RunQueues; + +typedef struct RunQueueInfo { + pthread_mutex_t run_queue_mutex; + pthread_cond_t queue_condition_var; + queue *run_queue; + pthread_t *threads; + char *device_str; +} RunQueueInfo; + +/** + * @brief Create a new run queue for a device. + */ +RunQueueInfo *RunQueue_Create(const char *device_str); + +/** + * @brief Return true if a ru queue exists for this particular device. + */ +bool RunQueue_IsExists(const char *device_str); + +/** + * @brief Return the RunQueueInfo saved in the global RunQueues dict for a certain + * device name, after asserting that it exists. + */ +RunQueueInfo *RunQueue_GetInfo(const char *device_str); + +/** + * @brief Terminate all working threads and free the run queue with its inner fields. + */ +void RunQueue_Free(RunQueueInfo *info); diff --git a/src/execution/utils.c b/src/execution/utils.c index 723ca661f..bdaeb5888 100644 --- a/src/execution/utils.c +++ b/src/execution/utils.c @@ -1,7 +1,8 @@ #include "utils.h" -#include "redis_ai_objects/tensor.h" +#include "background_workers.h" #include "redis_ai_objects/model.h" + int redisMajorVersion; int redisMinorVersion; int redisPatchVersion; diff --git a/src/execution/utils.h b/src/execution/utils.h index f4253a54c..fbe8ded3f 100644 --- a/src/execution/utils.h +++ b/src/execution/utils.h @@ -4,6 +4,7 @@ #include "redis_ai_objects/model_struct.h" #include "redis_ai_objects/err.h" #include +#include /** Use this to check if a command is given a key whose hash slot is not on the current * shard, when using enterprise cluster. diff --git a/src/redisai.c b/src/redisai.c index 83a63de40..684758724 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include #include "rmutil/alloc.h" #include "rmutil/args.h" @@ -63,6 +64,8 @@ extern int rlecMinorVersion; extern int rlecPatchVersion; extern int rlecBuild; +extern pthread_key_t ThreadIdKey; + /* ----------------------- RedisAI Module Commands ------------------------- */ /** @@ -234,8 +237,8 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg .batchsize = batchsize, .minbatchsize = minbatchsize, .minbatchtimeout = minbatchtimeout, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), + .backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(), + .backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(), }; if (AC_IsAtEnd(&ac)) { @@ -353,8 +356,8 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg } // TODO: if backend loaded, make sure there's a queue - if (!IsRunQueueExists(devicestr)) { - RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (!RunQueue_IsExists(devicestr)) { + RunQueueInfo *run_queue_info = RunQueue_Create(devicestr); if (run_queue_info == NULL) { RAI_ModelFree(model, &err); RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); @@ -391,7 +394,7 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg } void RAI_ReplyWithChunks(RedisModuleCtx *ctx, const char *buffer, long long len) { - long long chunk_size = getModelChunkSize(); + long long chunk_size = Config_GetModelChunkSize(); const size_t n_chunks = len / chunk_size + 1; if (n_chunks > 1) { RedisModule_ReplyWithArray(ctx, (long)n_chunks); @@ -776,8 +779,8 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv return ret; } - if (!IsRunQueueExists(devicestr)) { - RunQueueInfo *run_queue_info = CreateRunQueue(devicestr); + if (!RunQueue_IsExists(devicestr)) { + RunQueueInfo *run_queue_info = RunQueue_Create(devicestr); if (run_queue_info == NULL) { RAI_ScriptFree(script, &err); RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device"); @@ -981,11 +984,12 @@ int RedisAI_Config_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i const char *subcommand = RedisModule_StringPtrLen(argv[1], NULL); if (!strcasecmp(subcommand, "LOADBACKEND")) { - return RedisAI_Config_LoadBackend(ctx, argv + 1, argc - 1); + return Config_LoadBackend(ctx, argv + 1, argc - 1); } if (!strcasecmp(subcommand, "BACKENDSPATH")) { if (argc > 2) { + Config_SetBackendsPath(RedisModule_StringPtrLen(argv[2], NULL)); return RedisModule_ReplyWithSimpleString(ctx, "OK"); } else { return RedisModule_ReplyWithError(ctx, "ERR BACKENDSPATH: missing path argument"); @@ -994,7 +998,7 @@ int RedisAI_Config_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i if (!strcasecmp(subcommand, "MODEL_CHUNK_SIZE")) { if (argc > 2) { - if (RedisAI_Config_ModelChunkSize(argv[2]) == REDISMODULE_OK) { + if (Config_SetModelChunkSize(argv[2]) == REDISMODULE_OK) { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } else { return RedisModule_ReplyWithError(ctx, "ERR MODEL_CHUNK_SIZE: invalid chunk size"); @@ -1211,10 +1215,10 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_InfoAddSection(ctx, "git"); RedisModule_InfoAddFieldCString(ctx, "git_sha", REDISAI_GIT_SHA); RedisModule_InfoAddSection(ctx, "load_time_configs"); - RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", ThreadPoolSizePerQueue); - RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", getBackendsInterOpParallelism()); - RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", getBackendsIntraOpParallelism()); - RedisModule_InfoAddFieldLongLong(ctx, "timeout_for_onnxruntime_sessions", GetOnnxTimeout()); + RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", Config_GetNumThreadsPerQueue()); + RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", Config_GetBackendsInterOpParallelism()); + RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", Config_GetBackendsIntraOpParallelism()); + RedisModule_InfoAddFieldLongLong(ctx, "model_execution_timeout", Config_GetModelExecutionTimeout()); RedisModule_InfoAddSection(ctx, "memory_usage"); if (RAI_backends.onnx.get_memory_info) { RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", @@ -1281,7 +1285,7 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { char *queue_name = (char *)AI_dictGetKey(entry); RunQueueInfo *run_queue_info = (RunQueueInfo *)AI_dictGetVal(entry); if (run_queue_info) { - for (int i = 0; i < ThreadPoolSizePerQueue; i++) { + for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { pthread_t current_bg_threads = run_queue_info->threads[i]; struct timespec ts; clockid_t cid; @@ -1464,21 +1468,13 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_SetModuleOptions(ctx, REDISMODULE_OPTIONS_HANDLE_IO_ERRORS); - // Default configs - RAI_BackendsPath = NULL; - ThreadPoolSizePerQueue = REDISAI_DEFAULT_THREADS_PER_QUEUE; - setBackendsInterOpParallelism(REDISAI_DEFAULT_INTER_OP_PARALLELISM); - setBackendsIntraOpParallelism(REDISAI_DEFAULT_INTRA_OP_PARALLELISM); - setModelChunkSize(REDISAI_DEFAULT_MODEL_CHUNK_SIZE); - SetOnnxTimeout(ONNX_DEFAULT_MAX_RUNTIME); - - if (RAI_loadTimeConfig(ctx, argv, argc) != REDISMODULE_OK) { + if (Config_SetLoadTimeParams(ctx, argv, argc) != REDISMODULE_OK) { return REDISMODULE_ERR; } RunQueues = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); - pthread_key_create(&ThreadIdKey, RedisModule_Free); - RunQueueInfo *cpu_run_queue_info = CreateRunQueue("CPU"); + pthread_key_create(&ThreadIdKey, NULL); + RunQueueInfo *cpu_run_queue_info = RunQueue_Create("CPU"); if (cpu_run_queue_info == NULL) { RedisModule_Log(ctx, "warning", "RedisAI could not initialize run queue for CPU"); return REDISMODULE_ERR; diff --git a/src/serialization/AOF/rai_aof_rewrite.c b/src/serialization/AOF/rai_aof_rewrite.c index a6b306dae..bc7aa9885 100644 --- a/src/serialization/AOF/rai_aof_rewrite.c +++ b/src/serialization/AOF/rai_aof_rewrite.c @@ -46,7 +46,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value // [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] // BLOB model_blob - long long chunk_size = getModelChunkSize(); + long long chunk_size = Config_GetModelChunkSize(); const size_t n_chunks = len / chunk_size + 1; RedisModuleString **buffers_ = array_new(RedisModuleString *, n_chunks); diff --git a/src/serialization/RDB/decoder/current/v2/decode_v2.c b/src/serialization/RDB/decoder/current/v2/decode_v2.c index 67aa043d9..714bcb9bf 100644 --- a/src/serialization/RDB/decoder/current/v2/decode_v2.c +++ b/src/serialization/RDB/decoder/current/v2/decode_v2.c @@ -111,8 +111,8 @@ void *RAI_RDBLoadModel_v2(RedisModuleIO *io) { .batchsize = batchsize, .minbatchsize = minbatchsize, .minbatchtimeout = minbatchtimeout, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), + .backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(), + .backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(), }; size_t len = RedisModule_LoadUnsigned(io); diff --git a/src/serialization/RDB/decoder/previous/v0/decode_v0.c b/src/serialization/RDB/decoder/previous/v0/decode_v0.c index 23c8a7e85..58cab96e5 100644 --- a/src/serialization/RDB/decoder/previous/v0/decode_v0.c +++ b/src/serialization/RDB/decoder/previous/v0/decode_v0.c @@ -105,8 +105,8 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) { RAI_ModelOpts opts = { .batchsize = batchsize, .minbatchsize = minbatchsize, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), + .backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(), + .backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(), }; buffer = RedisModule_LoadStringBuffer(io, &len); diff --git a/src/serialization/RDB/decoder/previous/v1/decode_v1.c b/src/serialization/RDB/decoder/previous/v1/decode_v1.c index 7b070f7c1..461d0cf3c 100644 --- a/src/serialization/RDB/decoder/previous/v1/decode_v1.c +++ b/src/serialization/RDB/decoder/previous/v1/decode_v1.c @@ -109,8 +109,8 @@ void *RAI_RDBLoadModel_v1(RedisModuleIO *io) { RAI_ModelOpts opts = { .batchsize = batchsize, .minbatchsize = minbatchsize, - .backends_intra_op_parallelism = getBackendsIntraOpParallelism(), - .backends_inter_op_parallelism = getBackendsInterOpParallelism(), + .backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(), + .backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(), }; size_t len = RedisModule_LoadUnsigned(io); diff --git a/src/serialization/RDB/encoder/v2/encode_v2.c b/src/serialization/RDB/encoder/v2/encode_v2.c index bb0782635..f2e34ca2a 100644 --- a/src/serialization/RDB/encoder/v2/encode_v2.c +++ b/src/serialization/RDB/encoder/v2/encode_v2.c @@ -55,7 +55,7 @@ void RAI_RDBSaveModel_v2(RedisModuleIO *io, void *value) { for (size_t i = 0; i < model->noutputs; i++) { RedisModule_SaveStringBuffer(io, model->outputs[i], strlen(model->outputs[i]) + 1); } - long long chunk_size = getModelChunkSize(); + long long chunk_size = Config_GetModelChunkSize(); const size_t n_chunks = len / chunk_size + 1; RedisModule_SaveUnsigned(io, len); RedisModule_SaveUnsigned(io, n_chunks); diff --git a/src/util/queue.c b/src/util/queue.c index b68c1a15c..e2dffdad9 100644 --- a/src/util/queue.c +++ b/src/util/queue.c @@ -90,7 +90,7 @@ queueItem *queueEvict(queue *queue, queueItem *item) { return item; } -long queueLength(queue *queue) { return queue->len; } +unsigned long queueLength(queue *queue) { return queue->len; } void queueRelease(queue *queue) { unsigned long len; diff --git a/src/util/queue.h b/src/util/queue.h index a03845182..96aba4154 100644 --- a/src/util/queue.h +++ b/src/util/queue.h @@ -28,5 +28,5 @@ queueItem *queuePop(queue *queue); queueItem *queueFront(queue *queue); queueItem *queueNext(queueItem *item); queueItem *queueEvict(queue *queue, queueItem *item); -long queueLength(queue *queue); +unsigned long queueLength(queue *queue); void queueRelease(queue *queue); diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index c1d632ed9..d6a21e9d2 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -488,7 +488,7 @@ def test_onnx_kill_switch_basic(env): def test_onnx_kill_switch_multiple_working_threads(): - env = Env(moduleArgs='THREADS_PER_QUEUE 8 ONNX_TIMEOUT 1000') + env = Env(moduleArgs='THREADS_PER_QUEUE 8 MODEL_EXECUTION_TIMEOUT 1000') con = env.getConnection() model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) From 697faf91b26936a5ed61f40aab6982f1ff7c141a Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 10 Jun 2021 12:45:19 +0300 Subject: [PATCH 18/27] linter... --- src/backends/backends.c | 5 ++-- src/backends/onnx_timeout.c | 1 - src/config/config.c | 38 +++++++++++++----------------- src/config/config.h | 2 +- src/execution/background_workers.c | 12 ++++------ src/execution/background_workers.h | 1 - src/execution/run_queue_info.c | 8 +++---- src/execution/run_queue_info.h | 2 +- src/execution/utils.c | 1 - src/redisai.c | 9 ++++--- 10 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index cf794c619..f01e657fa 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -18,8 +18,6 @@ #include "config/config.h" #include "execution/background_workers.h" - - static bool _ValidateAPICreated(RedisModuleCtx *ctx, void *func_ptr, const char *func_name) { if (func_ptr == NULL) { RedisModule_Log(ctx, "warning", "Backend does not export %s", func_name); @@ -61,7 +59,8 @@ RedisModuleString *RAI_GetBackendsPath(RedisModuleCtx *ctx) { Dl_info info; RedisModuleString *backends_path = NULL; if (Config_GetBackendsPath() != NULL) { - backends_path = RedisModule_CreateString(ctx, Config_GetBackendsPath(), strlen(Config_GetBackendsPath())); + backends_path = RedisModule_CreateString(ctx, Config_GetBackendsPath(), + strlen(Config_GetBackendsPath())); } else { RedisModuleString *module_path = RAI_GetModulePath(ctx); backends_path = RedisModule_CreateStringPrintf(ctx, "%s/backends", diff --git a/src/backends/onnx_timeout.c b/src/backends/onnx_timeout.c index 7ce577c8a..6f3676310 100644 --- a/src/backends/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -7,7 +7,6 @@ #include "util/string_utils.h" #include "redis_ai_objects/stats.h" - int CreateGlobalOnnxRunSessions() { onnx_global_run_sessions = RedisModule_Alloc(sizeof(struct OnnxGlobalRunSessions)); diff --git a/src/config/config.c b/src/config/config.c index 7a205467e..8bcf0f2dc 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -4,22 +4,21 @@ #include "backends/backends.h" // Default configs -char *BackendsPath = NULL; // Path to backends dir. +char *BackendsPath = NULL; // Path to backends dir. -long long BackendsIntraOpParallelism = 0; // number of threads used within an - // individual op for parallelism. -long long BackendsInterOpParallelism = 0; // number of threads used for parallelism - // between independent operations. -long long ModelChunkSize = 535822336; // size of chunks used to break up model payloads. - // default is 511 * 1024 * 1024 -long long ThreadPoolSizePerQueue = 1; // Number of working threads for device. - -long long ModelExecutionTimeout = 5000; // The maximum time in milliseconds - // before killing onnx run session. +long long BackendsIntraOpParallelism = 0; // number of threads used within an + // individual op for parallelism. +long long BackendsInterOpParallelism = 0; // number of threads used for parallelism + // between independent operations. +long long ModelChunkSize = 535822336; // size of chunks used to break up model payloads. + // default is 511 * 1024 * 1024 +long long ThreadPoolSizePerQueue = 1; // Number of working threads for device. +long long ModelExecutionTimeout = 5000; // The maximum time in milliseconds + // before killing onnx run session. static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val, - RedisModuleString *rsval) { + RedisModuleString *rsval) { int ret = REDISMODULE_OK; long long param_val; if (strcasecmp((key), "TF") == 0) { @@ -36,31 +35,28 @@ static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, cons else if (strcasecmp((key), "THREADS_PER_QUEUE") == 0) { ret = Config_SetQueueThreadsNum(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_THREADS_PER_QUEUE, (param_val)); + RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_THREADS_PER_QUEUE, + (param_val)); } } else if (strcasecmp((key), "INTRA_OP_PARALLELISM") == 0) { ret = Config_SetIntraOperationParallelism(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, - val); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTRA_OP_PARALLELISM, val); } } else if (strcasecmp((key), "INTER_OP_PARALLELISM") == 0) { ret = Config_SetInterOperationParallelism(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTER_OP_PARALLELISM, - val); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_INTER_OP_PARALLELISM, val); } } else if (strcasecmp((key), "MODEL_CHUNK_SIZE") == 0) { ret = Config_SetModelChunkSize(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, - val); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_CHUNK_SIZE, val); } } else if (strcasecmp((key), "MODEL_EXECUTION_TIMEOUT") == 0) { ret = Config_SetModelExecutionTimeout(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT, - val); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_MODEL_EXECUTION_TIMEOUT, val); } } else if (strcasecmp((key), "BACKENDSPATH") == 0) { // already taken care of diff --git a/src/config/config.h b/src/config/config.h index 78cfb16ee..724cc509a 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -19,7 +19,7 @@ typedef enum { RAI_DEVICE_CPU = 0, RAI_DEVICE_GPU = 1 } RAI_Device; #define RAI_COPY_RUN_OUTPUT #define RAI_PRINT_BACKEND_ERRORS -#define REDISAI_ERRORMSG_PROCESSING_ARG "ERR error processing argument" +#define REDISAI_ERRORMSG_PROCESSING_ARG "ERR error processing argument" #define REDISAI_INFOMSG_THREADS_PER_QUEUE "Setting THREADS_PER_QUEUE parameter to" #define REDISAI_INFOMSG_INTRA_OP_PARALLELISM "Setting INTRA_OP_PARALLELISM parameter to" diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index c6498bd4e..902eb689e 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -21,8 +21,8 @@ int pthread_setname_np(const char *name); #endif #endif -uintptr_t BGWorkersCounter; // Total number of BG threads running currently. -pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage. +uintptr_t BGWorkersCounter; // Total number of BG threads running currently. +pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage. /** * @brief Save the id for some working thread in thread local storage. @@ -223,13 +223,9 @@ static RedisAI_RunInfo **_BGThread_BatchOperations(RunQueueInfo *run_queue_info, return batch_rinfo; } -uintptr_t BGWorker_GetThreadId() { - return (uintptr_t)pthread_getspecific(ThreadIdKey); -} +uintptr_t BGWorker_GetThreadId() { return (uintptr_t)pthread_getspecific(ThreadIdKey); } -uintptr_t BGWorker_GetThreadsCount() { - return BGWorkersCounter; -} +uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCounter; } void *BGWorker_ThreadMain(void *arg) { _BGWorker_SaveThreadId(); diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index 0e8948c2a..f3dbf4ce6 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -29,7 +29,6 @@ #include "util/arr.h" #include "util/queue.h" - /** * @brief RedisAI main loop for every background working thread * @param arg - This is the run queue info of the device on which this thread is diff --git a/src/execution/run_queue_info.c b/src/execution/run_queue_info.c index 57ce5d977..ff3a39bd5 100644 --- a/src/execution/run_queue_info.c +++ b/src/execution/run_queue_info.c @@ -16,7 +16,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { pthread_cond_init(&(run_queue_info->queue_condition_var), NULL); pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); run_queue_info->threads = - (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * Config_GetNumThreadsPerQueue()); + (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * Config_GetNumThreadsPerQueue()); // Save device with its associate run queue info in the dictionary. if (AI_dictAdd(RunQueues, upper_device_str, run_queue_info) != DICT_OK) { @@ -27,7 +27,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { // Create worker threads. for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { if (pthread_create(&(run_queue_info->threads[i]), NULL, BGWorker_ThreadMain, - run_queue_info) != 0) { + run_queue_info) != 0) { AI_dictDelete(RunQueues, upper_device_str); RunQueue_Free(run_queue_info); return NULL; @@ -75,6 +75,4 @@ void RunQueue_Free(RunQueueInfo *run_queue_info) { RedisModule_Free(run_queue_info); } -AI_dict *RunQueue_GetRAIRunQueuesDict() { - return RunQueues; -} \ No newline at end of file +AI_dict *RunQueue_GetRAIRunQueuesDict() { return RunQueues; } \ No newline at end of file diff --git a/src/execution/run_queue_info.h b/src/execution/run_queue_info.h index e1e819f1e..abe1331cb 100644 --- a/src/execution/run_queue_info.h +++ b/src/execution/run_queue_info.h @@ -1,4 +1,4 @@ -# pragma once +#pragma once /** * Contains the structure to manage the per-device queues, used for decoupling diff --git a/src/execution/utils.c b/src/execution/utils.c index bdaeb5888..ea0fd501d 100644 --- a/src/execution/utils.c +++ b/src/execution/utils.c @@ -2,7 +2,6 @@ #include "background_workers.h" #include "redis_ai_objects/model.h" - int redisMajorVersion; int redisMinorVersion; int redisPatchVersion; diff --git a/src/redisai.c b/src/redisai.c index 684758724..c89583344 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -1216,9 +1216,12 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RedisModule_InfoAddFieldCString(ctx, "git_sha", REDISAI_GIT_SHA); RedisModule_InfoAddSection(ctx, "load_time_configs"); RedisModule_InfoAddFieldLongLong(ctx, "threads_per_queue", Config_GetNumThreadsPerQueue()); - RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", Config_GetBackendsInterOpParallelism()); - RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", Config_GetBackendsIntraOpParallelism()); - RedisModule_InfoAddFieldLongLong(ctx, "model_execution_timeout", Config_GetModelExecutionTimeout()); + RedisModule_InfoAddFieldLongLong(ctx, "inter_op_parallelism", + Config_GetBackendsInterOpParallelism()); + RedisModule_InfoAddFieldLongLong(ctx, "intra_op_parallelism", + Config_GetBackendsIntraOpParallelism()); + RedisModule_InfoAddFieldLongLong(ctx, "model_execution_timeout", + Config_GetModelExecutionTimeout()); RedisModule_InfoAddSection(ctx, "memory_usage"); if (RAI_backends.onnx.get_memory_info) { RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", From 1201cb2c2271cb26725b8fcae053ab257b274aa6 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 10 Jun 2021 13:09:48 +0300 Subject: [PATCH 19/27] linter... --- src/execution/background_workers.c | 1 - 1 file changed, 1 deletion(-) diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index a63efed9c..f318027f4 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -252,7 +252,6 @@ static bool _BGThread_PrepareExecution(RunQueueInfo *run_queue_info, RedisAI_Run return true; } - uintptr_t BGWorker_GetThreadId() { return (uintptr_t)pthread_getspecific(ThreadIdKey); } uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCounter; } From 21737e6ee4e6ffe28d421f620e80c7c27b8f1cfe Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 10 Jun 2021 17:36:00 +0300 Subject: [PATCH 20/27] More PR fixes, add the option to get the global run sessions array from backend len (supported only for onnx now) in INFO MODULES command. --- src/backends/backends.c | 9 +- src/backends/backends.h | 6 +- src/backends/onnx_timeout.c | 16 +++- src/backends/onnx_timeout.h | 11 ++- src/backends/onnxruntime.c | 8 +- src/execution/run_queue_info.c | 6 +- src/redisai.c | 21 +++-- src/serialization/AOF/rai_aof_rewrite.c | 2 +- src/util/string_utils.c | 12 +-- src/util/string_utils.h | 2 +- tests/flow/tests_onnx.py | 115 +++++++++++++----------- 11 files changed, 118 insertions(+), 90 deletions(-) diff --git a/src/backends/backends.c b/src/backends/backends.c index f01e657fa..c5082588a 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -71,7 +71,7 @@ RedisModuleString *RAI_GetBackendsPath(RedisModuleCtx *ctx) { return backends_path; } -const char *GetBackendName(RAI_Backend backend) { +const char *RAI_GetBackendName(RAI_Backend backend) { switch (backend) { case RAI_BACKEND_TENSORFLOW: return "TF"; @@ -372,6 +372,13 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { goto error; } + backend.get_global_run_sessions_len = + (size_t(*)(void))(unsigned long)dlsym(handle, "RAI_GetGlobalRunSessionsLenORT"); + if (!_ValidateAPICreated(ctx, backend.get_global_run_sessions_len, + "RAI_GetGlobalRunSessionsLenORT")) { + goto error; + } + RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, backend.enforce_runtime_duration); RAI_backends.onnx = backend; diff --git a/src/backends/backends.h b/src/backends/backends.h index 01ccef8b2..eb9807156 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -88,6 +88,9 @@ typedef struct RAI_LoadedBackend { // Kill run session callback (for stopping long runs). void (*enforce_runtime_duration)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *); + + // Get the length of the global run sessions array (with entry per working thread). + size_t (*get_global_run_sessions_len)(void); } RAI_LoadedBackend; typedef struct RAI_LoadedBackends { @@ -100,9 +103,10 @@ typedef struct RAI_LoadedBackends { RAI_LoadedBackends RAI_backends; int RAI_LoadBackend(RedisModuleCtx *ctx, int backend, const char *path); + int RAI_LoadDefaultBackend(RedisModuleCtx *ctx, int backend); /** * @brief Returns the backend name as string. */ -const char *GetBackendName(RAI_Backend backend); +const char *RAI_GetBackendName(RAI_Backend backend); diff --git a/src/backends/onnx_timeout.c b/src/backends/onnx_timeout.c index 6f3676310..7e2b218ce 100644 --- a/src/backends/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -7,8 +7,8 @@ #include "util/string_utils.h" #include "redis_ai_objects/stats.h" -int CreateGlobalOnnxRunSessions() { - onnx_global_run_sessions = RedisModule_Alloc(sizeof(struct OnnxGlobalRunSessions)); +int RAI_InitGlobalRunSessionsORT() { + onnx_global_run_sessions = RedisModule_Alloc(sizeof(OnnxGlobalRunSessions)); // Initialize the array with entries number equals to the number of currently // working threads in RedisAI (note that CPU threads must exist form the start). @@ -25,6 +25,13 @@ int CreateGlobalOnnxRunSessions() { return REDISMODULE_OK; } +size_t RAI_GetGlobalRunSessionsLenORT() { + pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); + size_t len = array_len(onnx_global_run_sessions->OnnxRunSessions); + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); + return len; +} + int RAI_AddNewDeviceORT(const char *device_str) { // Acquire write lock, as we might reallocate the array while extending it. @@ -45,6 +52,7 @@ int RAI_AddNewDeviceORT(const char *device_str) { void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data) { + RedisModule_Assert(eid.id == REDISMODULE_EVENT_CRON_LOOP); const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx **run_sessions_ctx = onnx_global_run_sessions->OnnxRunSessions; @@ -63,7 +71,7 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index) { +void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session_index) { pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); // Get the thread index (which is the correspondent index in the global sessions array). @@ -77,7 +85,7 @@ void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index) pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void InvalidateRunSessionCtx(size_t run_session_index) { +void RAI_InvalidateRunSessionCtxORT(size_t run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; diff --git a/src/backends/onnx_timeout.h b/src/backends/onnx_timeout.h index c16671f2a..1bc43e638 100644 --- a/src/backends/onnx_timeout.h +++ b/src/backends/onnx_timeout.h @@ -23,7 +23,12 @@ OnnxGlobalRunSessions *onnx_global_run_sessions; * so that every thread will have a designated entry to update with the onnx session * that it's going to run. */ -int CreateGlobalOnnxRunSessions(void); +int RAI_InitGlobalRunSessionsORT(void); + +/** + * @return The length of the global array (should be the number of current working threads) + */ +size_t RAI_GetGlobalRunSessionsLenORT(void); /** * @brief This is called whenever RedisAI gets a request to store a model that run @@ -48,11 +53,11 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s * @param run_session_index - placeholder for the index of the running thread * in the global array, to have a quick access later to clean this entry. */ -void SetRunSessionCtx(OrtRunOptions *new_run_options, size_t *run_session_index); +void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session_index); /** * @brief Release the OrtRunOptions of a session that finished its run and * reset the corresponding entry in the global structure. * @param run_session_index - The entry index where OrtRunOptions was stored. */ -void InvalidateRunSessionCtx(size_t run_session_index); +void RAI_InvalidateRunSessionCtxORT(size_t run_session_index); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 482d6fcfd..6d979ccb5 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -99,7 +99,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) { get_api_fn("GetModelExecutionTimeout", ((void **)&RedisAI_GetModelExecutionTimeout)); get_api_fn("GetThreadsCount", ((void **)&RedisAI_GetThreadsCount)); // Create a global array of onnx runSessions, with an entry for every working thread. - CreateGlobalOnnxRunSessions(); + RAI_InitGlobalRunSessionsORT(); return REDISMODULE_OK; } @@ -582,11 +582,11 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); // Set the created run option in the global RunSessions and save its index. - SetRunSessionCtx(run_options, &run_session_index); + RAI_SetRunSessionCtxORT(run_options, &run_session_index); ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - InvalidateRunSessionCtx(run_session_index); + RAI_InvalidateRunSessionCtxORT(run_session_index); run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { @@ -674,7 +674,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ort->ReleaseTensorTypeAndShapeInfo(info); } if (run_options) { - InvalidateRunSessionCtx(run_session_index); + RAI_InvalidateRunSessionCtxORT(run_session_index); } return REDISMODULE_ERR; } diff --git a/src/execution/run_queue_info.c b/src/execution/run_queue_info.c index ff3a39bd5..426fac49c 100644 --- a/src/execution/run_queue_info.c +++ b/src/execution/run_queue_info.c @@ -7,7 +7,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); + RAI_StringToUpper(device_str, upper_device_str, device_str_len + 1); // Create new run queue and initialize its inner fields. RunQueueInfo *run_queue_info = RedisModule_Alloc(sizeof(RunQueueInfo)); @@ -44,7 +44,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { RunQueueInfo *RunQueue_GetInfo(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); + RAI_StringToUpper(device_str, upper_device_str, device_str_len + 1); AI_dictEntry *entry = AI_dictFind(RunQueues, upper_device_str); RedisModule_Assert(entry != NULL); return AI_dictGetVal(entry); @@ -53,7 +53,7 @@ RunQueueInfo *RunQueue_GetInfo(const char *device_str) { bool RunQueue_IsExists(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; - String_ToUpper(device_str, upper_device_str, &device_str_len); + RAI_StringToUpper(device_str, upper_device_str, device_str_len + 1); if (AI_dictFind(RunQueues, upper_device_str) == NULL) { return false; } diff --git a/src/redisai.c b/src/redisai.c index c89583344..78cdfc35b 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -466,7 +466,7 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_ReplyWithArray(ctx, outentries); RedisModule_ReplyWithCString(ctx, "backend"); - const char *backendstr = GetBackendName(mto->backend); + const char *backendstr = RAI_GetBackendName(mto->backend); RedisModule_ReplyWithCString(ctx, backendstr); RedisModule_ReplyWithCString(ctx, "device"); @@ -854,25 +854,25 @@ void _RedisAI_Info(RedisModuleCtx *ctx) { RedisModule_CreateStringPrintf(ctx, "%d", REDISAI_LLAPI_VERSION); RedisModuleString *rdb_version = RedisModule_CreateStringPrintf(ctx, "%llu", REDISAI_ENC_VER); - int reponse_len = 6; + int response_len = 6; if (RAI_backends.tf.get_version) { - reponse_len += 2; + response_len += 2; } if (RAI_backends.torch.get_version) { - reponse_len += 2; + response_len += 2; } if (RAI_backends.tflite.get_version) { - reponse_len += 2; + response_len += 2; } if (RAI_backends.onnx.get_version) { - reponse_len += 2; + response_len += 2; } - RedisModule_ReplyWithArray(ctx, reponse_len); + RedisModule_ReplyWithArray(ctx, response_len); RedisModule_ReplyWithSimpleString(ctx, "Version"); RedisModule_ReplyWithString(ctx, rai_version); @@ -948,7 +948,7 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModule_ReplyWithCString(ctx, "SCRIPT"); } RedisModule_ReplyWithCString(ctx, "backend"); - RedisModule_ReplyWithCString(ctx, GetBackendName(rstats->backend)); + RedisModule_ReplyWithCString(ctx, RAI_GetBackendName(rstats->backend)); RedisModule_ReplyWithCString(ctx, "device"); RedisModule_ReplyWithCString(ctx, rstats->devicestr); RedisModule_ReplyWithCString(ctx, "tag"); @@ -960,7 +960,7 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModule_ReplyWithCString(ctx, "duration"); RedisModule_ReplyWithLongLong(ctx, rstats->duration_us); RedisModule_ReplyWithCString(ctx, "samples"); - if (rstats->type == 0) { + if (rstats->type == RAI_MODEL) { RedisModule_ReplyWithLongLong(ctx, rstats->samples); } else { RedisModule_ReplyWithLongLong(ctx, -1); @@ -1228,9 +1228,12 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { RAI_backends.onnx.get_memory_info()); RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory_access_num", RAI_backends.onnx.get_memory_access_num()); + RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_global_run_sessions_length", + RAI_backends.onnx.get_global_run_sessions_len()); } else { RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", 0); RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory_access_num", 0); + RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_global_run_sessions_length", 0); } struct rusage self_ru, c_ru; diff --git a/src/serialization/AOF/rai_aof_rewrite.c b/src/serialization/AOF/rai_aof_rewrite.c index bc7aa9885..29d8dd6e3 100644 --- a/src/serialization/AOF/rai_aof_rewrite.c +++ b/src/serialization/AOF/rai_aof_rewrite.c @@ -60,7 +60,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value RedisModule_Free(buffer); } - const char *backendstr = GetBackendName(model->backend); + const char *backendstr = RAI_GetBackendName(model->backend); if (model->backend != RAI_BACKEND_TENSORFLOW) { diff --git a/src/util/string_utils.c b/src/util/string_utils.c index 0df2ee28f..49be828e4 100644 --- a/src/util/string_utils.c +++ b/src/util/string_utils.c @@ -54,16 +54,10 @@ void *RAI_RStringsKeyDup(void *privdata, const void *key) { return RedisModule_CreateStringFromString(NULL, (RedisModuleString *)key); } -void String_ToUpper(const char *str, char *upper, size_t *upper_len) { - size_t str_len = strlen(str); - // Avoid overflow - RedisModule_Assert(*upper_len >= str_len); - - // Update the upper string buffer len. - *upper_len = str_len; - +void RAI_StringToUpper(const char *str, char *upper, size_t str_len) { + // Assumption: upper buffer size is at least str_len. This can be used for + // every binary string, we do not assume that the string is null-terminated. for (size_t i = 0; i < str_len; i++) { upper[i] = (char)toupper(str[i]); } - upper[str_len] = 0; } diff --git a/src/util/string_utils.h b/src/util/string_utils.h index d2c60614b..9ba143e89 100644 --- a/src/util/string_utils.h +++ b/src/util/string_utils.h @@ -2,7 +2,7 @@ #include "dict.h" RedisModuleString *RAI_HoldString(RedisModuleString *str); -void String_ToUpper(const char *str, char *upper, size_t *upper_len); +void RAI_StringToUpper(const char *str, char *upper, size_t str_len); uint64_t RAI_StringsHashFunction(const void *key); int RAI_StringsKeyCompare(void *privdata, const void *key1, const void *key2); diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index d6a21e9d2..ba555f50d 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -464,60 +464,67 @@ def test_onnx_use_custom_allocator_with_GPU(env): env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 11) -def test_onnx_kill_switch_basic(env): - con = env.getConnection() - model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") - ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) - env.assertEqual(ret, b'OK') - - # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' - # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model - # runs a very large number of iterations without doing anything, until it is caught with the kill switch. - con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) - con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) - con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) - con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) - - try: - con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 4, 'outer_scope_input{1}', 'iterations{1}', - 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') - env.assertTrue(False) - except Exception as exception: - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) - - -def test_onnx_kill_switch_multiple_working_threads(): - env = Env(moduleArgs='THREADS_PER_QUEUE 8 MODEL_EXECUTION_TIMEOUT 1000') - con = env.getConnection() - model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") - ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) - env.assertEqual(ret, b'OK') - - # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' - # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model - # runs a very large number of iterations without doing anything, until it is caught with the kill switch. - con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) - con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) - con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) - con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) - - # Load another onnx model as if it runs on a different device (to test existence of multiple queues) - model_pb = load_file_content('mnist.onnx') - sample_raw = load_file_content('one.raw') - ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) - env.assertEqual(ret, b'OK') - con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) - - def run_parallel_onnx_sessions(con): - ret = con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') - env.assertEqual(ret, b'OK') +class TestOnnxKillSwitch: + + def __init__(self): + self.env = Env(moduleArgs='THREADS_PER_QUEUE 8 MODEL_EXECUTION_TIMEOUT 1000') + con = self.env.getConnection() + model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") + ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) + self.env.assertEqual(ret, b'OK') + + # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' + # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model + # runs a very large number of iterations without doing anything, until it is caught with the kill switch. + con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) + con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) + con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) + con.execute_command('AI.TENSORSET', 'outer_scope_input{1}', 'FLOAT', 1, 'VALUES', 42) + + def test_basic(self): try: + con = self.env.getConnection() con.execute_command('AI.MODELEXECUTE', 'inf_loop_model{1}', 'INPUTS', 4, 'outer_scope_input{1}', 'iterations{1}', - 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') + 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') + self.env.assertTrue(False) except Exception as exception: - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) - ret = con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') - env.assertEqual(ret, b'OK') - run_test_multiproc(env, 8, run_parallel_onnx_sessions) + self.env.assertEqual(type(exception), redis.exceptions.ResponseError) + self.env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) + + def test_multiple_working_threads(self): + con = self.env.getConnection() + + # Load another onnx model that will be executed on the same threads that use the kill switch + model_pb = load_file_content('mnist.onnx') + sample_raw = load_file_content('one.raw') + ret = con.execute_command('AI.MODELSTORE', 'mnist{1}', 'ONNX', DEVICE, 'BLOB', model_pb) + self.env.assertEqual(ret, b'OK') + con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) + + def run_parallel_onnx_sessions(con): + ret = con.execute_command('AI.MODELEXECUTE', 'mnist{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') + self.env.assertEqual(ret, b'OK') + try: + con.execute_command('AI.MODELEXECUTE', 'inf_loop_model{1}', 'INPUTS', 4, 'outer_scope_input{1}', 'iterations{1}', + 'loop_cond{1}', 'loop_input{1}', 'OUTPUTS', 2, 'outer_scope_output{1}', 'loop_output{1}') + except Exception as exception: + self.env.assertEqual(type(exception), redis.exceptions.ResponseError) + self.env.assertTrue(str(exception).find("Exiting due to terminate flag being set to true") != -1) + ret = con.execute_command('AI.MODELEXECUTE', 'mnist{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') + self.env.assertEqual(ret, b'OK') + run_test_multiproc(self.env, 8, run_parallel_onnx_sessions) + + def test_multiple_devices(self): + con = self.env.getConnection() + memory_config = {k.split(":")[0]: k.split(":")[1] + for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} + self.env.assertEqual(memory_config['ai_onnxruntime_global_run_sessions_length'], '8') + + # Load another onnx model as if it runs on a different device (to test existence of multiple queues, and + # the extension of the global onnx run sessions array as a consequence.) + model_pb = load_file_content('mnist.onnx') + ret = con.execute_command('AI.MODELSTORE', 'mnist_{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) + self.env.assertEqual(ret, b'OK') + memory_config = {k.split(":")[0]: k.split(":")[1] + for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} + self.env.assertEqual(memory_config['ai_onnxruntime_global_run_sessions_length'], '16') From 73f2a91f7cedc678cb2348dedc2d4caa3e29e2eb Mon Sep 17 00:00:00 2001 From: alonre24 Date: Thu, 10 Jun 2021 18:09:45 +0300 Subject: [PATCH 21/27] Minor fixes --- src/config/config.c | 4 +--- src/execution/run_queue_info.c | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/config/config.c b/src/config/config.c index 8bcf0f2dc..e51d43692 100644 --- a/src/config/config.c +++ b/src/config/config.c @@ -20,7 +20,6 @@ long long ModelExecutionTimeout = 5000; // The maximum time in milliseconds static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, const char *val, RedisModuleString *rsval) { int ret = REDISMODULE_OK; - long long param_val; if (strcasecmp((key), "TF") == 0) { ret = RAI_LoadBackend(ctx, RAI_BACKEND_TENSORFLOW, (val)); } else if (strcasecmp((key), "TFLITE") == 0) { @@ -35,8 +34,7 @@ static int _Config_LoadTimeParamParse(RedisModuleCtx *ctx, const char *key, cons else if (strcasecmp((key), "THREADS_PER_QUEUE") == 0) { ret = Config_SetQueueThreadsNum(rsval); if (ret == REDISMODULE_OK) { - RedisModule_Log(ctx, "notice", "%s: %lld", REDISAI_INFOMSG_THREADS_PER_QUEUE, - (param_val)); + RedisModule_Log(ctx, "notice", "%s: %s", REDISAI_INFOMSG_THREADS_PER_QUEUE, (val)); } } else if (strcasecmp((key), "INTRA_OP_PARALLELISM") == 0) { ret = Config_SetIntraOperationParallelism(rsval); diff --git a/src/execution/run_queue_info.c b/src/execution/run_queue_info.c index 426fac49c..9e98dcb34 100644 --- a/src/execution/run_queue_info.c +++ b/src/execution/run_queue_info.c @@ -74,5 +74,3 @@ void RunQueue_Free(RunQueueInfo *run_queue_info) { pthread_cond_destroy(&(run_queue_info->queue_condition_var)); RedisModule_Free(run_queue_info); } - -AI_dict *RunQueue_GetRAIRunQueuesDict() { return RunQueues; } \ No newline at end of file From 3942d237df8d5b65c993af36f103d872cdc496cd Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 13 Jun 2021 16:17:47 +0300 Subject: [PATCH 22/27] More PR fixes, among that: - Add a state flag to every entry in the onnx run sessions array and update it atomically, to avoid situations where main threads and bg thread both access the runOptions field. - Refactor the info_modules section, and change AI.INFO command so it must receive a module/script key. The other info will be accessible as part of the info modules command. --- docs/commands.md | 16 +--- src/backends/backedns_api.h | 11 +++ src/backends/backends.c | 124 ++++++++++++++++----------- src/backends/backends.h | 8 +- src/backends/onnx_timeout.c | 37 +++++--- src/backends/onnx_timeout.h | 8 ++ src/backends/onnxruntime.c | 4 +- src/backends/onnxruntime.h | 5 -- src/execution/parsing/deprecated.c | 5 +- src/execution/run_queue_info.c | 20 ++--- src/execution/run_queue_info.h | 1 + src/execution/utils.c | 1 - src/execution/utils.h | 7 +- src/redisai.c | 132 ++++++++++------------------- tests/flow/includes.py | 9 ++ tests/flow/tests_common.py | 4 - tests/flow/tests_onnx.py | 85 +++++++------------ tests/flow/tests_pytorch.py | 15 ++-- tests/flow/tests_tensorflow.py | 13 ++- tests/flow/tests_tflite.py | 10 +-- 20 files changed, 244 insertions(+), 271 deletions(-) create mode 100644 src/backends/backedns_api.h diff --git a/docs/commands.md b/docs/commands.md index f698342ed..3e44edfc8 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -934,14 +934,14 @@ Because `AI.DAGRUN` provides the `PERSIST` option it is flagged as a 'write' com Refer to the Redis [`READONLY` command](https://redis.io/commands/readonly) for further information about read-only cluster replicas. ## AI.INFO -The **`AI.INFO`** command returns general module information or information about the execution a model or a script. +The **`AI.INFO`** command returns information about the execution of a model or a script. -Runtime information is collected each time that [`AI.MODELRUN`](#aimodelrun) or [`AI.SCRIPTRUN`](#aiscriptrun) is called. The information is stored locally by the executing RedisAI engine, so when deployed in a cluster each shard stores its own runtime information. +Runtime information is collected each time that [`AI.MODELEXECUTE`](#aimodelrun) or [`AI.SCRIPTEXECUTE`](#aiscriptrun) is called. The information is stored locally by the executing RedisAI engine, so when deployed in a cluster each shard stores its own runtime information. **Redis API** ``` -AI.INFO [] [RESETSTAT] +AI.INFO [RESETSTAT] ``` _Arguments_ @@ -951,15 +951,7 @@ _Arguments_ _Return_ -For a module genernal information: An array with alternating entries that represent the following key-value pairs: - -* **Version**: a string showing the current module version. -* **Low level API Version**: a string showing the current module's low level api version. -* **RDB Encoding version**: a string showing the current module's RDB encoding version. -* **TensorFlow version**: a string showing the current loaded TesnorFlow backend version. -* **ONNX version**: a string showing the current loaded ONNX Runtime backend version. - -For model or script runtime information: An array with alternating entries that represent the following key-value pairs: +An array with alternating entries that represent the following key-value pairs: * **KEY**: a String of the name of the key storing the model or script value * **TYPE**: a String of the type of value (i.e. 'MODEL' or 'SCRIPT') diff --git a/src/backends/backedns_api.h b/src/backends/backedns_api.h new file mode 100644 index 000000000..13ed409b7 --- /dev/null +++ b/src/backends/backedns_api.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +uintptr_t (*RedisAI_GetThreadId)(void); + +uintptr_t (*RedisAI_GetThreadsCount)(void); + +long long (*RedisAI_GetNumThreadsPerQueue)(void); + +long long (*RedisAI_GetModelExecutionTimeout)(void); diff --git a/src/backends/backends.c b/src/backends/backends.c index c5082588a..490922462 100644 --- a/src/backends/backends.c +++ b/src/backends/backends.c @@ -18,26 +18,38 @@ #include "config/config.h" #include "execution/background_workers.h" -static bool _ValidateAPICreated(RedisModuleCtx *ctx, void *func_ptr, const char *func_name) { +static bool _ValidateFuncExists(RedisModuleCtx *ctx, void *func_ptr, const char *func_name, + const char *backend_name, const char *path) { if (func_ptr == NULL) { - RedisModule_Log(ctx, "warning", "Backend does not export %s", func_name); + RedisModule_Log(ctx, "warning", + "Backend does not export %s. %s backend" + " was not loaded from %s", + func_name, backend_name, path); return false; } return true; } -int RAI_GetApi(const char *func_name, void **targetPtrPtr) { +/** + * @brief Export a function from RedisAI to a backend. This will set a pointer + * to a function that has been declared in the backend to use the corresponding + * function in RedisAI. + * @param func_name A string that identifies the function to export. + * @param targetFuncPtr place holder for a function pointer coming from the + * backend to set the corresponding function from RedisAI into it. + */ +int RAI_ExportFunc(const char *func_name, void **targetFuncPtr) { if (strcmp("GetThreadId", func_name) == 0) { - *targetPtrPtr = BGWorker_GetThreadId; + *targetFuncPtr = BGWorker_GetThreadId; } else if (strcmp("GetNumThreadsPerQueue", func_name) == 0) { - *targetPtrPtr = Config_GetNumThreadsPerQueue; + *targetFuncPtr = Config_GetNumThreadsPerQueue; } else if (strcmp("GetModelExecutionTimeout", func_name) == 0) { - *targetPtrPtr = Config_GetModelExecutionTimeout; + *targetFuncPtr = Config_GetModelExecutionTimeout; } else if (strcmp("GetThreadsCount", func_name) == 0) { - *targetPtrPtr = BGWorker_GetThreadsCount; + *targetFuncPtr = BGWorker_GetThreadsCount; } else { - return RedisModule_GetApi(func_name, targetPtrPtr); + return RedisModule_GetApi(func_name, targetFuncPtr); } return REDISMODULE_OK; } @@ -101,40 +113,43 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym(handle, "RAI_InitBackendTF"); - if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTF")) { + if (!_ValidateFuncExists(ctx, init_backend, "RAI_InitBackendTF", "TF", path)) { goto error; } + // Here we use the input callback to export functions from Redis to the backend, + // by setting the backend's function pointers to the corresponding functions in Redis. init_backend(RedisModule_GetApi); backend.model_create_with_nodes = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, size_t, const char **, size_t, const char **, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTF"); - if (!_ValidateAPICreated(ctx, backend.model_create_with_nodes, "RAI_ModelCreateTF")) { + if (!_ValidateFuncExists(ctx, backend.model_create_with_nodes, "RAI_ModelCreateTF", "TF", + path)) { goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTF"); - if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTF")) { + if (!_ValidateFuncExists(ctx, backend.model_free, "RAI_ModelFreeTF", "TF", path)) { goto error; } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( unsigned long)dlsym(handle, "RAI_ModelRunTF"); - if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTF")) { + if (!_ValidateFuncExists(ctx, backend.model_run, "RAI_ModelRunTF", "TF", path)) { goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( (unsigned long)dlsym(handle, "RAI_ModelSerializeTF")); - if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTF")) { + if (!_ValidateFuncExists(ctx, backend.model_serialize, "RAI_ModelSerializeTF", "TF", path)) { goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTF"); - if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTF")) { + if (!_ValidateFuncExists(ctx, backend.get_version, "RAI_GetBackendVersionTF", "TF", path)) { goto error; } @@ -144,7 +159,6 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) { error: dlclose(handle); - RedisModule_Log(ctx, "warning", "TF backend not loaded from %s", path); return REDISMODULE_ERR; } @@ -166,39 +180,43 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( handle, "RAI_InitBackendTFLite"); - if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTFLite")) { + if (!_ValidateFuncExists(ctx, init_backend, "RAI_InitBackendTFLite", "TFLite", path)) { goto error; } + // Here we use the input callback to export functions from Redis to the backend, + // by setting the backend's function pointers to the corresponding functions in Redis. init_backend(RedisModule_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTFLite"); - if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateTFLite")) { + if (!_ValidateFuncExists(ctx, backend.model_create, "RAI_ModelCreateTFLite", "TFLite", path)) { goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTFLite"); - if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTFLite")) { + if (!_ValidateFuncExists(ctx, backend.model_free, "RAI_ModelFreeTFLite", "TFLite", path)) { goto error; } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( unsigned long)dlsym(handle, "RAI_ModelRunTFLite"); - if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTFLite")) { + if (!_ValidateFuncExists(ctx, backend.model_run, "RAI_ModelRunTFLite", "TFLite", path)) { goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeTFLite"); - if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTFLite")) { + if (!_ValidateFuncExists(ctx, backend.model_serialize, "RAI_ModelSerializeTFLite", "TFLite", + path)) { goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTFLite"); - if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTFLite")) { + if (!_ValidateFuncExists(ctx, backend.get_version, "RAI_GetBackendVersionTFLite", "TFLite", + path)) { goto error; } @@ -208,7 +226,6 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) { error: dlclose(handle); - RedisModule_Log(ctx, "warning", "TFLITE backend not loaded from %s", path); return REDISMODULE_ERR; } @@ -230,57 +247,61 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { int (*init_backend)(int (*)(const char *, void *)); init_backend = (int (*)(int (*)(const char *, void *)))(unsigned long)dlsym( handle, "RAI_InitBackendTorch"); - if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendTorch")) { + if (!_ValidateFuncExists(ctx, init_backend, "RAI_InitBackendTorch", "TORCH", path)) { goto error; } + // Here we use the input callback to export functions from Redis to the backend, + // by setting the backend's function pointers to the corresponding functions in Redis. init_backend(RedisModule_GetApi); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateTorch"); - if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateTorch")) { + if (!_ValidateFuncExists(ctx, backend.model_create, "RAI_ModelCreateTorch", "TORCH", path)) { goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeTorch"); - if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeTorch")) { + if (!_ValidateFuncExists(ctx, backend.model_free, "RAI_ModelFreeTorch", "TORCH", path)) { goto error; } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( unsigned long)dlsym(handle, "RAI_ModelRunTorch"); - if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunTorch")) { + if (!_ValidateFuncExists(ctx, backend.model_run, "RAI_ModelRunTorch", "TORCH", path)) { goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeTorch"); - if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeTorch")) { + if (!_ValidateFuncExists(ctx, backend.model_serialize, "RAI_ModelSerializeTorch", "TORCH", + path)) { goto error; } backend.script_create = (RAI_Script * (*)(const char *, const char *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ScriptCreateTorch"); - if (!_ValidateAPICreated(ctx, backend.script_create, "RAI_ScriptCreateTorch")) { + if (!_ValidateFuncExists(ctx, backend.script_create, "RAI_ScriptCreateTorch", "TORCH", path)) { goto error; } backend.script_free = (void (*)(RAI_Script *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ScriptFreeTorch"); - if (!_ValidateAPICreated(ctx, backend.script_free, "RAI_ScriptFreeTorch")) { + if (!_ValidateFuncExists(ctx, backend.script_free, "RAI_ScriptFreeTorch", "TORCH", path)) { goto error; } backend.script_run = (int (*)(RAI_Script *, const char *, RAI_ExecutionCtx *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ScriptRunTorch"); - if (!_ValidateAPICreated(ctx, backend.script_run, "RAI_ScriptRunTorch")) { + if (!_ValidateFuncExists(ctx, backend.script_run, "RAI_ScriptRunTorch", "TORCH", path)) { goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTorch"); - if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionTorch")) { + if (!_ValidateFuncExists(ctx, backend.get_version, "RAI_GetBackendVersionTorch", "TORCH", + path)) { goto error; } @@ -290,7 +311,6 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) { error: dlclose(handle); - RedisModule_Log(ctx, "warning", "TORCH backend not loaded from %s", path); return REDISMODULE_ERR; } @@ -311,83 +331,87 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) { int (*init_backend)(int (*)(const char *, void **)); init_backend = (int (*)(int (*)(const char *, void **)))(unsigned long)dlsym(handle, "RAI_InitBackendORT"); - if (!_ValidateAPICreated(ctx, init_backend, "RAI_InitBackendORT")) { + if (!_ValidateFuncExists(ctx, init_backend, "RAI_InitBackendORT", "ONNX", path)) { goto error; } - init_backend(RAI_GetApi); + // Here we use the input callback to export functions from Redis and RedisAI + // to the backend, by setting the backend's function pointers to the + // corresponding functions in Redis/RedisAI. + init_backend(RAI_ExportFunc); backend.model_create = (RAI_Model * (*)(RAI_Backend, const char *, RAI_ModelOpts, const char *, size_t, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelCreateORT"); - if (!_ValidateAPICreated(ctx, backend.model_create, "RAI_ModelCreateORT")) { + if (!_ValidateFuncExists(ctx, backend.model_create, "RAI_ModelCreateORT", "ONNX", path)) { goto error; } backend.model_free = (void (*)(RAI_Model *, RAI_Error *))(unsigned long)dlsym(handle, "RAI_ModelFreeORT"); - if (!_ValidateAPICreated(ctx, backend.model_free, "RAI_ModelFreeORT")) { + if (!_ValidateFuncExists(ctx, backend.model_free, "RAI_ModelFreeORT", "ONNX", path)) { goto error; } backend.model_run = (int (*)(RAI_Model * model, RAI_ExecutionCtx * *ectxs, RAI_Error * error))( unsigned long)dlsym(handle, "RAI_ModelRunORT"); - if (!_ValidateAPICreated(ctx, backend.model_run, "RAI_ModelRunORT")) { + if (!_ValidateFuncExists(ctx, backend.model_run, "RAI_ModelRunORT", "ONNX", path)) { goto error; } backend.model_serialize = (int (*)(RAI_Model *, char **, size_t *, RAI_Error *))( unsigned long)dlsym(handle, "RAI_ModelSerializeORT"); - if (!_ValidateAPICreated(ctx, backend.model_serialize, "RAI_ModelSerializeORT")) { + if (!_ValidateFuncExists(ctx, backend.model_serialize, "RAI_ModelSerializeORT", "ONNX", path)) { goto error; } backend.get_version = (const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionORT"); - if (!_ValidateAPICreated(ctx, backend.get_version, "RAI_GetBackendVersionORT")) { + if (!_ValidateFuncExists(ctx, backend.get_version, "RAI_GetBackendVersionORT", "ONNX", path)) { goto error; } backend.get_memory_info = (unsigned long long (*)(void))(unsigned long)dlsym(handle, "RAI_GetMemoryInfoORT"); - if (!_ValidateAPICreated(ctx, backend.get_memory_info, "RAI_GetMemoryInfoORT")) { + if (!_ValidateFuncExists(ctx, backend.get_memory_info, "RAI_GetMemoryInfoORT", "ONNX", path)) { goto error; } backend.get_memory_access_num = (unsigned long long (*)(void))(unsigned long)dlsym(handle, "RAI_GetMemoryAccessORT"); - if (!_ValidateAPICreated(ctx, backend.get_memory_access_num, "RAI_GetMemoryAccessORT")) { + if (!_ValidateFuncExists(ctx, backend.get_memory_access_num, "RAI_GetMemoryAccessORT", "ONNX", + path)) { goto error; } - backend.enforce_runtime_duration = + backend.stop_long_running_sessions_cb = (void (*)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *))(unsigned long)dlsym( handle, "RAI_EnforceTimeoutORT"); - if (!_ValidateAPICreated(ctx, backend.enforce_runtime_duration, "RAI_EnforceTimeoutORT")) { + if (!_ValidateFuncExists(ctx, backend.stop_long_running_sessions_cb, "RAI_EnforceTimeoutORT", + "ONNX", path)) { goto error; } - backend.add_new_device = + backend.add_new_device_cb = (int (*)(const char *))(unsigned long)dlsym(handle, "RAI_AddNewDeviceORT"); - if (!_ValidateAPICreated(ctx, backend.add_new_device, "RAI_AddNewDeviceORT")) { + if (!_ValidateFuncExists(ctx, backend.add_new_device_cb, "RAI_AddNewDeviceORT", "ONNX", path)) { goto error; } - backend.get_global_run_sessions_len = + backend.get_max_run_sessions = (size_t(*)(void))(unsigned long)dlsym(handle, "RAI_GetGlobalRunSessionsLenORT"); - if (!_ValidateAPICreated(ctx, backend.get_global_run_sessions_len, - "RAI_GetGlobalRunSessionsLenORT")) { + if (!_ValidateFuncExists(ctx, backend.get_max_run_sessions, "RAI_GetGlobalRunSessionsLenORT", + "ONNX", path)) { goto error; } RedisModule_SubscribeToServerEvent(ctx, RedisModuleEvent_CronLoop, - backend.enforce_runtime_duration); + backend.stop_long_running_sessions_cb); RAI_backends.onnx = backend; RedisModule_Log(ctx, "notice", "ONNX backend loaded from %s", path); return REDISMODULE_OK; error: dlclose(handle); - RedisModule_Log(ctx, "warning", "ONNX backend not loaded from %s", path); return REDISMODULE_ERR; } diff --git a/src/backends/backends.h b/src/backends/backends.h index eb9807156..0345f8c04 100644 --- a/src/backends/backends.h +++ b/src/backends/backends.h @@ -84,13 +84,13 @@ typedef struct RAI_LoadedBackend { unsigned long long (*get_memory_access_num)(void); // A callback for to use whenever a new device is introduced. - int (*add_new_device)(const char *); + int (*add_new_device_cb)(const char *); // Kill run session callback (for stopping long runs). - void (*enforce_runtime_duration)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *); + void (*stop_long_running_sessions_cb)(RedisModuleCtx *, RedisModuleEvent, uint64_t, void *); - // Get the length of the global run sessions array (with entry per working thread). - size_t (*get_global_run_sessions_len)(void); + // Get the number of maximum run sessions that can run. + size_t (*get_max_run_sessions)(void); } RAI_LoadedBackend; typedef struct RAI_LoadedBackends { diff --git a/src/backends/onnx_timeout.c b/src/backends/onnx_timeout.c index 7e2b218ce..2c70cbdd9 100644 --- a/src/backends/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -1,11 +1,11 @@ #include "onnx_timeout.h" #include "util/arr.h" -#include "util.h" #include "execution/utils.h" #include "config/config.h" #include #include "util/string_utils.h" #include "redis_ai_objects/stats.h" +#include "backedns_api.h" int RAI_InitGlobalRunSessionsORT() { onnx_global_run_sessions = RedisModule_Alloc(sizeof(OnnxGlobalRunSessions)); @@ -17,6 +17,8 @@ int RAI_InitGlobalRunSessionsORT() { array_new(OnnxRunSessionCtx *, RAI_working_threads_num); for (size_t i = 0; i < RAI_working_threads_num; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); + entry->runState = RedisModule_Calloc(1, sizeof(entry->runState)); + *entry->runState = RUN_SESSION_AVAILABLE; run_sessions_array = array_append(run_sessions_array, entry); } onnx_global_run_sessions->OnnxRunSessions = run_sessions_array; @@ -43,6 +45,8 @@ int RAI_AddNewDeviceORT(const char *device_str) { size_t size = RedisAI_GetNumThreadsPerQueue(); for (size_t i = 0; i < size; i++) { OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx)); + entry->runState = RedisModule_Calloc(1, sizeof(entry->runState)); + *entry->runState = RUN_SESSION_AVAILABLE; run_sessions_array = array_append(run_sessions_array, entry); } onnx_global_run_sessions->OnnxRunSessions = run_sessions_array; @@ -57,15 +61,19 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx **run_sessions_ctx = onnx_global_run_sessions->OnnxRunSessions; size_t len = array_len(run_sessions_ctx); + long long curr_time = mstime(); + long long timeout = RedisAI_GetModelExecutionTimeout(); for (size_t i = 0; i < len; i++) { - if (run_sessions_ctx[i]->runOptions == NULL) { - continue; - } - long long curr_time = mstime(); - long long timeout = RedisAI_GetModelExecutionTimeout(); - // Check if a sessions is running for too long, and kill it if so. + // Check if a sessions is running for too long, and kill it if is still active. if (curr_time - run_sessions_ctx[i]->queuingTime > timeout) { - ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions); + if (__sync_bool_compare_and_swap(run_sessions_ctx[i]->runState, RUN_SESSION_ACTIVE, + RUN_SESSION_INVALID)) { + // Set termination flag, validate that ONNX API succeeded (returns NULL) + RedisModule_Assert(ort->RunOptionsSetTerminate(run_sessions_ctx[i]->runOptions) == + NULL); + __atomic_store_n(run_sessions_ctx[i]->runState, RUN_SESSION_TERMINATED, + __ATOMIC_RELAXED); + } } } pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); @@ -77,11 +85,12 @@ void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session // Get the thread index (which is the correspondent index in the global sessions array). *run_session_index = RedisAI_GetThreadId(); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; - RedisModule_Assert(entry->runOptions == NULL); + RedisModule_Assert(*entry->runState == RUN_SESSION_AVAILABLE); // Update the entry with the current session data. entry->runOptions = new_run_options; entry->queuingTime = mstime(); + __atomic_store_n(entry->runState, RUN_SESSION_ACTIVE, __ATOMIC_RELAXED); pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } @@ -89,7 +98,15 @@ void RAI_InvalidateRunSessionCtxORT(size_t run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; + // Busy wait until we get a valid state, as we might access this entry from + // the main thread callback and call ONNX API to terminate the run session. + while (true) { + runSessionState state = __atomic_load_n(entry->runState, __ATOMIC_RELAXED); + if (state == RUN_SESSION_ACTIVE || state == RUN_SESSION_TERMINATED) { + break; + } + } ort->ReleaseRunOptions(entry->runOptions); - entry->runOptions = NULL; + __atomic_store_n(entry->runState, RUN_SESSION_AVAILABLE, __ATOMIC_RELAXED); pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } diff --git a/src/backends/onnx_timeout.h b/src/backends/onnx_timeout.h index 1bc43e638..d57ea448e 100644 --- a/src/backends/onnx_timeout.h +++ b/src/backends/onnx_timeout.h @@ -3,9 +3,17 @@ #include "backends/onnxruntime.h" #include "onnxruntime_c_api.h" +typedef enum { + RUN_SESSION_AVAILABLE, + RUN_SESSION_ACTIVE, + RUN_SESSION_TERMINATED, + RUN_SESSION_INVALID +} runSessionState; + typedef struct OnnxRunSessionCtx { long long queuingTime; OrtRunOptions *runOptions; + runSessionState *runState; } OnnxRunSessionCtx; // This is a global array of OnnxRunSessionCtx. Contains an entry for every thread diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index 6d979ccb5..e2ebdac26 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -3,14 +3,13 @@ #include "backends/util.h" #include #include -#include "execution/background_workers.h" #include -#include #include "util/arr.h" #include "backends/onnxruntime.h" #include "redis_ai_objects/tensor.h" #include "onnxruntime_c_api.h" +#include "backedns_api.h" // Use as a wrapper for ORT api call. If ORT api hasn't returned null, it has failed. // A label "error" must exist in every function that uses this macro. @@ -98,6 +97,7 @@ int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)) { get_api_fn("GetNumThreadsPerQueue", ((void **)&RedisAI_GetNumThreadsPerQueue)); get_api_fn("GetModelExecutionTimeout", ((void **)&RedisAI_GetModelExecutionTimeout)); get_api_fn("GetThreadsCount", ((void **)&RedisAI_GetThreadsCount)); + // Create a global array of onnx runSessions, with an entry for every working thread. RAI_InitGlobalRunSessionsORT(); diff --git a/src/backends/onnxruntime.h b/src/backends/onnxruntime.h index 2f7da40fc..d165af32c 100644 --- a/src/backends/onnxruntime.h +++ b/src/backends/onnxruntime.h @@ -9,11 +9,6 @@ unsigned long long RAI_GetMemoryInfoORT(void); unsigned long long RAI_GetMemoryAccessORT(void); -uintptr_t (*RedisAI_GetThreadId)(void); -uintptr_t (*RedisAI_GetThreadsCount)(void); -long long (*RedisAI_GetNumThreadsPerQueue)(void); -long long (*RedisAI_GetModelExecutionTimeout)(void); - int RAI_InitBackendORT(int (*get_api_fn)(const char *, void **)); RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, diff --git a/src/execution/parsing/deprecated.c b/src/execution/parsing/deprecated.c index e1d77bbb8..3d95bc42c 100644 --- a/src/execution/parsing/deprecated.c +++ b/src/execution/parsing/deprecated.c @@ -1,5 +1,3 @@ - -#include #include "deprecated.h" #include "rmutil/args.h" #include "backends/backends.h" @@ -7,11 +5,12 @@ #include "redis_ai_objects/stats.h" #include "execution/utils.h" +#include #include "execution/DAG/dag_builder.h" #include "execution/DAG/dag_execute.h" -#include "execution/background_workers.h" #include "execution/parsing/dag_parser.h" #include "execution/parsing/parse_utils.h" + #include "execution/execution_contexts/modelRun_ctx.h" #include "execution/execution_contexts/scriptRun_ctx.h" diff --git a/src/execution/run_queue_info.c b/src/execution/run_queue_info.c index 9e98dcb34..7ef559da6 100644 --- a/src/execution/run_queue_info.c +++ b/src/execution/run_queue_info.c @@ -15,9 +15,7 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { run_queue_info->device_str = RedisModule_Strdup(upper_device_str); pthread_cond_init(&(run_queue_info->queue_condition_var), NULL); pthread_mutex_init(&(run_queue_info->run_queue_mutex), NULL); - run_queue_info->threads = - (pthread_t *)RedisModule_Alloc(sizeof(pthread_t) * Config_GetNumThreadsPerQueue()); - + run_queue_info->threads = array_new(pthread_t, Config_GetNumThreadsPerQueue()); // Save device with its associate run queue info in the dictionary. if (AI_dictAdd(RunQueues, upper_device_str, run_queue_info) != DICT_OK) { RunQueue_Free(run_queue_info); @@ -26,17 +24,18 @@ RunQueueInfo *RunQueue_Create(const char *device_str) { // Create worker threads. for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { - if (pthread_create(&(run_queue_info->threads[i]), NULL, BGWorker_ThreadMain, - run_queue_info) != 0) { + pthread_t thread; + if (pthread_create(&thread, NULL, BGWorker_ThreadMain, run_queue_info) != 0) { AI_dictDelete(RunQueues, upper_device_str); RunQueue_Free(run_queue_info); return NULL; } + run_queue_info->threads = array_append(run_queue_info->threads, thread); } // Add the new device worker threads to onnx run sessions tracking. - if (RAI_backends.onnx.add_new_device) { - RAI_backends.onnx.add_new_device(device_str); + if (RAI_backends.onnx.add_new_device_cb) { + RAI_backends.onnx.add_new_device_cb(device_str); } return run_queue_info; } @@ -54,10 +53,7 @@ bool RunQueue_IsExists(const char *device_str) { size_t device_str_len = strlen(device_str); char upper_device_str[device_str_len + 1]; RAI_StringToUpper(device_str, upper_device_str, device_str_len + 1); - if (AI_dictFind(RunQueues, upper_device_str) == NULL) { - return false; - } - return true; + return AI_dictFind(RunQueues, upper_device_str) != NULL; } void RunQueue_Free(RunQueueInfo *run_queue_info) { @@ -66,7 +62,7 @@ void RunQueue_Free(RunQueueInfo *run_queue_info) { RedisModule_Free(run_queue_info->device_str); // Wait for workers to exit and free the pool. - for (int i = 0; i < Config_GetNumThreadsPerQueue(); i++) { + for (int i = 0; i < array_len(run_queue_info->threads); i++) { RedisModule_Assert(pthread_join(run_queue_info->threads[i], NULL) == 0); RedisModule_Free(run_queue_info->threads); } diff --git a/src/execution/run_queue_info.h b/src/execution/run_queue_info.h index abe1331cb..da898a267 100644 --- a/src/execution/run_queue_info.h +++ b/src/execution/run_queue_info.h @@ -9,6 +9,7 @@ #include "utils.h" #include "queue.h" +#include "dictionaries.h" AI_dict *RunQueues; diff --git a/src/execution/utils.c b/src/execution/utils.c index ea0fd501d..07b0704eb 100644 --- a/src/execution/utils.c +++ b/src/execution/utils.c @@ -1,5 +1,4 @@ #include "utils.h" -#include "background_workers.h" #include "redis_ai_objects/model.h" int redisMajorVersion; diff --git a/src/execution/utils.h b/src/execution/utils.h index fbe8ded3f..afeeb1663 100644 --- a/src/execution/utils.h +++ b/src/execution/utils.h @@ -1,10 +1,7 @@ #pragma once -#include "redismodule.h" -#include "redis_ai_objects/tensor_struct.h" -#include "redis_ai_objects/model_struct.h" -#include "redis_ai_objects/err.h" + #include -#include +#include "redismodule.h" /** Use this to check if a command is given a key whose hash slot is not on the current * shard, when using enterprise cluster. diff --git a/src/redisai.c b/src/redisai.c index 78cdfc35b..589566bbd 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -8,16 +8,14 @@ #include "redis_ai_objects/tensor.h" #include "execution/command_parser.h" #include "backends/backends.h" -#include "backends/util.h" -#include "execution/utils.h" -#include "execution/background_workers.h" #include "execution/DAG/dag.h" #include "execution/DAG/dag_builder.h" #include "execution/DAG/dag_execute.h" +#include "execution/utils.h" #include "execution/parsing/deprecated.h" -#include "redis_ai_objects/model.h" #include "execution/execution_contexts/modelRun_ctx.h" #include "execution/execution_contexts/scriptRun_ctx.h" +#include "redis_ai_objects/model.h" #include "redis_ai_objects/script.h" #include "redis_ai_objects/stats.h" #include @@ -26,7 +24,6 @@ #include #include #include -#include #include #include "rmutil/alloc.h" @@ -847,81 +844,13 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg return REDISMODULE_OK; } -void _RedisAI_Info(RedisModuleCtx *ctx) { - RedisModuleString *rai_version = RedisModule_CreateStringPrintf( - ctx, "%d.%d.%d", REDISAI_VERSION_MAJOR, REDISAI_VERSION_MINOR, REDISAI_VERSION_PATCH); - RedisModuleString *llapi_version = - RedisModule_CreateStringPrintf(ctx, "%d", REDISAI_LLAPI_VERSION); - RedisModuleString *rdb_version = RedisModule_CreateStringPrintf(ctx, "%llu", REDISAI_ENC_VER); - - int response_len = 6; - - if (RAI_backends.tf.get_version) { - response_len += 2; - } - - if (RAI_backends.torch.get_version) { - response_len += 2; - } - - if (RAI_backends.tflite.get_version) { - response_len += 2; - } - - if (RAI_backends.onnx.get_version) { - response_len += 2; - } - - RedisModule_ReplyWithArray(ctx, response_len); - - RedisModule_ReplyWithSimpleString(ctx, "Version"); - RedisModule_ReplyWithString(ctx, rai_version); - - // TODO: Add Git SHA - - RedisModule_ReplyWithSimpleString(ctx, "Low Level API Version"); - RedisModule_ReplyWithString(ctx, llapi_version); - - RedisModule_ReplyWithSimpleString(ctx, "RDB Encoding version"); - RedisModule_ReplyWithString(ctx, llapi_version); - - if (RAI_backends.tf.get_version) { - RedisModule_ReplyWithSimpleString(ctx, "TensorFlow version"); - RedisModule_ReplyWithSimpleString(ctx, RAI_backends.tf.get_version()); - } - - if (RAI_backends.torch.get_version) { - RedisModule_ReplyWithSimpleString(ctx, "Torch version"); - RedisModule_ReplyWithSimpleString(ctx, RAI_backends.torch.get_version()); - } - - if (RAI_backends.tflite.get_version) { - RedisModule_ReplyWithSimpleString(ctx, "TFLite version"); - RedisModule_ReplyWithSimpleString(ctx, RAI_backends.tflite.get_version()); - } - - if (RAI_backends.onnx.get_version) { - RedisModule_ReplyWithSimpleString(ctx, "ONNX version"); - RedisModule_ReplyWithSimpleString(ctx, RAI_backends.onnx.get_version()); - } - - RedisModule_FreeString(ctx, rai_version); - RedisModule_FreeString(ctx, llapi_version); - RedisModule_FreeString(ctx, rdb_version); -} - /** * AI.INFO [RESETSTAT] */ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { - if (argc > 3) + if (argc < 2 || argc > 3) return RedisModule_WrongArity(ctx); - if (argc == 1) { - _RedisAI_Info(ctx); - return REDISMODULE_OK; - } - RedisModuleString *runkey = argv[1]; struct RedisAI_RunStats *rstats = NULL; if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) { @@ -1211,7 +1140,48 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) { return REDISMODULE_OK; } +static void _moduleInfo_getBackendsInfo(RedisModuleInfoCtx *ctx) { + RedisModule_InfoAddSection(ctx, "backends_info"); + if (RAI_backends.tf.get_version) { + RedisModule_InfoAddFieldCString(ctx, "TensorFlow_version", + (char *)RAI_backends.tf.get_version()); + } + if (RAI_backends.tflite.get_version) { + RedisModule_InfoAddFieldCString(ctx, "TensorFlowLite_version", + (char *)RAI_backends.tflite.get_version()); + } + if (RAI_backends.torch.get_version) { + RedisModule_InfoAddFieldCString(ctx, "Torch_version", + (char *)RAI_backends.torch.get_version()); + } + if (RAI_backends.onnx.get_version) { + RedisModule_InfoAddFieldCString(ctx, "onnxruntime_version", + (char *)RAI_backends.onnx.get_version()); + RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", + RAI_backends.onnx.get_memory_info()); + RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory_access_num", + RAI_backends.onnx.get_memory_access_num()); + RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_maximum_run_sessions_number", + RAI_backends.onnx.get_max_run_sessions()); + } +} + void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { + RedisModule_InfoAddSection(ctx, "versions"); + RedisModuleString *rai_version = RedisModule_CreateStringPrintf( + NULL, "%d.%d.%d", REDISAI_VERSION_MAJOR, REDISAI_VERSION_MINOR, REDISAI_VERSION_PATCH); + RedisModuleString *llapi_version = + RedisModule_CreateStringPrintf(NULL, "%d", REDISAI_LLAPI_VERSION); + RedisModuleString *rdb_version = RedisModule_CreateStringPrintf(NULL, "%llu", REDISAI_ENC_VER); + + RedisModule_InfoAddFieldString(ctx, "RedisAI_version", rai_version); + RedisModule_InfoAddFieldString(ctx, "low_level_API_version", llapi_version); + RedisModule_InfoAddFieldString(ctx, "rdb_version", rdb_version); + + RedisModule_FreeString(NULL, rai_version); + RedisModule_FreeString(NULL, llapi_version); + RedisModule_FreeString(NULL, rdb_version); + RedisModule_InfoAddSection(ctx, "git"); RedisModule_InfoAddFieldCString(ctx, "git_sha", REDISAI_GIT_SHA); RedisModule_InfoAddSection(ctx, "load_time_configs"); @@ -1222,19 +1192,7 @@ void RAI_moduleInfoFunc(RedisModuleInfoCtx *ctx, int for_crash_report) { Config_GetBackendsIntraOpParallelism()); RedisModule_InfoAddFieldLongLong(ctx, "model_execution_timeout", Config_GetModelExecutionTimeout()); - RedisModule_InfoAddSection(ctx, "memory_usage"); - if (RAI_backends.onnx.get_memory_info) { - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", - RAI_backends.onnx.get_memory_info()); - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory_access_num", - RAI_backends.onnx.get_memory_access_num()); - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_global_run_sessions_length", - RAI_backends.onnx.get_global_run_sessions_len()); - } else { - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory", 0); - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_memory_access_num", 0); - RedisModule_InfoAddFieldULongLong(ctx, "onnxruntime_global_run_sessions_length", 0); - } + _moduleInfo_getBackendsInfo(ctx); struct rusage self_ru, c_ru; // Return resource usage statistics for the calling process, diff --git a/tests/flow/includes.py b/tests/flow/includes.py index bcbdf3219..6ca9bfe0a 100755 --- a/tests/flow/includes.py +++ b/tests/flow/includes.py @@ -202,6 +202,7 @@ def check_error_message(env, con, error_msg, *command): env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual(error_msg, str(exception)) + def check_error(env, con, *command): try: con.execute_command(*command) @@ -209,3 +210,11 @@ def check_error(env, con, *command): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) + + +# Returns a dict with all the fields of a certain section from INFO MODULES command +def get_info_section(con, section): + sections = ['ai_versions', 'ai_git', 'ai_load_time_configs', 'ai_backends_info', 'ai_cpu'] + section_ind = [i for i in range(len(sections)) if sections[i] == 'ai_'+section][0] + return {k.split(":")[0]: k.split(":")[1] + for k in con.execute_command("INFO MODULES").decode().split("#")[section_ind+2].split()[1:]} diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 9ace58f08..52d79b7c3 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -360,7 +360,3 @@ def test_lua_multi(env): env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual("Cannot run RedisAI command within a transaction or a LUA script", exception.__str__()) -def test_info(env): - con = env.getConnection() - ret = con.execute_command('AI.INFO') - env.assertEqual(6, len(ret)) diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index ba555f50d..dbd748d56 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -296,16 +296,14 @@ def tests_onnx_info(env): return con = env.getConnection() - ret = con.execute_command('AI.INFO') - env.assertEqual(6, len(ret)) + backends_info = get_info_section(con, 'backends_info') + env.assertFalse('ai_onnxruntime_version' in backends_info) linear_model = load_file_content('linear_iris.onnx') - con.execute_command('AI.MODELSTORE', 'linear{1}', 'ONNX', DEVICE, 'BLOB', linear_model) - - ret = con.execute_command('AI.INFO') - env.assertEqual(8, len(ret)) - env.assertEqual(b'ONNX version', ret[6]) + + backends_info = get_info_section(con, 'backends_info') + env.assertTrue('ai_onnxruntime_version' in backends_info) def test_parallelism(): @@ -328,14 +326,12 @@ def test_parallelism(): argmax = max(range(len(values)), key=lambda i: values[i]) env.assertEqual(argmax, 1) - load_time_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} + load_time_config = get_info_section(con, 'load_time_configs') env.assertEqual(load_time_config["ai_inter_op_parallelism"], "1") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "1") env = Env(moduleArgs='INTRA_OP_PARALLELISM 2 INTER_OP_PARALLELISM 2') - load_time_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} + load_time_config = get_info_section(con, 'load_time_configs') env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2") @@ -348,10 +344,6 @@ def test_onnx_use_custom_allocator(env): con = env.getConnection() model_pb = load_file_content('mul_1.onnx') - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory"]), 0) - # Expect using the allocator during model set for allocating the model, its input name and output name: # overall 3 allocations. The model raw size is 130B ,and the names are 2B each. In practice we allocate # more than 134B as Redis allocator will use additional memory for its internal management and for the @@ -359,13 +351,12 @@ def test_onnx_use_custom_allocator(env): # (hence will not use additional memory). ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'ONNX', 'CPU', 'BLOB', model_pb) env.assertEqual(ret, b'OK') - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} + backends_info = get_info_section(con, 'backends_info') # Expect using at least 130+63+(size of an address) + 2*(2+63+(size of an address)) bytes. - model_allocation_bytes_used = int(ai_memory_config["ai_onnxruntime_memory"]) + model_allocation_bytes_used = int(backends_info["ai_onnxruntime_memory"]) env.assertTrue(model_allocation_bytes_used > 334) - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 3) + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 3) con.execute_command('AI.TENSORSET', 'a_mul{1}', 'FLOAT', 3, 2, 'VALUES', 1.0, 2.0, 3.0, 4.0, 5.0, 6.0) # Running the model should access the allocator 6 times: allocating+freeing input+output names, @@ -373,18 +364,16 @@ def test_onnx_use_custom_allocator(env): con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a_mul{1}', 'OUTPUTS', 1, 'b{1}') values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES') env.assertEqual(values, [b'1', b'4', b'9', b'16', b'25', b'36']) - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 9) - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory"]), model_allocation_bytes_used) + backends_info = get_info_section(con, 'backends_info') + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 9) + env.assertEqual(int(backends_info["ai_onnxruntime_memory"]), model_allocation_bytes_used) # Expect using the allocator free function 3 times: when releasing the model, input name and output name. con.execute_command('AI.MODELDEL', 'm{1}') env.assertFalse(con.execute_command('EXISTS', 'm{1}')) - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory"]), 0) - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 12) + backends_info = get_info_section(con, 'backends_info') + env.assertEqual(int(backends_info["ai_onnxruntime_memory"]), 0) + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 12) # test the use of Redis allocator in model run op. model_pb = load_file_content('mnist.onnx') @@ -396,13 +385,11 @@ def test_onnx_use_custom_allocator(env): # Expect 18 allocator's access from onnx during the run (in addition to the allocations that were made while # creating the model). - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - allocator_access_num_before = ai_memory_config["ai_onnxruntime_memory_access_num"] + backends_info = get_info_section(con, 'backends_info') + allocator_access_num_before = backends_info["ai_onnxruntime_memory_access_num"] con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'b{1}') - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - allocator_access_num_after = ai_memory_config["ai_onnxruntime_memory_access_num"] + backends_info = get_info_section(con, 'backends_info') + allocator_access_num_after = backends_info["ai_onnxruntime_memory_access_num"] env.assertEqual(int(allocator_access_num_after) - int(allocator_access_num_before), 18) values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES') @@ -420,9 +407,6 @@ def test_onnx_use_custom_allocator_with_GPU(env): con = env.getConnection() model_pb = load_file_content('mul_1.onnx') - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory"]), 0) # Expect using the allocator during model set for allocating the model, its input name and output name: # overall 3 allocations. The model raw size is 130B ,and the names are 2B each. In practice we allocate @@ -434,14 +418,13 @@ def test_onnx_use_custom_allocator_with_GPU(env): # but for GPU, expect using the allocator only for allocating input and output names (not the model itself). ret = con.execute_command('AI.MODELSTORE', 'm_cpu{1}', 'ONNX', 'CPU', 'BLOB', model_pb) env.assertEqual(ret, b'OK') - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} + backends_info = get_info_section(con, 'backends_info') # Expect using at least 130+63+(size of an address) + 4*(2+63+(size of an address)) bytes. - model_allocation_bytes_used = int(ai_memory_config["ai_onnxruntime_memory"]) + model_allocation_bytes_used = int(backends_info["ai_onnxruntime_memory"]) env.assertTrue(model_allocation_bytes_used > 472) env.assertTrue(model_allocation_bytes_used < 705) - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 5) + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 5) # Make sure that allocator is not used for running and freeing the GPU model, except for # the input and output names allocations (and deallocations). @@ -451,17 +434,15 @@ def test_onnx_use_custom_allocator_with_GPU(env): env.assertEqual(values, [b'1', b'4', b'9', b'16', b'25', b'36']) # Expect that memory usage didn't change, and for another 4 accesses to the allocator (input and output names # allocation and free) - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory"]), model_allocation_bytes_used) - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 9) + backends_info = get_info_section(con, 'backends_info') + env.assertEqual(int(backends_info["ai_onnxruntime_memory"]), model_allocation_bytes_used) + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 9) # Expect only 2 more accesses in delete - for deallocating input and output names con.execute_command('AI.MODELDEL', 'm_gpu{1}') env.assertFalse(con.execute_command('EXISTS', 'm_gpu{1}')) - ai_memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - env.assertEqual(int(ai_memory_config["ai_onnxruntime_memory_access_num"]), 11) + backends_info = get_info_section(con, 'backends_info') + env.assertEqual(int(backends_info["ai_onnxruntime_memory_access_num"]), 11) class TestOnnxKillSwitch: @@ -516,15 +497,13 @@ def run_parallel_onnx_sessions(con): def test_multiple_devices(self): con = self.env.getConnection() - memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - self.env.assertEqual(memory_config['ai_onnxruntime_global_run_sessions_length'], '8') + backends_info = get_info_section(con, 'backends_info') + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '8') # Load another onnx model as if it runs on a different device (to test existence of multiple queues, and # the extension of the global onnx run sessions array as a consequence.) model_pb = load_file_content('mnist.onnx') ret = con.execute_command('AI.MODELSTORE', 'mnist_{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) self.env.assertEqual(ret, b'OK') - memory_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[4].split()[1:]} - self.env.assertEqual(memory_config['ai_onnxruntime_global_run_sessions_length'], '16') + backends_info = get_info_section(con, 'backends_info') + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index 92118b89a..d19347204 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -739,14 +739,12 @@ def test_parallelism(): ensureSlaveSynced(con, env) values = con.execute_command('AI.TENSORGET', 'c{1}', 'VALUES') env.assertEqual(values, [b'4', b'6', b'4', b'6']) - load_time_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} + load_time_config = get_info_section(con, 'load_time_configs') env.assertEqual(load_time_config["ai_inter_op_parallelism"], "1") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "1") env = Env(moduleArgs='INTRA_OP_PARALLELISM 2 INTER_OP_PARALLELISM 2') - load_time_config = {k.split(":")[0]: k.split(":")[1] - for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]} + load_time_config = get_info_section(con, 'load_time_configs') env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2") env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2") @@ -777,12 +775,11 @@ def test_torch_info(env): return con = env.getConnection() - ret = con.execute_command('AI.INFO') - env.assertEqual(6, len(ret)) + backends_info = get_info_section(con, 'backends_info') + env.assertFalse('ai_Torch_version' in backends_info) model_pb = load_file_content('pt-minimal-bb.pt') ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'TORCH', DEVICE, 'BLOB', model_pb) - ret = con.execute_command('AI.INFO') - env.assertEqual(8, len(ret)) - env.assertEqual(b'Torch version', ret[6]) + backends_info = get_info_section(con, 'backends_info') + env.assertTrue('ai_Torch_version' in backends_info) diff --git a/tests/flow/tests_tensorflow.py b/tests/flow/tests_tensorflow.py index d9d6581bd..ed3d268d4 100644 --- a/tests/flow/tests_tensorflow.py +++ b/tests/flow/tests_tensorflow.py @@ -691,14 +691,11 @@ def test_tensorflow_modelexecute_script_execute_resnet(env): def test_tf_info(env): con = env.getConnection() - ret = con.execute_command('AI.INFO') - env.assertEqual(6, len(ret)) - + backends_info = get_info_section(con, 'backends_info') + env.assertFalse('ai_TensorFlow_version' in backends_info) model_pb = load_file_content('graph.pb') - con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul', 'BLOB', model_pb) - - ret = con.execute_command('AI.INFO') - env.assertEqual(8, len(ret)) - env.assertEqual(b'TensorFlow version', ret[6]) + + backends_info = get_info_section(con, 'backends_info') + env.assertTrue('ai_TensorFlow_version' in backends_info) diff --git a/tests/flow/tests_tflite.py b/tests/flow/tests_tflite.py index 9db8be2e9..67a78ab7a 100644 --- a/tests/flow/tests_tflite.py +++ b/tests/flow/tests_tflite.py @@ -191,13 +191,11 @@ def test_tflite_info(env): return con = env.getConnection() - ret = con.execute_command('AI.INFO') - env.assertEqual(6, len(ret)) + backends_info = get_info_section(con, 'backends_info') + env.assertFalse('ai_TensorFlowLite_version' in backends_info) model_pb = load_file_content('mnist_model_quant.tflite') - con.execute_command('AI.MODELSTORE', 'mnist{1}', 'TFLITE', 'CPU', 'BLOB', model_pb) - ret = con.execute_command('AI.INFO') - env.assertEqual(8, len(ret)) - env.assertEqual(b'TFLite version', ret[6]) + backends_info = get_info_section(con, 'backends_info') + env.assertTrue('ai_TensorFlowLite_version' in backends_info) From 349653c98cc87a537ec041682401b07a5e4749e4 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 13 Jun 2021 18:27:13 +0300 Subject: [PATCH 23/27] Fix tests for the case that we run on GPU - since CPU queue always created from start, number of threads in this case should be larger. --- opt/readies | 2 +- tests/flow/tests_onnx.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/opt/readies b/opt/readies index 75459c614..951975156 160000 --- a/opt/readies +++ b/opt/readies @@ -1 +1 @@ -Subproject commit 75459c6142ac01ff82fa7b4646d9d574d177fa3d +Subproject commit 9519751566c8a335221265f2cdfed915edea954f diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index dbd748d56..34b5ad60d 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -455,8 +455,8 @@ def __init__(self): self.env.assertEqual(ret, b'OK') # Set tensors according to the model inputs. This model consists of two operations to type 'Identity' - # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this model - # runs a very large number of iterations without doing anything, until it is caught with the kill switch. + # (i.e., just output the input), where the second op is wrapped with another op of type 'Loop'. Overall, this + # model runs a very large number of iterations without doing anything, until it is caught with the kill switch. con.execute_command('AI.TENSORSET', 'iterations{1}', 'INT64', 1, 'VALUES', 9223372036854775807) con.execute_command('AI.TENSORSET', 'loop_cond{1}', 'BOOL', 1, 'VALUES', 1) con.execute_command('AI.TENSORSET', 'loop_input{1}', 'FLOAT', 1, 'VALUES', 42) @@ -498,7 +498,13 @@ def run_parallel_onnx_sessions(con): def test_multiple_devices(self): con = self.env.getConnection() backends_info = get_info_section(con, 'backends_info') - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '8') + + # CPU run queue is created from the start, so if we used a device different than CPU, we should + # have maximum of 2*THREADS_PER_QUEUE run sessions, and otherwise we should have THREADS_PER_QUEUE. + if DEVICE == 'CPU': + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '8') + else: + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') # Load another onnx model as if it runs on a different device (to test existence of multiple queues, and # the extension of the global onnx run sessions array as a consequence.) @@ -506,4 +512,7 @@ def test_multiple_devices(self): ret = con.execute_command('AI.MODELSTORE', 'mnist_{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) self.env.assertEqual(ret, b'OK') backends_info = get_info_section(con, 'backends_info') - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') + if DEVICE == 'CPU': + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') + else: + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '24') From af423a7986ecf4aea17d3f007e58e19c3bf93cb9 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 13 Jun 2021 18:48:20 +0300 Subject: [PATCH 24/27] Update readies --- opt/readies | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opt/readies b/opt/readies index 951975156..75459c614 160000 --- a/opt/readies +++ b/opt/readies @@ -1 +1 @@ -Subproject commit 9519751566c8a335221265f2cdfed915edea954f +Subproject commit 75459c6142ac01ff82fa7b4646d9d574d177fa3d From fd8c67265219f0f65b81503b9516cbe48e06ea1a Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 14 Jun 2021 11:14:00 +0300 Subject: [PATCH 25/27] PR fixes --- src/backends/backedns_api.h | 18 ++++++++++++++++++ src/backends/onnx_timeout.c | 14 +++++++------- src/backends/onnx_timeout.h | 18 +++++++++++++++--- src/backends/onnxruntime.c | 4 ++-- tests/flow/tests_common.py | 1 - tests/flow/tests_onnx.py | 21 ++++++++++----------- 6 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/backends/backedns_api.h b/src/backends/backedns_api.h index 13ed409b7..b0dbcff09 100644 --- a/src/backends/backedns_api.h +++ b/src/backends/backedns_api.h @@ -2,10 +2,28 @@ #include +/** + * @return The internal id of RedisAI current working thread. + * id range is {0, ..., -1} + */ uintptr_t (*RedisAI_GetThreadId)(void); +/** + * @return The number of working threads in RedisAI. This number should be + * equal to the number of threads per queue (load time config) * number of devices + * registered in RedisAI (a new device is registered if a model is set to run on + * this device in AI.MODELSTORE command. + */ uintptr_t (*RedisAI_GetThreadsCount)(void); +/** + * @return The number of working threads per device queue (load time config). + */ long long (*RedisAI_GetNumThreadsPerQueue)(void); +/** + * @return The maximal number of milliseconds that a model run session should run + * before it is terminated forcefully (load time config). + * Currently supported only fo onnxruntime backend. + */ long long (*RedisAI_GetModelExecutionTimeout)(void); diff --git a/src/backends/onnx_timeout.c b/src/backends/onnx_timeout.c index 2c70cbdd9..bf677275f 100644 --- a/src/backends/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -94,18 +94,18 @@ void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void RAI_InvalidateRunSessionCtxORT(size_t run_session_index) { +void RAI_ResetRunSessionCtxORT(size_t run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; + // Busy wait until we get a valid state, as we might access this entry from // the main thread callback and call ONNX API to terminate the run session. - while (true) { - runSessionState state = __atomic_load_n(entry->runState, __ATOMIC_RELAXED); - if (state == RUN_SESSION_ACTIVE || state == RUN_SESSION_TERMINATED) { - break; - } - } + RunSessionState state; + do { + state = __atomic_load_n(entry->runState, __ATOMIC_RELAXED); + } while (state != RUN_SESSION_ACTIVE && state != RUN_SESSION_TERMINATED); + ort->ReleaseRunOptions(entry->runOptions); __atomic_store_n(entry->runState, RUN_SESSION_AVAILABLE, __ATOMIC_RELAXED); pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); diff --git a/src/backends/onnx_timeout.h b/src/backends/onnx_timeout.h index d57ea448e..a8d934b13 100644 --- a/src/backends/onnx_timeout.h +++ b/src/backends/onnx_timeout.h @@ -3,17 +3,29 @@ #include "backends/onnxruntime.h" #include "onnxruntime_c_api.h" +/** + * The possible states for every run session entry in the array (entry per BG thread): + * Every is initialized as AVAILABLE, which means that it is ready to get a new run session. + * BG thread can perform a transition from AVAILABLE to ACTIVE upon starting a new run session. + * In the cron callback, Redis main thread can perform a transition from ACTIVE to + * INVALID if a timeout has reached, set the run session as terminated, and then make + * another transition to TERMINATED. + * At the end of a run session, the state is ACTIVE/TERMINATED, and then the BG thread + * reset the entry and make a transition back to AVAILABLE. + * Transition are done atomically to ensure right synchronization (BG thread cannot reset + * run session while main thread is setting it as terminated). + */ typedef enum { RUN_SESSION_AVAILABLE, RUN_SESSION_ACTIVE, RUN_SESSION_TERMINATED, RUN_SESSION_INVALID -} runSessionState; +} RunSessionState; typedef struct OnnxRunSessionCtx { long long queuingTime; OrtRunOptions *runOptions; - runSessionState *runState; + RunSessionState *runState; } OnnxRunSessionCtx; // This is a global array of OnnxRunSessionCtx. Contains an entry for every thread @@ -68,4 +80,4 @@ void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session * reset the corresponding entry in the global structure. * @param run_session_index - The entry index where OrtRunOptions was stored. */ -void RAI_InvalidateRunSessionCtxORT(size_t run_session_index); +void RAI_ResetRunSessionCtxORT(size_t run_session_index); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index e2ebdac26..eaad2598e 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -586,7 +586,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); - RAI_InvalidateRunSessionCtxORT(run_session_index); + RAI_ResetRunSessionCtxORT(run_session_index); run_options = NULL; for (uint32_t i = 0; i < ninputs; i++) { @@ -674,7 +674,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ort->ReleaseTensorTypeAndShapeInfo(info); } if (run_options) { - RAI_InvalidateRunSessionCtxORT(run_session_index); + RAI_ResetRunSessionCtxORT(run_session_index); } return REDISMODULE_ERR; } diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 52d79b7c3..892bec11d 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -359,4 +359,3 @@ def test_lua_multi(env): exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual("Cannot run RedisAI command within a transaction or a LUA script", exception.__str__()) - diff --git a/tests/flow/tests_onnx.py b/tests/flow/tests_onnx.py index 34b5ad60d..027b1cae1 100644 --- a/tests/flow/tests_onnx.py +++ b/tests/flow/tests_onnx.py @@ -448,7 +448,8 @@ def test_onnx_use_custom_allocator_with_GPU(env): class TestOnnxKillSwitch: def __init__(self): - self.env = Env(moduleArgs='THREADS_PER_QUEUE 8 MODEL_EXECUTION_TIMEOUT 1000') + self.threads_per_queue = 8 + self.env = Env(moduleArgs='THREADS_PER_QUEUE '+str(self.threads_per_queue)+' MODEL_EXECUTION_TIMEOUT 1000') con = self.env.getConnection() model_with_inf_loop = load_file_content("model_with_infinite_loop.onnx") ret = con.execute_command('AI.MODELSTORE', 'inf_loop_model{1}', 'ONNX', DEVICE, 'BLOB', model_with_inf_loop) @@ -497,22 +498,20 @@ def run_parallel_onnx_sessions(con): def test_multiple_devices(self): con = self.env.getConnection() - backends_info = get_info_section(con, 'backends_info') - # CPU run queue is created from the start, so if we used a device different than CPU, we should # have maximum of 2*THREADS_PER_QUEUE run sessions, and otherwise we should have THREADS_PER_QUEUE. - if DEVICE == 'CPU': - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '8') - else: - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') + devices = {'CPU', DEVICE} + backends_info = get_info_section(con, 'backends_info') + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], + str(len(devices)*self.threads_per_queue)) # Load another onnx model as if it runs on a different device (to test existence of multiple queues, and # the extension of the global onnx run sessions array as a consequence.) model_pb = load_file_content('mnist.onnx') ret = con.execute_command('AI.MODELSTORE', 'mnist_{1}', 'ONNX', 'CPU:1', 'BLOB', model_pb) self.env.assertEqual(ret, b'OK') + devices.add('CPU:1') + backends_info = get_info_section(con, 'backends_info') - if DEVICE == 'CPU': - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '16') - else: - self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], '24') + self.env.assertEqual(backends_info['ai_onnxruntime_maximum_run_sessions_number'], + str(len(devices)*self.threads_per_queue)) From 78da23ea536ce424d69a8158876f1949ff495bf6 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 14 Jun 2021 15:48:19 +0300 Subject: [PATCH 26/27] Return error if onnx is executed in a non async manner (via gears for instance). test info command with AI fields --- src/backends/backedns_api.h | 5 ++- src/backends/onnx_timeout.c | 11 +++-- src/backends/onnx_timeout.h | 4 +- src/backends/onnxruntime.c | 13 +++++- src/execution/background_workers.c | 12 +++++- src/execution/background_workers.h | 5 ++- tests/flow/tests_common.py | 29 +++++++++----- tests/flow/tests_gears_llapi.py | 64 ++++++++++++++++++++++++++++++ 8 files changed, 119 insertions(+), 24 deletions(-) diff --git a/src/backends/backedns_api.h b/src/backends/backedns_api.h index b0dbcff09..875fed7d5 100644 --- a/src/backends/backedns_api.h +++ b/src/backends/backedns_api.h @@ -4,9 +4,10 @@ /** * @return The internal id of RedisAI current working thread. - * id range is {0, ..., -1} + * id range is {0, ..., -1}. If this is called from a non + * RedisAI BG thread, return -1. */ -uintptr_t (*RedisAI_GetThreadId)(void); +long (*RedisAI_GetThreadId)(void); /** * @return The number of working threads in RedisAI. This number should be diff --git a/src/backends/onnx_timeout.c b/src/backends/onnx_timeout.c index bf677275f..6e3a715ad 100644 --- a/src/backends/onnx_timeout.c +++ b/src/backends/onnx_timeout.c @@ -79,11 +79,16 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session_index) { +void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_session_index) { pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); - // Get the thread index (which is the correspondent index in the global sessions array). + // Get the thread id (which is the correspondent index in the global sessions array + 1). + // if thread id is -1, we are not running from RedisAI thread (not allowed) *run_session_index = RedisAI_GetThreadId(); + if (*run_session_index == -1) { + pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); + return; + } OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index]; RedisModule_Assert(*entry->runState == RUN_SESSION_AVAILABLE); @@ -94,7 +99,7 @@ void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock)); } -void RAI_ResetRunSessionCtxORT(size_t run_session_index) { +void RAI_ResetRunSessionCtxORT(long run_session_index) { const OrtApi *ort = OrtGetApiBase()->GetApi(1); pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock)); OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index]; diff --git a/src/backends/onnx_timeout.h b/src/backends/onnx_timeout.h index a8d934b13..63935fe83 100644 --- a/src/backends/onnx_timeout.h +++ b/src/backends/onnx_timeout.h @@ -73,11 +73,11 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s * @param run_session_index - placeholder for the index of the running thread * in the global array, to have a quick access later to clean this entry. */ -void RAI_SetRunSessionCtxORT(OrtRunOptions *new_run_options, size_t *run_session_index); +void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_session_index); /** * @brief Release the OrtRunOptions of a session that finished its run and * reset the corresponding entry in the global structure. * @param run_session_index - The entry index where OrtRunOptions was stored. */ -void RAI_ResetRunSessionCtxORT(size_t run_session_index); +void RAI_ResetRunSessionCtxORT(long run_session_index); diff --git a/src/backends/onnxruntime.c b/src/backends/onnxruntime.c index eaad2598e..c6aef1310 100644 --- a/src/backends/onnxruntime.c +++ b/src/backends/onnxruntime.c @@ -532,7 +532,7 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error array_new_on_stack(OrtValue *, 5, inputs); array_new_on_stack(OrtValue *, 5, outputs); OrtRunOptions *run_options = NULL; - size_t run_session_index; + long run_session_index; OrtTensorTypeAndShapeInfo *info = NULL; { size_t n_input_nodes; @@ -582,7 +582,16 @@ int RAI_ModelRunORT(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error ONNX_VALIDATE_STATUS(ort->CreateRunOptions(&run_options)); // Set the created run option in the global RunSessions and save its index. - RAI_SetRunSessionCtxORT(run_options, &run_session_index); + RAI_ActivateRunSessionCtxORT(run_options, &run_session_index); + if (run_session_index == -1) { + RAI_SetError( + error, RAI_EMODELRUN, + "Cannot execute onnxruntime model synchronously, use async execution instead"); + ort->ReleaseRunOptions(run_options); + run_options = NULL; + goto error; + } + ONNX_VALIDATE_STATUS(ort->Run(session, run_options, input_names, (const OrtValue *const *)inputs, n_input_nodes, output_names, n_output_nodes, outputs)); diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index f318027f4..ce059212b 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -29,8 +29,9 @@ pthread_key_t ThreadIdKey; // Key to hold thread id in its local storage. */ static void _BGWorker_SaveThreadId() { // Let the current thread have the next available id, and increase the counter. - uintptr_t id_value = __atomic_fetch_add(&BGWorkersCounter, 1, __ATOMIC_RELAXED); + long id_value = __atomic_add_fetch(&BGWorkersCounter, 1, __ATOMIC_RELAXED); // Convert the id value to a pointer and store it the thread local storage. + // First id is 1, so we won't confuse with NULL (which is the error return value) pthread_setspecific(ThreadIdKey, (const void *)id_value); } @@ -252,7 +253,14 @@ static bool _BGThread_PrepareExecution(RunQueueInfo *run_queue_info, RedisAI_Run return true; } -uintptr_t BGWorker_GetThreadId() { return (uintptr_t)pthread_getspecific(ThreadIdKey); } +long BGWorker_GetThreadId() { + void *thread_id = pthread_getspecific(ThreadIdKey); + if (thread_id == NULL) { + return -1; + } + // Return the as 0 based id. + return (long)pthread_getspecific(ThreadIdKey) - 1; +} uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCounter; } diff --git a/src/execution/background_workers.h b/src/execution/background_workers.h index f3dbf4ce6..b6e8eacb6 100644 --- a/src/execution/background_workers.h +++ b/src/execution/background_workers.h @@ -37,9 +37,10 @@ void *BGWorker_ThreadMain(void *arg); /** - * @brief Returns the thread id (among RedisAI working threads). + * @brief Returns the thread id (among RedisAI working threads). If this is called + * form a non RedisAI working thread, return -1 */ -uintptr_t BGWorker_GetThreadId(void); +long BGWorker_GetThreadId(void); /** * @brief Returns the total number of RedisAI working threads (for all devices). diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 892bec11d..e1a7b6481 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -328,17 +328,6 @@ def test_tensorget_disconnect(env): ret = send_and_disconnect(('AI.TENSORGET', 't_FLOAT', 'META'), red) env.assertEqual(ret, None) -def test_info_modules(env): - red = env.getConnection() - ret = red.execute_command('INFO','MODULES') - env.assertEqual( ret['ai_threads_per_queue'], 1 ) - # minimum cpu properties - env.assertEqual( 'ai_self_used_cpu_sys' in ret, True ) - env.assertEqual( 'ai_self_used_cpu_user' in ret, True ) - env.assertEqual( 'ai_children_used_cpu_sys' in ret, True ) - env.assertEqual( 'ai_children_used_cpu_user' in ret, True ) - env.assertEqual( 'ai_queue_CPU_bthread_n1_used_cpu_total' in ret, True ) - def test_lua_multi(env): con = env.getConnection() @@ -359,3 +348,21 @@ def test_lua_multi(env): exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual("Cannot run RedisAI command within a transaction or a LUA script", exception.__str__()) + + +def test_info_command(env): + con = env.getConnection() + versions = get_info_section(con, 'versions') + env.assertEqual(list(versions.keys()), ['ai_RedisAI_version', 'ai_low_level_API_version', 'ai_rdb_version']) + git = get_info_section(con, 'git') + env.assertEqual(list(git.keys()), ['ai_git_sha']) + load_time_configs = get_info_section(con, 'load_time_configs') + env.assertEqual(list(load_time_configs.keys()), ['ai_threads_per_queue', 'ai_inter_op_parallelism', + 'ai_intra_op_parallelism', 'ai_model_execution_timeout']) + # minimum cpu properties + cpu = get_info_section(con, 'cpu') + env.assertTrue('ai_self_used_cpu_sys' in cpu.keys()) + env.assertTrue('ai_self_used_cpu_user' in cpu.keys()) + env.assertTrue('ai_children_used_cpu_sys' in cpu.keys()) + env.assertTrue('ai_children_used_cpu_user' in cpu.keys()) + env.assertTrue('ai_queue_CPU_bthread_n1_used_cpu_total' in cpu.keys()) diff --git a/tests/flow/tests_gears_llapi.py b/tests/flow/tests_gears_llapi.py index fd08dab59..7db9c9bc0 100644 --- a/tests/flow/tests_gears_llapi.py +++ b/tests/flow/tests_gears_llapi.py @@ -421,3 +421,67 @@ def FlattenTensor(record): env.assertEqual(ret, b'OK') ret = con.execute_command('rg.trigger', 'FlattenTensor_test') env.assertEqual(ret[0], b'test_OK') + + +class TestExecuteOnnxModel: + + def __init__(self): + self.env = Env() + if not verify_gears_loaded(self.env): + self.env.skip() + return + script = ''' + +import redisAI + +def OnnxModelRunSync(record): + input_tensor = redisAI.getTensorFromKey('mnist_input{1}') + modelRunner = redisAI.createModelRunner('mnist{1}') + redisAI.modelRunnerAddInput(modelRunner, 'input_name', input_tensor) + redisAI.modelRunnerAddOutput(modelRunner, 'output_name') + try: + res = redisAI.modelRunnerRun(modelRunner) + except Exception as e: + raise e + +async def OnnxModelRunAsync(record): + input_tensor = redisAI.getTensorFromKey('mnist_input{1}') + modelRunner = redisAI.createModelRunner('mnist{1}') + redisAI.modelRunnerAddInput(modelRunner, 'input_name', input_tensor) + redisAI.modelRunnerAddOutput(modelRunner, 'output_name') + res = await redisAI.modelRunnerRunAsync(modelRunner) + redisAI.setTensorInKey('mnist_output{1}', res[0]) + return "OnnxModelRun_OK" + +GB("CommandReader").map(OnnxModelRunSync).register(trigger="OnnxModelRunSync_test1") +GB("CommandReader").map(OnnxModelRunAsync).register(trigger="OnnxModelRunAsync_test2") + ''' + + con = self.env.getConnection() + ret = con.execute_command('rg.pyexecute', script) + self.env.assertEqual(ret, b'OK') + + # Load onnx model and its input. + model_pb = load_file_content('mnist.onnx') + sample_raw = load_file_content('one.raw') + ret = con.execute_command('AI.MODELSTORE', 'mnist{1}', 'ONNX', DEVICE, 'BLOB', model_pb) + self.env.assertEqual(ret, b'OK') + con.execute_command('AI.TENSORSET', 'mnist_input{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw) + + def test_sync_run_error(self): + con = self.env.getConnection() + try: + con.execute_command('rg.trigger', 'OnnxModelRunSync_test1') + self.env.assertFalse(True) + except Exception as exception: + self.env.assertEqual(type(exception), redis.exceptions.ResponseError) + self.env.assertTrue(str(exception).find("Cannot execute onnxruntime model synchronously, " + "use async execution instead") >= 0) + + def test_async_run(self): + con = self.env.getConnection() + ret = con.execute_command('rg.trigger', 'OnnxModelRunAsync_test2') + self.env.assertEqual(ret[0], b'OnnxModelRun_OK') + values = con.execute_command('AI.TENSORGET', 'mnist_output{1}', 'VALUES') + argmax = max(range(len(values)), key=lambda i: values[i]) + self.env.assertEqual(argmax, 1) From 285b5bee2049b6ded47efadaa061092218063ec8 Mon Sep 17 00:00:00 2001 From: alonre24 Date: Tue, 15 Jun 2021 18:13:07 +0300 Subject: [PATCH 27/27] Small refactor in get_thread_id function. --- src/execution/background_workers.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/execution/background_workers.c b/src/execution/background_workers.c index ce059212b..1c7c9d3e1 100644 --- a/src/execution/background_workers.c +++ b/src/execution/background_workers.c @@ -255,11 +255,10 @@ static bool _BGThread_PrepareExecution(RunQueueInfo *run_queue_info, RedisAI_Run long BGWorker_GetThreadId() { void *thread_id = pthread_getspecific(ThreadIdKey); - if (thread_id == NULL) { - return -1; - } - // Return the as 0 based id. - return (long)pthread_getspecific(ThreadIdKey) - 1; + + // Return the 0 based id, if thread_id was NULL, we return -1 to indicates that + // the caller is not RedisAI thread. + return (long)(thread_id)-1; } uintptr_t BGWorker_GetThreadsCount() { return BGWorkersCounter; }