diff --git a/src/DAG/dag.c b/src/DAG/dag.c index ad6f84d95..0205f9c8d 100644 --- a/src/DAG/dag.c +++ b/src/DAG/dag.c @@ -55,16 +55,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre RAI_Tensor *inputTensors[n_inkeys]; for (uint i = 0; i < n_inkeys; i++) { - RAI_Tensor *inputTensor; - const int get_result = RAI_getTensorFromLocalContext( - rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); - if (get_result == REDISMODULE_ERR) { - // We check for this outside the function - // this check cannot be covered by tests - currentOp->result = REDISMODULE_ERR; - RAI_ContextUnlock(rinfo); - return; - } + RAI_Tensor *inputTensor = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[i]); inputTensors[i] = inputTensor; } @@ -95,78 +86,58 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber); - tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; - AI_dictReplace(rinfo->dagTensorsContext, (void *)currentOp->outkeys[outputNumber], tensor); + Dag_SetTensorInGlobalCtx(rinfo, currentOp->outkeys_indices[outputNumber], tensor); } RAI_ContextUnlock(rinfo); } static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor, - RedisModuleString *persist_key_name, bool mangled_name) { + RedisModuleString *persist_key_name, RAI_Error *err) { - int ret = REDISMODULE_ERR; RedisModuleKey *key; - size_t persist_key_len; - const char *persist_key_str = RedisModule_StringPtrLen(persist_key_name, &persist_key_len); - - RedisModuleString *demangled_key_name; - if (mangled_name) { - demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len - 4); - } else { - demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len); - } - const int status = - RAI_OpenKey_Tensor(ctx, demangled_key_name, &key, REDISMODULE_READ | REDISMODULE_WRITE); + RAI_OpenKey_Tensor(ctx, persist_key_name, &key, REDISMODULE_READ | REDISMODULE_WRITE, err); if (status == REDISMODULE_ERR) { - RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); - goto clean_up; + return REDISMODULE_ERR; } if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) != REDISMODULE_OK) { - RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); + RAI_SetError(err, RAI_EDAGRUN, "ERR could not save tensor"); RedisModule_CloseKey(key); - goto clean_up; + return REDISMODULE_ERR; } // Only if we got until here, tensor is saved in keyspace. - RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor); + RedisAI_ReplicateTensorSet(ctx, persist_key_name, tensor); RedisModule_CloseKey(key); - ret = REDISMODULE_OK; - -clean_up: - RedisModule_FreeString(NULL, demangled_key_name); - return ret; + return REDISMODULE_OK; } -static void _DAG_PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { +static int _DAG_PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { - AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext); - AI_dictEntry *persist_entry = AI_dictNext(persist_iter); + AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->persistTensors); + AI_dictEntry *persist_entry; - while (persist_entry) { + while ((persist_entry = AI_dictNext(persist_iter))) { RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry); - AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name); - RedisModule_Assert(tensor_entry); - RAI_Tensor *tensor = AI_dictGetVal(tensor_entry); - if (tensor == NULL) { - persist_entry = AI_dictNext(persist_iter); - continue; - } + size_t index = (size_t)AI_dictGetVal(persist_entry); + RAI_Tensor *tensor = Dag_GetTensorFromGlobalCtx(rinfo, index); tensor = RAI_TensorGetShallowCopy(tensor); - if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, true) == REDISMODULE_ERR) { + + if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, rinfo->err) == REDISMODULE_ERR) { *rinfo->dagError = 1; RedisModule_Log(ctx, "warning", "Could not persist tensor under the key (%s) after executing DAGRUN " "command, persist stopped", RedisModule_StringPtrLen(persist_key_name, NULL)); AI_dictReleaseIterator(persist_iter); - return; + rinfo->dagReplyLength++; + return REDISMODULE_ERR; } - persist_entry = AI_dictNext(persist_iter); } AI_dictReleaseIterator(persist_iter); + return REDISMODULE_OK; } -static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { +static int _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op, RAI_Error *err) { const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { @@ -176,18 +147,19 @@ static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { if (!tensor) continue; - if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, false) == REDISMODULE_ERR) { + if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, err) == REDISMODULE_ERR) { RedisModule_Log(ctx, "warning", "Could not persist tensor under the key (%s) after executing DAGRUN " "command, persist stopped", RedisModule_StringPtrLen(persist_key_name, NULL)); op->result = REDISMODULE_ERR; - return; + return REDISMODULE_ERR; } } + return REDISMODULE_OK; } -static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { +static int _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op, RAI_Error *err) { const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { @@ -197,15 +169,16 @@ static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { if (!tensor) continue; - if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, false) == REDISMODULE_ERR) { + if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, err) == REDISMODULE_ERR) { RedisModule_Log(ctx, "warning", "Could not persist tensor under the key (%s) after executing DAGRUN " "command, persist stopped", RedisModule_StringPtrLen(persist_key_name, NULL)); op->result = REDISMODULE_ERR; - return; + return REDISMODULE_ERR; } } + return REDISMODULE_OK; } /** @@ -272,15 +245,13 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo, for (int i = 0; i < n_rinfo; i++) { RedisAI_RunInfo *rinfo = batched_rinfo[i]; RAI_DagOp *currentOp = currentOps[i]; + currentOp->duration_us = duration; + currentOp->result = result; if (result == REDISMODULE_ERR) { - currentOp->result = result; RAI_SetError(currentOp->err, err.code, err.detail); continue; } - - currentOp->duration_us = duration; - currentOp->result = result; if (rinfo->single_op_dag == 0) Dag_StoreOutputsFromModelRunCtx(rinfo, currentOp); } @@ -302,24 +273,14 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur uint n_outkeys = array_len(currentOp->outkeys); if (!rinfo->single_op_dag) { - RAI_ContextReadLock(rinfo); RAI_Tensor *inputTensors[n_inkeys]; for (uint i = 0; i < n_inkeys; i++) { - RAI_Tensor *inputTensor; - const int get_result = RAI_getTensorFromLocalContext( - rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); - if (get_result == REDISMODULE_ERR) { - // We check for this outside the function - // this check cannot be covered by tests - currentOp->result = REDISMODULE_ERR; - RAI_ContextUnlock(rinfo); - return; - } + RAI_Tensor *inputTensor = + Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[i]); inputTensors[i] = inputTensor; } RAI_ContextUnlock(rinfo); - for (uint i = 0; i < n_inkeys; i++) { RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err); } @@ -334,16 +295,16 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur currentOp->result = result; currentOp->duration_us = end - start; + if (result != REDISMODULE_OK) { + return; + } if (!rinfo->single_op_dag) { - RAI_ContextWriteLock(rinfo); const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber); - RedisModuleString *key_string = currentOp->outkeys[outputNumber]; - tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; - AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); + Dag_SetTensorInGlobalCtx(rinfo, currentOp->outkeys_indices[outputNumber], tensor); } RAI_ContextUnlock(rinfo); } @@ -357,20 +318,14 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) { size_t ninputs = array_len(op->inkeys); int batchsize = 0; - if (!rinfo->single_device_dag) { - RAI_ContextReadLock(rinfo); - } + RAI_ContextReadLock(rinfo); for (size_t i = 0; i < ninputs; i++) { RAI_Tensor *input; if (rinfo->single_op_dag) { input = op->mctx->inputs[i].tensor; } else { - RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, op->inkeys[i], &input, op->err); + input = Dag_GetTensorFromGlobalCtx(rinfo, op->inkeys_indices[i]); } - // We are expecting input != NULL, because we only reach this function if all inputs - // are available in context for the current dagOp. We could be more defensive - // eventually. - assert(input != NULL); if (i == 0) { batchsize = RAI_TensorDim(input, 0); @@ -381,78 +336,59 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) { break; } } - if (!rinfo->single_device_dag) { - RAI_ContextUnlock(rinfo); - } + RAI_ContextUnlock(rinfo); + return batchsize; } -int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2, - RedisAI_RunInfo *rinfo2) { +bool RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2, + RedisAI_RunInfo *rinfo2) { - if (op1->mctx == NULL || op2->mctx == NULL) { - return 0; - } + RedisModule_Assert(op1->mctx && op2->mctx); if (op1->mctx->model != op2->mctx->model) { - return 0; + return false; } const int ninputs1 = array_len(op1->inkeys); const int ninputs2 = array_len(op2->inkeys); - if (ninputs1 != ninputs2) { - return 0; - } - if (!rinfo1->single_device_dag) { - RAI_ContextReadLock(rinfo1); - } - if (!rinfo2->single_device_dag) { - RAI_ContextReadLock(rinfo2); + return false; } + + RAI_ContextReadLock(rinfo1); + RAI_ContextReadLock(rinfo2); for (int i = 0; i < ninputs1; i++) { RAI_Tensor *input1; - if (rinfo1->single_op_dag == 1) { + if (rinfo1->single_op_dag) { input1 = op1->mctx->inputs[i].tensor; } else { - RAI_getTensorFromLocalContext(rinfo1->dagTensorsContext, op1->inkeys[i], &input1, - op1->err); + input1 = Dag_GetTensorFromGlobalCtx(rinfo1, op1->inkeys_indices[i]); } RAI_Tensor *input2; - if (rinfo2->single_op_dag == 1) { + if (rinfo2->single_op_dag) { input2 = op2->mctx->inputs[i].tensor; } else { - RAI_getTensorFromLocalContext(rinfo2->dagTensorsContext, op2->inkeys[i], &input2, - op2->err); - } - if (input1 == NULL || input2 == NULL) { - return 0; + input2 = Dag_GetTensorFromGlobalCtx(rinfo2, op2->inkeys_indices[i]); } - int ndims1 = RAI_TensorNumDims(input1); int ndims2 = RAI_TensorNumDims(input2); - if (ndims1 != ndims2) { - return 0; + return false; } - if (ndims1 == 0) { continue; } - for (int j = 1; j < ndims1; j++) { long long dim1 = RAI_TensorDim(input1, j); long long dim2 = RAI_TensorDim(input2, j); if (dim1 != dim2) { - return 0; + return false; } } } - if (!rinfo1->single_device_dag) { - RAI_ContextUnlock(rinfo1); - } - if (!rinfo2->single_device_dag) { - RAI_ContextUnlock(rinfo2); - } - return 1; + RAI_ContextUnlock(rinfo1); + RAI_ContextUnlock(rinfo2); + + return true; } bool RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo) { @@ -479,28 +415,25 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) { void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, bool *currentOpReady, bool *currentOpBatchable) { - RAI_DagOp *currentOp_ = RedisAI_DagCurrentOp(rinfo); + RAI_DagOp *currentOp_ = RedisAI_DagCurrentOp(rinfo); *currentOpReady = false; *currentOpBatchable = false; - - if (currentOp_ == NULL) { - return; - } + RedisModule_Assert(currentOp_); if (currentOp_->mctx && currentOp_->mctx->model->opts.batchsize > 0) { *currentOpBatchable = true; } *currentOpReady = true; // If this is a single op dag, the op is definitely ready. - if (rinfo->single_op_dag == 1) + if (rinfo->single_op_dag) return; uint n_inkeys = array_len(currentOp_->inkeys); RAI_ContextReadLock(rinfo); for (int i = 0; i < n_inkeys; i++) { - if (AI_dictFind(rinfo->dagTensorsContext, currentOp_->inkeys[i]) == NULL) { + if (Dag_GetTensorFromGlobalCtx(rinfo, currentOp_->inkeys_indices[i]) == NULL) { RAI_ContextUnlock(rinfo); *currentOpReady = false; return; @@ -530,8 +463,7 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI *inbatchsize = 0; if (op2->mctx) { - int match = RAI_DagOpBatchable(op1, rinfo1, op2, rinfo2); - + bool match = RAI_DagOpBatchable(op1, rinfo1, op2, rinfo2); if (match) { *batched = 1; *inbatchsize = RAI_DagOpBatchSize(op2, rinfo2); @@ -539,6 +471,17 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI } } +RAI_Tensor *Dag_GetTensorFromGlobalCtx(RedisAI_RunInfo *rinfo, size_t index) { + RedisModule_Assert(index < array_len(rinfo->dagSharedTensors)); + return rinfo->dagSharedTensors[index]; +} + +void Dag_SetTensorInGlobalCtx(RedisAI_RunInfo *rinfo, size_t index, RAI_Tensor *t) { + RedisModule_Assert(index < array_len(rinfo->dagSharedTensors)); + RedisModule_Assert(rinfo->dagSharedTensors[index] == NULL); + rinfo->dagSharedTensors[index] = RAI_TensorGetShallowCopy(t); +} + void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) { RAI_DagOp *currentOp = RedisAI_DagCurrentOp(rinfo); @@ -646,15 +589,8 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc case REDISAI_DAG_CMD_TENSORGET: { rinfo->dagReplyLength++; - RAI_Tensor *t; - int res = RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, currentOp->inkeys[0], - &t, currentOp->err); - if (res != REDISMODULE_OK) { - RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline); - dag_error = 1; - } else { - ReplyWithTensor(ctx, currentOp->fmt, t); - } + RAI_Tensor *t = Dag_GetTensorFromGlobalCtx(rinfo, currentOp->inkeys_indices[0]); + ReplyWithTensor(ctx, currentOp->fmt, t); break; } @@ -700,26 +636,29 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc } break; } - default: /* no-op */ break; } } - if (dag_error) { goto cleanup; } + + int persist_status; if (!rinfo->single_op_dag) { - _DAG_PersistTensors(ctx, rinfo); + persist_status = _DAG_PersistTensors(ctx, rinfo); } else { if (rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_MODELRUN) { - _ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]); + persist_status = _ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0], rinfo->err); } else { RedisModule_Assert(rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN); - _ScriptSingleOp_PersistTensors(ctx, rinfo->dagOps[0]); + persist_status = _ScriptSingleOp_PersistTensors(ctx, rinfo->dagOps[0], rinfo->err); } } + if (persist_status != REDISMODULE_OK) { + RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(rinfo->err)); + } cleanup: if (!rinfo->single_op_dag) { diff --git a/src/DAG/dag.h b/src/DAG/dag.h index 76eecb14a..f980183cd 100644 --- a/src/DAG/dag.h +++ b/src/DAG/dag.h @@ -88,6 +88,24 @@ void RedisAI_DagOpBatchInfo(RedisAI_RunInfo *rinfo, RAI_DagOp *op, size_t *batch void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI_RunInfo *rinfo2, RAI_DagOp *op2, int *batched, size_t *inbatchsize); +/** + * @brief Get a tensor from the dag local context in a given index + * (this access to a shared array, require read lock) + * @param rinfo The DAG runInfo. + * @param index The index of the tensor in the Dag shared array to return + * @return The tensor of the given index (NULL is returned if this tensor hasn't been realized yet) + */ +RAI_Tensor *Dag_GetTensorFromGlobalCtx(RedisAI_RunInfo *rinfo, size_t index); + +/** + * @brief Shallow copy and set a tensor in the dag local context in a given index. + * (this access to a shared array, require write lock) + * @param rinfo The DAG runInfo. + * @param index The index to put in the given tensor in the Dag shared array. + * @param t The tensor to shallow copy and store in the given index. + */ +void Dag_SetTensorInGlobalCtx(RedisAI_RunInfo *rinfo, size_t index, RAI_Tensor *t); + /** * Run the first unrealized DAG operation in rinfo for the given device. * @param rinfo context in which RedisAI blocking commands operate. diff --git a/src/DAG/dag_builder.c b/src/DAG/dag_builder.c index 1e314baa5..d1194c422 100644 --- a/src/DAG/dag_builder.c +++ b/src/DAG/dag_builder.c @@ -29,12 +29,18 @@ int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *t RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); - // Add the tensor under its "mangled" key name to the DAG local context dict. - char buf[16]; - sprintf(buf, "%04d", 1); - RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf)); - AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, - (void *)RAI_TensorGetShallowCopy(tensor)); + + // Cannot load more than one tensor under the same name + if (AI_dictFind(rinfo->tensorsNamesToIndices, key_name) != NULL) { + RedisModule_FreeString(NULL, key_name); + return REDISMODULE_ERR; + } + + // Add the tensor to the DAG shared tensors and map its name to the relevant index. + size_t index = array_len(rinfo->dagSharedTensors); + AI_dictAdd(rinfo->tensorsNamesToIndices, (void *)key_name, (void *)index); + RAI_TensorGetShallowCopy(tensor); + rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, (void *)tensor); RedisModule_FreeString(NULL, key_name); return REDISMODULE_OK; diff --git a/src/DAG/dag_execute.c b/src/DAG/dag_execute.c index 71072ce25..c5e60c40f 100644 --- a/src/DAG/dag_execute.c +++ b/src/DAG/dag_execute.c @@ -3,155 +3,61 @@ #include "background_workers.h" #include "util/string_utils.h" -void _DAG_SetTensorsInLocalContext(RedisAI_RunInfo *rinfo) { - for (size_t i = 0; i < rinfo->dagOpCount; i++) { - RAI_DagOp *op = rinfo->dagOps[i]; - if (op->commandType == REDISAI_DAG_CMD_TENSORSET) { - // Insert the tensor with its mangled (unique) name. - void *t = (void *)RAI_TensorGetShallowCopy(op->outTensor); - AI_dictReplace(rinfo->dagTensorsContext, (void *)op->outkeys[0], t); - } - } -} - -int MangleTensorsNames(RedisAI_RunInfo *rinfo) { - - int res = REDISMODULE_ERR; - AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); +int ValidatePersistKeys(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd, + AI_dict *persistTensorsNames) { { - AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); - AI_dictEntry *entry = AI_dictNext(iter); - while (entry) { - RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry); - size_t key_len; - const char *key_str = RedisModule_StringPtrLen(key, &key_len); - RedisModuleString *demangled_key = RedisModule_CreateString(NULL, key_str, key_len - 4); - int *instance = RedisModule_Alloc(sizeof(int)); - *instance = 1; - AI_dictAdd(mangled_tensors, (void *)demangled_key, (void *)instance); - RedisModule_FreeString(NULL, demangled_key); - entry = AI_dictNext(iter); + AI_dictIterator *iter = AI_dictGetSafeIterator(persistTensorsNames); + AI_dictEntry *persist_entry; + while ((persist_entry = AI_dictNext(iter))) { + RedisModuleString *persist_key = (RedisModuleString *)AI_dictGetKey(persist_entry); + AI_dictEntry *entry = AI_dictFind(tensorsNamesToInd, persist_key); + if (!entry) { + RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG"); + AI_dictReleaseIterator(iter); + return REDISMODULE_ERR; + } + size_t index = (size_t)AI_dictGetVal(entry); + AI_dictReplace(persistTensorsNames, (void *)persist_key, (void *)index); } AI_dictReleaseIterator(iter); } + return REDISMODULE_OK; +} + +int MapTensorsKeysToIndices(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd) { for (long long i = 0; i < array_len(rinfo->dagOps); i++) { RAI_DagOp *currentOp = rinfo->dagOps[i]; - RedisModuleString **mangled_inkeys = - array_new(RedisModuleString *, array_len(currentOp->inkeys)); for (long long j = 0; j < array_len(currentOp->inkeys); j++) { RedisModuleString *key = currentOp->inkeys[j]; - AI_dictEntry *entry = AI_dictFind(mangled_tensors, key); + AI_dictEntry *entry = AI_dictFind(tensorsNamesToInd, key); if (!entry) { - array_free(mangled_inkeys); RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR INPUT key cannot be found in DAG"); - goto cleanup; + return REDISMODULE_ERR; } - int *instance = AI_dictGetVal(entry); - char buf[16]; - sprintf(buf, "%04d", *instance); - RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); - RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); - mangled_inkeys = array_append(mangled_inkeys, mangled_key); + size_t ind = (size_t)AI_dictGetVal(entry); + currentOp->inkeys_indices = array_append(currentOp->inkeys_indices, ind); } - RedisModuleString **mangled_outkeys = - array_new(RedisModuleString *, array_len(currentOp->outkeys)); for (long long j = 0; j < array_len(currentOp->outkeys); j++) { RedisModuleString *key = currentOp->outkeys[j]; - AI_dictEntry *entry = AI_dictFind(mangled_tensors, key); - int *instance = NULL; - if (entry) { - instance = AI_dictGetVal(entry); - *instance += 1; - } else { - instance = RedisModule_Alloc(sizeof(int)); - *instance = 1; - AI_dictAdd(mangled_tensors, (void *)key, (void *)instance); - } - char buf[16]; - sprintf(buf, "%04d", *instance); - RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); - RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); - mangled_outkeys = array_append(mangled_outkeys, mangled_key); - } - - if (currentOp->inkeys) { - for (size_t j = 0; j < array_len(currentOp->inkeys); j++) { - RedisModule_FreeString(NULL, currentOp->inkeys[j]); - } - array_free(currentOp->inkeys); - } - - if (currentOp->outkeys) { - for (size_t j = 0; j < array_len(currentOp->outkeys); j++) { - RedisModule_FreeString(NULL, currentOp->outkeys[j]); - } - array_free(currentOp->outkeys); - } - - currentOp->inkeys = mangled_inkeys; - currentOp->outkeys = mangled_outkeys; - } + size_t ind = array_len(rinfo->dagSharedTensors); - AI_dict *mangled_persisted = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); - { - AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext); - AI_dictEntry *entry = AI_dictNext(iter); - while (entry) { - RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry); - AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key); - if (!mangled_entry) { - AI_dictRelease(mangled_persisted); - AI_dictReleaseIterator(iter); - RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG"); - goto cleanup; - } - if (AI_dictFind(mangled_persisted, key) != NULL) { - AI_dictRelease(mangled_persisted); - AI_dictReleaseIterator(iter); - RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST keys must be unique"); - goto cleanup; + // Add a new empty place holder in the array for an output tensor. + // If this is a TENSORSET op, the tensor is already realized. + if (currentOp->commandType == REDISAI_DAG_CMD_TENSORSET) { + RAI_Tensor *t = RAI_TensorGetShallowCopy(currentOp->outTensor); + rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, t); + } else { + rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, NULL); } - int *instance = AI_dictGetVal(mangled_entry); - char buf[16]; - sprintf(buf, "%04d", *instance); - RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); - RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); - AI_dictAdd(mangled_persisted, (void *)mangled_key, (void *)1); - RedisModule_FreeString(NULL, mangled_key); - entry = AI_dictNext(iter); + currentOp->outkeys_indices = array_append(currentOp->outkeys_indices, ind); + AI_dictReplace(tensorsNamesToInd, (void *)key, (void *)ind); } - AI_dictReleaseIterator(iter); } - - AI_dictRelease(rinfo->dagTensorsPersistedContext); - rinfo->dagTensorsPersistedContext = mangled_persisted; - - for (long long i = 0; i < array_len(rinfo->dagOps); i++) { - if (rinfo->dagOps[i]->devicestr == NULL) { - rinfo->dagOps[i]->devicestr = "CPU"; - } - } - // Tensors from TENSORSET ops are ready to be put in DAG local context under their mangled - // names. - _DAG_SetTensorsInLocalContext(rinfo); - res = REDISMODULE_OK; - -cleanup : { - AI_dictIterator *iter = AI_dictGetSafeIterator(mangled_tensors); - AI_dictEntry *entry = AI_dictNext(iter); - while (entry) { - int *val = (int *)AI_dictGetVal(entry); - RedisModule_Free(val); - entry = AI_dictNext(iter); - } - AI_dictReleaseIterator(iter); -} - AI_dictRelease(mangled_tensors); - return res; + return REDISMODULE_OK; } // Add Shallow copies of the DAG run info to the devices' queues. @@ -242,7 +148,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri } // Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations // will be execute in the right order. - if (MangleTensorsNames(rinfo) != REDISMODULE_OK) { + if (MapTensorsKeysToIndices(rinfo, rinfo->tensorsNamesToIndices) != REDISMODULE_OK) { RAI_SetError(err, rinfo->err->code, rinfo->err->detail); return REDISMODULE_ERR; } @@ -269,16 +175,13 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) { const RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index) { size_t tensor_get_op_ind = -1; RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx; + for (size_t i = 0; i < rinfo->dagOpCount; i++) { RAI_DagOp *op = rinfo->dagOps[i]; if (op->commandType == REDISAI_DAG_CMD_TENSORGET) { tensor_get_op_ind++; if (tensor_get_op_ind == index) { - RAI_Tensor *t; - int res = RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, op->inkeys[0], &t, - op->err); - RedisModule_Assert(res == REDISMODULE_OK); - return t; + return Dag_GetTensorFromGlobalCtx(rinfo, op->inkeys_indices[0]); } } } diff --git a/src/DAG/dag_execute.h b/src/DAG/dag_execute.h index c4b5b1b99..22b794877 100644 --- a/src/DAG/dag_execute.h +++ b/src/DAG/dag_execute.h @@ -4,17 +4,39 @@ #include "run_info.h" /** - * @brief We are given a DAG runInfo of a sequence of operations, each with its own + @brief We are given a DAG runInfo of a sequence of operations, each with its own input and output keys. The names of the keys will be used to look whether the inputs to a DAG operation have all been realized by previous operations (or if they are available as part of LOADed keys from keyspace). This strategy is fine if keys are not aliased, that is, if a command's output overwrites the key of a previous command. This would trick DAG operations into thinking that their input is ready when it's not. - To overcome this, we make key names unique, so that names are not aliased. We - mangle the names by appending a numerical suffix ":0001". After computing, we - demangle the keys in order to persist them.*/ -int MangleTensorsNames(RedisAI_RunInfo *rinfo); + To overcome this, we map the input and output tensors for every operation to indices, + in the following way. For every input of an operation having the key "x", we map the index + for which "x" was last mapped to when, it was an output of a previous operation. + For every output of an operation "y", we map the next available index in the array. + Every entry in the DAG array contains NULL (except for tensors that where loaded + before the DAG run starts). + @param rinfo The DAG runInfo. + @param tensorsNamesToInd A dict mapping every key name of a tensor that appeared + in DAG operation, to the maximal index of the DAG shared array for which they were mapped to. + @returns REDISMODULE_ERR if there exists an operation for which one of the input + tensors didn't appear as an output of a previous operation, REDISMODULE_OK otherwise + */ +int MapTensorsKeysToIndices(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd); + +/** + * @brief Validates that tensors key names to persist appeared in the DAG operations. + * @param rinfo The DAG runInfo. + * @param tensorsNamesToInd A dict mapping every key name of a tensor that appeared + * in DAG operation, to the maximal index of the DAG shared array for which they were mapped to. + * @param persistTensorsNames A hash table the contains the names of the tensors + * to persist when the DAG run is finished. + * @return REDISMODULE_ERR if there exists a tensor key to persist that didn't + * appear in DAG operation, REDISMODULE_OK otherwise + */ +int ValidatePersistKeys(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd, + AI_dict *persistTensorsNames); /** * @brief Run asynchronously a DAG. This will validate that the sequence of DAG ops diff --git a/src/DAG/dag_parser.c b/src/DAG/dag_parser.c index ad7e53890..5cab18c21 100644 --- a/src/DAG/dag_parser.c +++ b/src/DAG/dag_parser.c @@ -15,17 +15,16 @@ * @param ctx Context in which Redis modules operate * @param argv Redis command arguments, as an array of strings * @param argc Redis command number of arguments - * @param loadedContextDict local non-blocking hash table containing key names - * loaded from the keyspace tensors - * @param localContextDict local non-blocking hash table containing DAG's - * tensors + * @param tensorsToInd Hash table that maps tensor key name to its index in the + * shared tensors array of the DAG. + * @param sharedTensors An array that use to store intermideate tensors in the DAG * @param chaining_operator operator used to split operations. Any command * argument after the chaining operator is not considered * @return processed number of arguments on success, or -1 if the parsing failed */ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, - AI_dict **localContextDict, const char *chaining_operator, - RAI_Error *err) { + AI_dict *tensorsToInd, RAI_Tensor ***sharedTensors, + const char *chaining_operator, RAI_Error *err) { if (argc < 3) { RAI_SetError(err, RAI_EDAGBUILDER, "ERR wrong number of arguments for LOAD in 'AI.DAGRUN' command"); @@ -59,11 +58,10 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int return -1; } - // Add the tensor under its "mangled" key name to the DAG local context dict. - char buf[16]; - sprintf(buf, "%04d", 1); - RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf)); - AI_dictAdd(*localContextDict, (void *)key_name, (void *)RAI_TensorGetShallowCopy(t)); + // Add the tensor to the DAG shared tensors and map its name to the relevant index. + size_t index = array_len(*sharedTensors); + AI_dictAdd(tensorsToInd, (void *)key_name, (void *)index); + *sharedTensors = array_append(*sharedTensors, (void *)RAI_TensorGetShallowCopy(t)); number_loaded_keys++; } @@ -79,18 +77,16 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int /** * DAGRUN Building Block to parse [PERSIST key1 key2... ] * - * @param ctx Context in which Redis modules operate * @param argv Redis command arguments, as an array of strings * @param argc Redis command number of arguments - * @param localContextDict local non-blocking hash table containing DAG's + * @param persistTensorsNames local hash table containing DAG's * keynames marked as persistent * @param chaining_operator operator used to split operations. Any command * argument after the chaining operator is not considered * @return processed number of arguments on success, or -1 if the parsing failed */ -static int _ParseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, - AI_dict **persistContextDict, const char *chaining_operator, - RAI_Error *err) { +static int _ParseDAGPersistArgs(RedisModuleString **argv, int argc, AI_dict *persistTensorsNames, + const char *chaining_operator, RAI_Error *err) { if (argc < 3) { RAI_SetError(err, RAI_EDAGBUILDER, "ERR wrong number of arguments for PERSIST in 'AI.DAGRUN' command"); @@ -111,10 +107,13 @@ static int _ParseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv, i const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); if (!strcasecmp(arg_string, chaining_operator)) { break; - } else { - AI_dictAdd(*persistContextDict, (void *)argv[argpos], (void *)1); - number_keys_to_persist++; } + if (AI_dictFind(persistTensorsNames, (void *)argv[argpos]) != NULL) { + RAI_SetError(err, RAI_EDAGRUN, "ERR PERSIST keys must be unique"); + return -1; + } + AI_dictAdd(persistTensorsNames, (void *)argv[argpos], NULL); + number_keys_to_persist++; } if (number_keys_to_persist != n_keys) { RAI_SetError(err, RAI_EDAGBUILDER, @@ -247,10 +246,11 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL); if (!strcasecmp(arg_string, "LOAD") && !load_complete && chainingOpCount == 0) { - /* Load the required tensors from key space and store them in both - dagTensorsLoadedContext and dagTensorsContext dicts. */ - const int parse_result = _ParseDAGLoadArgs( - ctx, &argv[arg_pos], argc - arg_pos, &(rinfo->dagTensorsContext), "|>", rinfo->err); + /* Load the required tensors from key space to the dag shared tensors + * array, and save a mapping of their names to the corresponding indices. */ + const int parse_result = + _ParseDAGLoadArgs(ctx, &argv[arg_pos], argc - arg_pos, rinfo->tensorsNamesToIndices, + &rinfo->dagSharedTensors, "|>", rinfo->err); if (parse_result <= 0) goto cleanup; arg_pos += parse_result; @@ -263,12 +263,11 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS "ERR PERSIST cannot be specified in a read-only DAG"); goto cleanup; } - /* Store the keys to persist in dagTensorsPersistedContext dict. - These keys will be populated later on with actual tensors. */ - const int parse_result = - _ParseDAGPersistArgs(ctx, &argv[arg_pos], argc - arg_pos, - &(rinfo->dagTensorsPersistedContext), "|>", rinfo->err); - + /* Store the keys to persist in persistTensors dict, these keys will + * be mapped later to the indices in the dagSharedTensors array in which the + * tensors to persist will be found by the end of the DAG run. */ + const int parse_result = _ParseDAGPersistArgs(&argv[arg_pos], argc - arg_pos, + rinfo->persistTensors, "|>", rinfo->err); if (parse_result <= 0) goto cleanup; arg_pos += parse_result; @@ -307,9 +306,15 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS goto cleanup; } - if (MangleTensorsNames(rinfo) != REDISMODULE_OK) { + if (MapTensorsKeysToIndices(rinfo, rinfo->tensorsNamesToIndices) != REDISMODULE_OK) { + goto cleanup; + } + if (ValidatePersistKeys(rinfo, rinfo->tensorsNamesToIndices, rinfo->persistTensors) != + REDISMODULE_OK) { goto cleanup; } + AI_dictRelease(rinfo->tensorsNamesToIndices); + rinfo->tensorsNamesToIndices = NULL; res = REDISMODULE_OK; cleanup: diff --git a/src/redisai.c b/src/redisai.c index 2691641db..8298b2be8 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -101,13 +101,16 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv return RedisModule_WrongArity(ctx); RedisModuleKey *key; - const int status = RAI_OpenKey_Tensor(ctx, argv[1], &key, REDISMODULE_READ | REDISMODULE_WRITE); + RAI_Error err = {0}; + const int status = + RAI_OpenKey_Tensor(ctx, argv[1], &key, REDISMODULE_READ | REDISMODULE_WRITE, &err); if (status == REDISMODULE_ERR) { + RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err)); + RAI_ClearError(&err); return REDISMODULE_ERR; } RAI_Tensor *t = NULL; - RAI_Error err = {0}; const int parse_result = RAI_parseTensorSetArgs(argv, argc, &t, 1, &err); // if the number of parsed args is negative something went wrong diff --git a/src/run_info.c b/src/run_info.c index ac6f4ddf3..bcbbdef5c 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -46,6 +46,8 @@ int RAI_InitDagOp(RAI_DagOp **result) { dagOp->runkey = NULL; dagOp->inkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); dagOp->outkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); + dagOp->inkeys_indices = array_new(size_t, 1); + dagOp->outkeys_indices = array_new(size_t, 1); dagOp->outTensor = NULL; dagOp->mctx = NULL; dagOp->sctx = NULL; @@ -70,9 +72,9 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) { RedisAI_RunInfo *rinfo; rinfo = (RedisAI_RunInfo *)RedisModule_Calloc(1, sizeof(RedisAI_RunInfo)); - rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeTensorVals, NULL); - rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); - + rinfo->dagSharedTensors = array_new(RAI_Tensor *, 1); + rinfo->persistTensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); + rinfo->tensorsNamesToIndices = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); rinfo->dagOps = (RAI_DagOp **)array_new(RAI_DagOp *, 1); rinfo->dagError = RedisModule_Calloc(1, sizeof(int)); RAI_InitError(&rinfo->err); @@ -125,6 +127,7 @@ void RAI_FreeDagOp(RAI_DagOp *dagOp) { } array_free(dagOp->inkeys); } + array_free(dagOp->inkeys_indices); if (dagOp->outkeys) { for (size_t i = 0; i < array_len(dagOp->outkeys); i++) { @@ -132,6 +135,7 @@ void RAI_FreeDagOp(RAI_DagOp *dagOp) { } array_free(dagOp->outkeys); } + array_free(dagOp->outkeys_indices); RedisModule_Free(dagOp); } @@ -152,9 +156,14 @@ void RAI_FreeRunInfo(struct RedisAI_RunInfo *rinfo) { pthread_rwlock_destroy(rinfo->dagLock); RedisModule_Free(rinfo->dagLock); - if (rinfo->dagTensorsContext) { - AI_dictRelease(rinfo->dagTensorsContext); - AI_dictRelease(rinfo->dagTensorsPersistedContext); + size_t dag_tensors_num = array_len(rinfo->dagSharedTensors); + for (size_t i = 0; i < dag_tensors_num; i++) { + RAI_TensorFree(rinfo->dagSharedTensors[i]); + } + array_free(rinfo->dagSharedTensors); + AI_dictRelease(rinfo->persistTensors); + if (rinfo->tensorsNamesToIndices) { + AI_dictRelease(rinfo->tensorsNamesToIndices); } if (rinfo->dagOps) { @@ -179,8 +188,7 @@ void RAI_ContextReadLock(RedisAI_RunInfo *rinfo) { if (rinfo->single_op_dag || rinfo->single_device_dag) { return; } - // This is a temporary solution - pthread_rwlock_wrlock(rinfo->dagLock); + pthread_rwlock_rdlock(rinfo->dagLock); } void RAI_ContextWriteLock(RedisAI_RunInfo *rinfo) { @@ -279,7 +287,7 @@ RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) { RAI_DagOp *op = rinfo->dagOps[0]; if (!rinfo->single_op_dag || !op->mctx) { - RAI_SetError(err, RedisAI_ErrorCode_EFINISHCTX, "Finish ctx is not a model run ctx"); + RAI_SetError(err, RAI_EFINISHCTX, "Finish ctx is not a model run ctx"); return NULL; } RAI_SetError(err, RAI_GetErrorCode(op->err), RAI_GetError(op->err)); @@ -293,7 +301,7 @@ RAI_ScriptRunCtx *RAI_GetAsScriptRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) RAI_DagOp *op = rinfo->dagOps[0]; if (!rinfo->single_op_dag || !op->sctx) { - RAI_SetError(err, RedisAI_ErrorCode_EFINISHCTX, "Finish ctx is not a script run ctx"); + RAI_SetError(err, RAI_EFINISHCTX, "Finish ctx is not a script run ctx"); return NULL; } RAI_SetError(err, RAI_GetErrorCode(op->err), RAI_GetError(op->err)); diff --git a/src/run_info.h b/src/run_info.h index efba7d2ec..5816602ed 100644 --- a/src/run_info.h +++ b/src/run_info.h @@ -33,6 +33,8 @@ typedef struct RAI_DagOp { RedisModuleString *runkey; RedisModuleString **inkeys; RedisModuleString **outkeys; + size_t *inkeys_indices; + size_t *outkeys_indices; RAI_Tensor *outTensor; // The tensor to upload in TENSORSET op. RAI_ModelRunCtx *mctx; RAI_ScriptRunCtx *sctx; @@ -91,10 +93,11 @@ struct RedisAI_RunInfo { RedisModuleBlockedClient *client; int single_op_dag; int single_device_dag; - AI_dict *dagTensorsContext; - AI_dict *dagTensorsPersistedContext; // dict to flag tensors to persist - RAI_DagOp **dagOps; // all ops in DAG - RAI_DagOp **dagDeviceOps; // all ops in DAG for device + RAI_Tensor **dagSharedTensors; // Shared array of tensors that dag ops use. + AI_dict *persistTensors; // Associates the tensors to persist with their indices . + AI_dict *tensorsNamesToIndices; // Maps tensor key name to its (maximal) index. + RAI_DagOp **dagOps; // all ops in DAG + RAI_DagOp **dagDeviceOps; // all ops in DAG for device int dagReplyLength; int dagOpCount; // number of ops in DAG int *dagCompleteOpCount; // number of completed ops in DAG diff --git a/src/tensor.c b/src/tensor.c index 0310f33ae..7f087501f 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -562,17 +562,15 @@ size_t RAI_TensorByteSize(RAI_Tensor *t) { char *RAI_TensorData(RAI_Tensor *t) { return t->tensor.dl_tensor.data; } -/* Return REDISMODULE_ERR if is the key not associated with a tensor type. - * Return REDISMODULE_OK otherwise. */ int RAI_OpenKey_Tensor(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key, - int mode) { + int mode, RAI_Error *err) { *key = RedisModule_OpenKey(ctx, keyName, mode); if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) { return REDISMODULE_OK; } if (RedisModule_ModuleTypeGetType(*key) != RedisAI_TensorType) { RedisModule_CloseKey(*key); - RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + RAI_SetError(err, RAI_ETENSORSET, REDISMODULE_ERRORMSG_WRONGTYPE); return REDISMODULE_ERR; } return REDISMODULE_OK; @@ -596,22 +594,6 @@ int RAI_GetTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R return REDISMODULE_OK; } -/* Return REDISMODULE_ERR if there was an error getting the Tensor. - * Return REDISMODULE_OK if the tensor value is present at the localContextDict. - */ -int RAI_getTensorFromLocalContext(AI_dict *localContextDict, RedisModuleString *localContextKey, - RAI_Tensor **tensor, RAI_Error *error) { - int result = REDISMODULE_ERR; - AI_dictEntry *tensor_entry = AI_dictFind(localContextDict, localContextKey); - if (tensor_entry) { - *tensor = AI_dictGetVal(tensor_entry); - result = REDISMODULE_OK; - } else { - RAI_SetError(error, RAI_ETENSORGET, "ERR tensor key is empty"); - } - return result; -} - void RedisAI_ReplicateTensorSet(RedisModuleCtx *ctx, RedisModuleString *key, RAI_Tensor *t) { long long ndims = RAI_TensorNumDims(t); diff --git a/src/tensor.h b/src/tensor.h index e5587dc40..5768a7fba 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -312,7 +312,7 @@ char *RAI_TensorData(RAI_Tensor *t); * tensor type. */ int RAI_OpenKey_Tensor(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key, - int mode); + int mode, RAI_Error *err); /** * Helper method to get Tensor from keyspace. In case of a failure an @@ -331,20 +331,6 @@ int RAI_OpenKey_Tensor(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisMod int RAI_GetTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key, RAI_Tensor **tensor, int mode, RAI_Error *err); -/** - * Helper method to get Tensor from local context ( no keyspace access ) - * - * @param localContextDict local non-blocking hash table containing DAG's - * tensors - * @param localContextKey key name - * @param tensor destination tensor - * @param error error data structure to store error message in the case of - * failure - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed - */ -int RAI_getTensorFromLocalContext(AI_dict *localContextDict, RedisModuleString *localContextKey, - RAI_Tensor **tensor, RAI_Error *error); - /** * Helper method to replicate a tensor via an AI.TENSORSET command to the * replicas. This is used on MODELRUN, SCRIPTRUN, DAGRUN as a way to ensure that diff --git a/tests/module/DAG_utils.c b/tests/module/DAG_utils.c index d145fb372..08142d332 100644 --- a/tests/module/DAG_utils.c +++ b/tests/module/DAG_utils.c @@ -72,6 +72,26 @@ static void _DAGFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) { pthread_cond_signal(&global_cond); } +int testLoadTensor(RedisModuleCtx *ctx) { + RAI_DAGRunCtx *run_info = RedisAI_DAGRunCtxCreate(); + int res = LLAPIMODULE_ERR; + RAI_Tensor *t = (RAI_Tensor *)_getFromKeySpace(ctx, "a{1}"); + if (RedisAI_DAGLoadTensor(run_info, "input", t) != REDISMODULE_OK) { + goto cleanup; + } + t = (RAI_Tensor *)_getFromKeySpace(ctx, "b{1}"); + + // cannot load more than one tensor under the same name. + if (RedisAI_DAGLoadTensor(run_info, "input", t) != REDISMODULE_ERR) { + goto cleanup; + } + res = LLAPIMODULE_OK; + + cleanup: + RedisAI_DAGFree(run_info); + return res; +} + int testModelRunOpError(RedisModuleCtx *ctx) { RAI_DAGRunCtx *run_info = RedisAI_DAGRunCtxCreate(); diff --git a/tests/module/DAG_utils.h b/tests/module/DAG_utils.h index fbaa40c3b..f68a94736 100644 --- a/tests/module/DAG_utils.h +++ b/tests/module/DAG_utils.h @@ -11,6 +11,8 @@ typedef struct RAI_RunResults { RAI_Error *error; } RAI_RunResults; +int testLoadTensor(RedisModuleCtx *ctx); + int testModelRunOpError(RedisModuleCtx *ctx); int testEmptyDAGError(RedisModuleCtx *ctx); diff --git a/tests/module/LLAPI.c b/tests/module/LLAPI.c index 16e4563ac..aed94fd55 100644 --- a/tests/module/LLAPI.c +++ b/tests/module/LLAPI.c @@ -243,6 +243,10 @@ int RAI_llapi_DAGRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return REDISMODULE_OK; } + // Test the case a successful and failure tensor load input to DAG. + if(testLoadTensor(ctx) != LLAPIMODULE_OK) { + return RedisModule_ReplyWithSimpleString(ctx, "LOAD tensor test failed"); + } // Test the case of a failure due to addition of a non compatible MODELRUN op. if(testModelRunOpError(ctx) != LLAPIMODULE_OK) { return RedisModule_ReplyWithSimpleString(ctx, "MODELRUN op error test failed");