@@ -99,10 +99,11 @@ RAI_Tensor *RAI_TensorNew(void) {
99
99
RAI_Tensor * ret = RedisModule_Calloc (1 , sizeof (* ret ));
100
100
ret -> refCount = 1 ;
101
101
ret -> len = LEN_UNKOWN ;
102
+ return ret ;
102
103
}
103
104
104
105
RAI_Tensor * RAI_TensorCreateWithDLDataType (DLDataType dtype , long long * dims , int ndims ,
105
- int tensorAllocMode ) {
106
+ bool empty ) {
106
107
107
108
size_t dtypeSize = Tensor_DataTypeSize (dtype );
108
109
if (dtypeSize == 0 ) {
@@ -124,20 +125,14 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
124
125
}
125
126
126
127
DLDevice device = (DLDevice ){.device_type = kDLCPU , .device_id = 0 };
127
- void * data = NULL ;
128
- switch (tensorAllocMode ) {
129
- case TENSORALLOC_ALLOC :
130
- data = RedisModule_Alloc (len * dtypeSize );
131
- break ;
132
- case TENSORALLOC_CALLOC :
128
+
129
+ // If we return an empty tensor, we initialize the data with zeros to avoid security
130
+ // issues. Otherwise, we only allocate without initializing (for better performance)
131
+ void * data ;
132
+ if (empty ) {
133
133
data = RedisModule_Calloc (len , dtypeSize );
134
- break ;
135
- case TENSORALLOC_NONE :
136
- /* shallow copy no alloc */
137
- default :
138
- /* assume TENSORALLOC_NONE
139
- shallow copy no alloc */
140
- break ;
134
+ } else {
135
+ data = RedisModule_Alloc (len * dtypeSize );
141
136
}
142
137
143
138
ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.device = device ,
@@ -214,27 +209,11 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
214
209
return ret ;
215
210
}
216
211
217
- RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
212
+ // Important note: the tensor data must be initialized after the creation.
213
+ RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims ) {
218
214
DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
219
- return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
220
- }
221
-
222
- #if 0
223
- void RAI_TensorMoveFrom (RAI_Tensor * dst , RAI_Tensor * src ) {
224
- if (-- dst -> refCount <= 0 ){
225
- RedisModule_Free (t -> tensor .shape );
226
- if (t -> tensor .strides ) {
227
- RedisModule_Free (t -> tensor .strides );
228
- }
229
- RedisModule_Free (t -> tensor .data );
230
- RedisModule_Free (t );
231
- }
232
- dst -> tensor .ctx = src -> tensor .ctx ;
233
- dst -> tensor .data = src -> tensor .data ;
234
-
235
- dst -> refCount = 1 ;
215
+ return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false);
236
216
}
237
- #endif
238
217
239
218
RAI_Tensor * RAI_TensorCreateByConcatenatingTensors (RAI_Tensor * * ts , long long n ) {
240
219
@@ -273,7 +252,7 @@ RAI_Tensor *RAI_TensorCreateByConcatenatingTensors(RAI_Tensor **ts, long long n)
273
252
274
253
DLDataType dtype = RAI_TensorDataType (ts [0 ]);
275
254
276
- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
255
+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
277
256
278
257
for (long long i = 0 ; i < n ; i ++ ) {
279
258
memcpy (RAI_TensorData (ret ) + batch_offsets [i ] * sample_size * dtype_size ,
@@ -300,7 +279,7 @@ RAI_Tensor *RAI_TensorCreateBySlicingTensor(RAI_Tensor *t, long long offset, lon
300
279
301
280
DLDataType dtype = RAI_TensorDataType (t );
302
281
303
- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
282
+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
304
283
305
284
memcpy (RAI_TensorData (ret ), RAI_TensorData (t ) + offset * sample_size * dtype_size ,
306
285
len * sample_size * dtype_size );
@@ -329,14 +308,14 @@ int RAI_TensorDeepCopy(RAI_Tensor *t, RAI_Tensor **dest) {
329
308
330
309
DLDataType dtype = RAI_TensorDataType (t );
331
310
332
- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
311
+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
333
312
334
313
memcpy (RAI_TensorData (ret ), RAI_TensorData (t ), sample_size * dtype_size );
335
314
* dest = ret ;
336
315
return 0 ;
337
316
}
338
317
339
- // Beware: this will take ownership of dltensor
318
+ // Beware: this will take ownership of dltensor.
340
319
RAI_Tensor * RAI_TensorCreateFromDLTensor (DLManagedTensor * dl_tensor ) {
341
320
342
321
RAI_Tensor * ret = RAI_TensorNew ();
@@ -419,19 +398,15 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) {
419
398
case 8 :
420
399
((int8_t * )data )[i ] = val ;
421
400
break ;
422
- break ;
423
401
case 16 :
424
402
((int16_t * )data )[i ] = val ;
425
403
break ;
426
- break ;
427
404
case 32 :
428
405
((int32_t * )data )[i ] = val ;
429
406
break ;
430
- break ;
431
407
case 64 :
432
408
((int64_t * )data )[i ] = val ;
433
409
break ;
434
- break ;
435
410
default :
436
411
return 0 ;
437
412
}
@@ -440,19 +415,15 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) {
440
415
case 8 :
441
416
((uint8_t * )data )[i ] = val ;
442
417
break ;
443
- break ;
444
418
case 16 :
445
419
((uint16_t * )data )[i ] = val ;
446
420
break ;
447
- break ;
448
421
case 32 :
449
422
((uint32_t * )data )[i ] = val ;
450
423
break ;
451
- break ;
452
424
case 64 :
453
425
((uint64_t * )data )[i ] = val ;
454
426
break ;
455
- break ;
456
427
default :
457
428
return 0 ;
458
429
}
@@ -642,7 +613,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
642
613
643
614
const char * fmtstr ;
644
615
int datafmt = TENSOR_NONE ;
645
- int tensorAllocMode = TENSORALLOC_CALLOC ;
646
616
size_t ndims = 0 ;
647
617
long long len = 1 ;
648
618
long long * dims = (long long * )array_new (long long , 1 );
@@ -656,7 +626,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
656
626
remaining_args = argc - 1 - argpos ;
657
627
if (!strcasecmp (opt , "BLOB" )) {
658
628
datafmt = TENSOR_BLOB ;
659
- tensorAllocMode = TENSORALLOC_CALLOC ;
660
629
// if we've found the dataformat there are no more dimensions
661
630
// check right away if the arity is correct
662
631
if (remaining_args != 1 && enforceArity == 1 ) {
@@ -669,7 +638,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
669
638
break ;
670
639
} else if (!strcasecmp (opt , "VALUES" )) {
671
640
datafmt = TENSOR_VALUES ;
672
- tensorAllocMode = TENSORALLOC_CALLOC ;
673
641
// if we've found the dataformat there are no more dimensions
674
642
// check right away if the arity is correct
675
643
if (remaining_args != len && enforceArity == 1 ) {
@@ -699,7 +667,8 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
699
667
RedisModuleString * rstr = argv [argpos ];
700
668
* t = _TensorCreateWithDLDataTypeAndRString (datatype , datasize , dims , ndims , rstr , error );
701
669
} else {
702
- * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
670
+ bool is_empty = (datafmt == TENSOR_NONE );
671
+ * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , is_empty );
703
672
}
704
673
if (!(* t )) {
705
674
array_free (dims );
0 commit comments