Skip to content

Commit 1707686

Browse files
committed
added multiple lists test
1 parent 52d8ff8 commit 1707686

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

tests/flow/test_data/script.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ def bar(a, b):
33

44
def bar_variadic(a, args : List[Tensor]):
55
return args[0] + args[1]
6+
7+
def bar_two_lists(a: List[Tensor], b:List[Tensor]):
8+
return a[0] + b[0]

tests/flow/tests_pytorch.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,53 @@ def test_pytorch_scriptexecute_list_input(env):
389389
env.assertEqual(values2, values)
390390

391391

392+
def test_pytorch_scriptexecute_multiple_list_input(env):
393+
if not TEST_PT:
394+
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
395+
return
396+
397+
con = env.getConnection()
398+
399+
script = load_file_content('script.txt')
400+
401+
ret = con.execute_command('AI.SCRIPTSET', 'myscript{$}', DEVICE, 'TAG', 'version1', 'SOURCE', script)
402+
env.assertEqual(ret, b'OK')
403+
404+
ret = con.execute_command('AI.TENSORSET', 'a{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
405+
env.assertEqual(ret, b'OK')
406+
ret = con.execute_command('AI.TENSORSET', 'b{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
407+
env.assertEqual(ret, b'OK')
408+
409+
ensureSlaveSynced(con, env)
410+
411+
for _ in range( 0,100):
412+
ret = con.execute_command('AI.SCRIPTEXECUTE', 'myscript{$}', 'bar_two_lists', 'KEYS', 1, '{$}', 'LIST_INPUTS', 1, 'a{$}', 'LIST_INPUTS', 1, 'b{$}', 'OUTPUTS', 1, 'c{$}')
413+
env.assertEqual(ret, b'OK')
414+
415+
ensureSlaveSynced(con, env)
416+
417+
info = con.execute_command('AI.INFO', 'myscript{$}')
418+
info_dict_0 = info_to_dict(info)
419+
420+
env.assertEqual(info_dict_0['key'], 'myscript{$}')
421+
env.assertEqual(info_dict_0['type'], 'SCRIPT')
422+
env.assertEqual(info_dict_0['backend'], 'TORCH')
423+
env.assertEqual(info_dict_0['tag'], 'version1')
424+
env.assertTrue(info_dict_0['duration'] > 0)
425+
env.assertEqual(info_dict_0['samples'], -1)
426+
env.assertEqual(info_dict_0['calls'], 100)
427+
env.assertEqual(info_dict_0['errors'], 0)
428+
429+
values = con.execute_command('AI.TENSORGET', 'c{$}', 'VALUES')
430+
env.assertEqual(values, [b'4', b'6', b'4', b'6'])
431+
432+
ensureSlaveSynced(con, env)
433+
434+
if env.useSlaves:
435+
con2 = env.getSlaveConnection()
436+
values2 = con2.execute_command('AI.TENSORGET', 'c{$}', 'VALUES')
437+
env.assertEqual(values2, values)
438+
392439
def test_pytorch_scriptexecute_errors(env):
393440
if not TEST_PT:
394441
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
@@ -484,7 +531,7 @@ def test_pytorch_scriptexecute_variadic_errors(env):
484531
check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 'LIST_INPUTS', 'OUTPUTS')
485532

486533
check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'LIST_INPUTS', 1, 'a{$}', 'LIST_INPUTS', 1, 'b{$}' 'OUTPUTS')
487-
534+
488535

489536
def test_pytorch_scriptinfo(env):
490537
if not TEST_PT:

0 commit comments

Comments
 (0)