-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtests.py
123 lines (90 loc) · 3.39 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import unittest, copy, time
import torch
from problem import generators_v2 as gen
from dnc_arity_list import DNC
import problem.my_air_cargo_problems as mac
from problem.lp_utils import (
decode_state,
)
import run
from visualize import wuddido as viz
def test_solve_with_logic(data):
print('\nWITH LOGICAL HEURISTICS')
print('GOAL', data.current_problem.goal, '\n')
start = time.time()
f = 0
for n in range(20):
log_actions = data.best_logic(data.get_raw_actions(mode='all'))
if log_actions != []:
print('chosen', log_actions[0])
data.send_action(data.expr_to_vec(log_actions[0]))
else:
f = n
break
end = time.time()
print('DONE in {} steps, {:0.4f} s'.format(f, end - start))
def test_solve_with_algo(data):
print('\nWITH ASTAR')
print('GOAL', data.current_problem.goal, '\n')
f = 0
start2 = time.time()
for n in range(20):
actions = data.get_raw_actions(mode='best')
# print(actions)
if actions != []:
print('chosen', actions[0])
data.send_action(data.expr_to_vec(actions[0]))
else:
f = n
break
end2 = time.time()
print('DONE in {} steps, {:0.4f} s'.format(f, end2 - start2 ))
sample_args = {'num_plane': 2, 'num_cargo': 2,
'num_airport': 2,
'one_hot_size': [4,6, 4, 6, 4, 6],
'plan_phase': 2 * 3,
'cuda': False, 'batch_size': 1,
'encoding': 2, 'solve': True, 'mapping': None}
class Misc(unittest.TestCase):
def setUp(self):
self.dataspec = {'solve': True, 'mapping': None,
'num_plane': 2, 'one_hot_size': [4, 6, 4, 6, 4, 6],
'num_airport': 2, 'plan_phase': 6,
'encoding': 2, 'batch_size': 1, 'num_cargo': 2}
self.dataspec2 = {'solve': True, 'mapping': None,
'num_plane': 3, 'one_hot_size': [4, 6, 4, 6, 4, 6],
'num_airport': 3, 'plan_phase': 6,
'encoding': 2, 'batch_size': 1, 'num_cargo': 3}
def est_cache(self):
data = gen.AirCargoData(**self.dataspec)
data.make_new_problem()
test_solve_with_logic(copy.deepcopy(data))
test_solve_with_algo(copy.deepcopy(data))
print('\n\n ROUND 2')
data.make_new_problem()
test_solve_with_logic(copy.deepcopy(data))
test_solve_with_algo(copy.deepcopy(data))
def est_searches(self):
problem = mac.air_cargo_p1()
ds = decode_state(problem.initial, problem.state_map)
class TestVis(unittest.TestCase):
def setUp(self):
self.folder = '1512692566_clip_cont2_40'
iter = 109 #
self.base = './models/{}/checkpts/{}/dnc_model.pkl'
dict1 = torch.load(self.base.format(self.folder, iter))
self.data = gen.AirCargoData(**sample_args)
args = run.dnc_args.copy()
args['output_size'] = self.data.nn_in_size
args['word_len'] = self.data.nn_out_size
self.Dnc = DNC(**args)
self.Dnc.load_state_dict(dict1)
pass
def test_show_state(self):
rand_vec = torch.randn(39, 1)
viz.ix_to_color(rand_vec)
def test_run(self):
record = viz.recorded_step(self.data, self.Dnc)
viz.make_usage_viz(record)
if __name__ == '__main__':
unittest.main()