-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPhase2c_eval.py
99 lines (86 loc) · 5.27 KB
/
Phase2c_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#from http.client import NOT_IMPLEMENTED
#import random
#from matplotlib.pyplot import get
#import torch.nn as nn
#import torch.nn.functional as F
#from modules.ppo.ppo_wrappers import VarTargetWrapper
#import modules.gnn.nfm_gen
import torch
from modules.ppo.helpfuncs import get_super_env, CreateEnv, eval_simple, evaluate_ppo, check_custom_position_probs
from modules.rl.environments import SuperEnv
from modules.rl.rl_policy import ActionMaskedPolicySB3_PPO
from modules.ppo.models_sb3_s2v import s2v_ActorCriticPolicy, Struc2VecExtractor, DeployablePPOPolicy
from modules.ppo.models_sb3_gat2 import Gat2_ActorCriticPolicy, DeployablePPOPolicy_gat2
from sb3_contrib import MaskablePPO
from modules.gnn.construct_trainsets import ConstructTrainSet, get_train_configs
from modules.sim.simdata_utils import SimulateInteractiveMode_PPO, SimulateAutomaticMode_PPO
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def ManualEval(config):
# Path directions to best model from ppo experiment
#runname='test' # bestrun
#runname = 'SuperSet_noU'
#train_configs=get_train_configs(runname, load_trainset=False)
#seed = 0
#config = train_configs[runname]
logdir = config['logdir']
modeldir = logdir+'/SEED'+str(config['seed0'])+'/saved_models'
print(logdir)
# OPTIONS TO LOAD WORLDS:
# 1. 3x3 graph permutations
#config['solve_select']='solvable'
#env, _ = get_super_env(Uselected=[2,3], Eselected=[4], config=config)
## 2. Set of specific worlds
#global_env=[]
# world_names=[
#'Manhattan5x5_FixedEscapeInit',
#'Manhattan5x5_VariableEscapeInit',
#'MetroU3_e17tborder_FixedEscapeInit',
#'MetroU3_e1t31_FixedEscapeInit',
#'SparseManhattan5x5' ]
#env = CreateEnv('MetroU3_e17tborder_FixedEscapeInit',max_nodes=config['max_nodes'],var_targets=None)#[4,4])
# for w in world_names:
# env = CreateEnv(w,max_nodes=config['max_nodes'],var_targets=[4,4])
# global_env.append(env)
# env=SuperEnv(global_env,hashint2env=None,max_possible_num_nodes=33)#,probs=[1,10,1,1,1,1,1,1])
## 3. Individual environment
#env = CreateEnv('MetroU3_e17tborder_FixedEscapeInit',max_nodes=config['max_nodes'],var_targets=None, remove_world_pool=False)
#env = CreateEnv('MetroU3_e1t31_FixedEscapeInit',max_nodes=33,nfm_func_name =config['nfm_func'],var_targets=[1,1], remove_world_pool=True)
maxnodes=25
env = CreateEnv('Manhattan5x5_FixedEscapeInit',max_nodes=maxnodes,nfm_func_name =config['nfm_func'],var_targets=None, remove_world_pool=False, apply_wrappers=True)
#env = CreateEnv('Manhattan3x3_WalkAround',max_nodes=maxnodes,nfm_func_name =config['nfm_func'],var_targets=None, remove_world_pool=False, apply_wrappers=True)
#env = CreateEnv('NWB_test_VariableEscapeInit',max_nodes=975,nfm_func_name =config['nfm_func'],var_targets=None, remove_world_pool=False, apply_wrappers=True)
#env = CreateEnv('NWB_test_VariableEscapeInit',nfm_func_name =config['nfm_func'],max_nodes=975,var_targets=None, remove_world_pool=True)
## 4. Pre-defined training set for ppo experiments
#env, _ = ConstructTrainSet(config)
## Load pre-saved model
saved_model = MaskablePPO.load(modeldir+"/model_best")
if config['qnet'] == 's2v':
#saved_policy = s2v_ActorCriticPolicy.load(modeldir+"/policy_last")
saved_policy_deployable=DeployablePPOPolicy(env, saved_model.policy)
ppo_policy = ActionMaskedPolicySB3_PPO(saved_policy_deployable, deterministic=True)
elif config['qnet'] == 'gat2':
#saved_policy = Gat2_ActorCriticPolicy.load(modeldir+"/policy_last")
saved_policy_deployable=DeployablePPOPolicy_gat2(env, saved_model.policy,max_num_nodes=maxnodes)
ppo_policy = ActionMaskedPolicySB3_PPO(saved_policy_deployable, deterministic=True)
# OPTIONS TO PERFORM TESTS
## 1. Evaluate a specific constellation on the graph
## Metro example, left turn or right turn
#check_custom_position_probs(env,saved_model.policy,hashint=None,entry=None,targetnodes=[13,22,23,29],epath=[17],upaths=[[23,22],[30,27],[32,7]],max_nodes=33,logdir=logdir)
#check_custom_position_probs(env,saved_model.policy,hashint=None,entry=None,targetnodes=[13,22,23,29],epath=[17],upaths=[[12,13],[30,27],[32,7]],max_nodes=33,logdir=logdir)
## Metro example, long range shortest path
#check_custom_position_probs(env,saved_model.policy,hashint=None,entry=None,targetnodes=[31],epath=[1],upaths=[[14]],max_nodes=33,logdir=logdir)
#check_custom_position_probs(env,saved_model.policy,hashint=None,entry=None,targetnodes=[31],epath=[1],upaths=[[14,17]],max_nodes=33,logdir=logdir)
## Metro example, long range shortest path with one pursuer
# epath=[1,5,6,7,14,18,19,25,29,31]
# upaths=[]#[[9,8,7]]
# check_custom_position_probs(env,saved_model.policy,hashint=None,entry=None,targetnodes=[31],epath=epath,upaths=upaths,max_nodes=33,logdir=logdir)
## 2. Run Interactive simulation
# plots are updated in the results folder
# while True:
# a = SimulateInteractiveMode_PPO(env, model = saved_model, t_suffix=True)
# if a == 'Q': break
## 3. Run automated simulation (stepping)
while True:
entries=None#[5012,218,3903]
a = SimulateAutomaticMode_PPO(env, ppo_policy, t_suffix=False, entries=entries)
if a == 'Q': break