-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathinference.py
53 lines (48 loc) · 1.57 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
Run a trained agent and get generated maps
"""
import model
from stable_baselines import PPO2
import time
from utils import make_vec_envs
def infer(game, representation, model_path, **kwargs):
"""
- max_trials: The number of trials per evaluation.
- infer_kwargs: Args to pass to the environment.
"""
env_name = '{}-{}-v0'.format(game, representation)
if game == "binary":
model.FullyConvPolicy = model.FullyConvPolicyBigMap
kwargs['cropped_size'] = 28
elif game == "zelda":
model.FullyConvPolicy = model.FullyConvPolicyBigMap
kwargs['cropped_size'] = 22
elif game == "sokoban":
model.FullyConvPolicy = model.FullyConvPolicySmallMap
kwargs['cropped_size'] = 10
kwargs['render'] = True
agent = PPO2.load(model_path)
env = make_vec_envs(env_name, representation, None, 1, **kwargs)
obs = env.reset()
obs = env.reset()
dones = False
for i in range(kwargs.get('trials', 1)):
while not dones:
action, _ = agent.predict(obs)
obs, _, dones, info = env.step(action)
if kwargs.get('verbose', False):
print(info[0])
if dones:
break
time.sleep(0.2)
################################## MAIN ########################################
game = 'binary'
representation = 'narrow'
model_path = 'models/{}/{}/model_1.pkl'.format(game, representation)
kwargs = {
'change_percentage': 0.4,
'trials': 1,
'verbose': True
}
if __name__ == '__main__':
infer(game, representation, model_path, **kwargs)