@@ -389,6 +389,53 @@ def test_pytorch_scriptexecute_list_input(env):
389
389
env .assertEqual (values2 , values )
390
390
391
391
392
+ def test_pytorch_scriptexecute_multiple_list_input (env ):
393
+ if not TEST_PT :
394
+ env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
395
+ return
396
+
397
+ con = env .getConnection ()
398
+
399
+ script = load_file_content ('script.txt' )
400
+
401
+ ret = con .execute_command ('AI.SCRIPTSET' , 'myscript{$}' , DEVICE , 'TAG' , 'version1' , 'SOURCE' , script )
402
+ env .assertEqual (ret , b'OK' )
403
+
404
+ ret = con .execute_command ('AI.TENSORSET' , 'a{$}' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
405
+ env .assertEqual (ret , b'OK' )
406
+ ret = con .execute_command ('AI.TENSORSET' , 'b{$}' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
407
+ env .assertEqual (ret , b'OK' )
408
+
409
+ ensureSlaveSynced (con , env )
410
+
411
+ for _ in range ( 0 ,100 ):
412
+ ret = con .execute_command ('AI.SCRIPTEXECUTE' , 'myscript{$}' , 'bar_two_lists' , 'KEYS' , 1 , '{$}' , 'LIST_INPUTS' , 1 , 'a{$}' , 'LIST_INPUTS' , 1 , 'b{$}' , 'OUTPUTS' , 1 , 'c{$}' )
413
+ env .assertEqual (ret , b'OK' )
414
+
415
+ ensureSlaveSynced (con , env )
416
+
417
+ info = con .execute_command ('AI.INFO' , 'myscript{$}' )
418
+ info_dict_0 = info_to_dict (info )
419
+
420
+ env .assertEqual (info_dict_0 ['key' ], 'myscript{$}' )
421
+ env .assertEqual (info_dict_0 ['type' ], 'SCRIPT' )
422
+ env .assertEqual (info_dict_0 ['backend' ], 'TORCH' )
423
+ env .assertEqual (info_dict_0 ['tag' ], 'version1' )
424
+ env .assertTrue (info_dict_0 ['duration' ] > 0 )
425
+ env .assertEqual (info_dict_0 ['samples' ], - 1 )
426
+ env .assertEqual (info_dict_0 ['calls' ], 100 )
427
+ env .assertEqual (info_dict_0 ['errors' ], 0 )
428
+
429
+ values = con .execute_command ('AI.TENSORGET' , 'c{$}' , 'VALUES' )
430
+ env .assertEqual (values , [b'4' , b'6' , b'4' , b'6' ])
431
+
432
+ ensureSlaveSynced (con , env )
433
+
434
+ if env .useSlaves :
435
+ con2 = env .getSlaveConnection ()
436
+ values2 = con2 .execute_command ('AI.TENSORGET' , 'c{$}' , 'VALUES' )
437
+ env .assertEqual (values2 , values )
438
+
392
439
def test_pytorch_scriptexecute_errors (env ):
393
440
if not TEST_PT :
394
441
env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
@@ -484,7 +531,7 @@ def test_pytorch_scriptexecute_variadic_errors(env):
484
531
check_error (env , con , 'AI.SCRIPTEXECUTE' , 'ket{$}' , 'bar_variadic' , 'KEYS' , 1 , '{$}' , 'INPUTS' , 'LIST_INPUTS' , 'OUTPUTS' )
485
532
486
533
check_error (env , con , 'AI.SCRIPTEXECUTE' , 'ket{$}' , 'bar_variadic' , 'KEYS' , 1 , '{$}' , 'LIST_INPUTS' , 1 , 'a{$}' , 'LIST_INPUTS' , 1 , 'b{$}' 'OUTPUTS' )
487
-
534
+
488
535
489
536
def test_pytorch_scriptinfo (env ):
490
537
if not TEST_PT :
0 commit comments