-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_model.py
39 lines (27 loc) · 934 Bytes
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from customized_environments.envs.my_agent import CustomAgent
import gym
from stable_baselines.deepq.policies import CnnPolicy
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines import DQN
from absl import flags
FLAGS = flags.FLAGS
FLAGS([''])
name = "dqn_mlp_std_simple"
learn_type='DQN'
model_iteration = 1
# create vectorized environment
env = DummyVecEnv([lambda: CustomAgent(learn_type=learn_type)])
if model_iteration > 0:
if learn_type == "DQN":
model = DQN.load("gym_ouput/" + name + "/it" + str(model_iteration), env=env)
elif learn_type == "PPO2":
model = PPO2.load("gym_ouput/" + name + "/it" + str(model_iteration), env=env)
else:
print("invalid model_iteration")
exit
quit
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)