Skip to content

Commit ffbf441

Browse files
authored
Merge branch 'master' into lite_build
2 parents 88a59fa + 1f0b9b9 commit ffbf441

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/execution/parsing/dag_parser.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
7878

7979
/**
8080
* DAGRUN Building Block to parse [PERSIST <nkeys> key1 key2... ]
81-
*
81+
* @param ctx Context in which Redis modules operate
8282
* @param argv Redis command arguments, as an array of strings
8383
* @param argc Redis command number of arguments
8484
* @param persistTensorsNames local hash table containing DAG's
@@ -87,8 +87,8 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
8787
* argument after the chaining operator is not considered
8888
* @return processed number of arguments on success, or -1 if the parsing failed
8989
*/
90-
static int _ParseDAGPersistArgs(RedisModuleString **argv, int argc, AI_dict *persistTensorsNames,
91-
RAI_Error *err) {
90+
static int _ParseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
91+
AI_dict *persistTensorsNames, RAI_Error *err) {
9292
if (argc < 3) {
9393
RAI_SetError(err, RAI_EDAGBUILDER,
9494
"ERR missing arguments after PERSIST keyword in DAG command");
@@ -106,11 +106,16 @@ static int _ParseDAGPersistArgs(RedisModuleString **argv, int argc, AI_dict *per
106106
// Go over the given args and save the tensor key names to persist.
107107
int number_keys_to_persist = 0;
108108
for (size_t argpos = 2; (argpos < argc) && (number_keys_to_persist < n_keys); argpos++) {
109-
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
110109
if (AI_dictFind(persistTensorsNames, (void *)argv[argpos]) != NULL) {
111110
RAI_SetError(err, RAI_EDAGBUILDER, "ERR PERSIST keys must be unique");
112111
return -1;
113112
}
113+
if (!VerifyKeyInThisShard(ctx, argv[argpos])) { // Relevant for enterprise cluster.
114+
RAI_SetError(
115+
err, RAI_EDAGBUILDER,
116+
"ERR Found keys to persist in DAG command that don't hash to the local shard");
117+
return -1;
118+
}
114119
AI_dictAdd(persistTensorsNames, (void *)argv[argpos], NULL);
115120
number_keys_to_persist++;
116121
}
@@ -267,7 +272,7 @@ int DAGInitialParsing(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleSt
267272
/* Store the keys to persist in persistTensors dict, these keys will
268273
* be mapped later to the indices in the dagSharedTensors array in which the
269274
* tensors to persist will be found by the end of the DAG run. */
270-
const int parse_result = _ParseDAGPersistArgs(&argv[arg_pos], argc - arg_pos,
275+
const int parse_result = _ParseDAGPersistArgs(ctx, &argv[arg_pos], argc - arg_pos,
271276
rinfo->persistTensors, rinfo->err);
272277
if (parse_result <= 0)
273278
return REDISMODULE_ERR;

0 commit comments

Comments
 (0)