-
Notifications
You must be signed in to change notification settings - Fork 0
/
combine_bc_mb.py
119 lines (89 loc) · 4.65 KB
/
combine_bc_mb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import os
import torch
import random
import pyspiel
import numpy as np
from absl import app
import os.path as osp
from absl import flags
from open_spiel.python import policy
from open_spiel.python.algorithms import exploitability
from policy_wrapper import deep_cfr_policy
from network.ensemble_mb_bc import mix_policy
FLAGS = flags.FLAGS
flags.DEFINE_integer("seed", 1, "Seed.")
flags.DEFINE_integer("proportion", 10, "dataset proportion")
flags.DEFINE_string("game_name", "kuhn_poker", "Game name.")
flags.DEFINE_integer("n_players", 4, "The number of players.")
# Behavior Clone Strategy Location
flags.DEFINE_string("bc_policy_location", "mix_offline_dataset_behavior_clone_policy/", "offline data location")
flags.DEFINE_string("bc_policy_file_name",
"/seed_1_game_kuhn_poker_players_4_hidden_layer_64_buffer_10000_lr_0.05_train_epoch_5000_batch_size_128_policy.pkl",
"Behavior Clone Strategy Location")
flags.DEFINE_string("mb_policy_location", "mb_deep_cfr/mb_method_results/mb_deep_cfr_train_policy/", "offline data location")
flags.DEFINE_string("mb_policy_file_name",
"/policy_train_data_10000_train_epoch_5000_proportion_10_conv_0.6931734493672149.pkl",
"Behavior Clone Strategy Location")
flags.DEFINE_integer("replay_buffer", 500, "env model replay buffer")
flags.DEFINE_string("device", "cpu", "device type")
FLAGS.device = "cuda" if torch.cuda.is_available() else "cpu"
def tabular_policy_from_callable(game, behavior_policy, players=None):
tabular_policy = policy.TabularPolicy(game, players)
for state_index, state in enumerate(tabular_policy.states):
cur_player = state.current_player()
legal_actions = state.legal_actions(cur_player)
info_state_vector = np.array(state.information_state_tensor())
if len(info_state_vector.shape) == 1:
info_state_vector = np.expand_dims(info_state_vector, axis=0)
info_state_vector = torch.FloatTensor(info_state_vector).to(FLAGS.device)
strategy = behavior_policy[cur_player].step(info_state_vector).squeeze(0).tolist()
action_probabilities = {action: strategy[action] for action in legal_actions}
infostate_policy = [action_probabilities.get(action, 0.) for action in range(game.num_distinct_actions())]
tabular_policy.action_probability_array[state_index, :] = infostate_policy
return tabular_policy
def get_result_dir():
result_dir = "results_bc/" + FLAGS.game_name + '_' + str(FLAGS.n_players) + "_players"
if not osp.exists(result_dir):
os.makedirs(result_dir)
result_name = "game_{}_players_{}_replay_buffer_{}.txt".format(FLAGS.game_name, FLAGS.n_players, FLAGS.replay_buffer)
return osp.join(result_dir, result_name)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
# set seed
setup_seed(FLAGS.seed)
# load liar's dice game
# game = pyspiel.load_game(FLAGS.game_name, {"players": FLAGS.n_players, "numdice": FLAGS.numdice, "dice_sides": 6})
# load poker game
game = pyspiel.load_game(FLAGS.game_name, {"players": FLAGS.n_players})
# load phantom ttt game
# game = pyspiel.load_game(FLAGS.game_name, {"obstype": "reveal-nothing"})
results = []
min_weights = []
for w in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
conv_list = []
policy_list = []
for index in range(FLAGS.n_players):
bc_location = FLAGS.bc_policy_location + FLAGS.game_name + "_" + str(FLAGS.n_players) + "_players/" + \
str(index) + "_player_policy_proportion_" + str(FLAGS.proportion) + FLAGS.bc_policy_file_name
mb_location = FLAGS.mb_policy_location + FLAGS.game_name + "_" + str(FLAGS.n_players) + "_players" + FLAGS.mb_policy_file_name
deep_cfr_bc_model = deep_cfr_policy(torch.load(mb_location).to(FLAGS.device), device=FLAGS.device)
policy_list.append(mix_policy(bc_model=torch.load(bc_location).to(FLAGS.device), mb_model=deep_cfr_bc_model, bc_weight=w))
# compute nash_cov
average_policy = tabular_policy_from_callable(game, policy_list)
conv = exploitability.nash_conv(game, average_policy)
conv_list.append(conv)
results.append(min(conv_list))
min_weights.append(weights_list[conv_list.index(min(conv_list))])
print(conv)
print(results)
torch.save([results, min_weights], get_result_dir())
if __name__ == "__main__":
app.run(main)