-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathagent.h
165 lines (150 loc) · 4.63 KB
/
agent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/**
* Framework for 2048 & 2048-Like Games (C++ 11)
* agent.h: Define the behavior of variants of agents including players and environments
*
* Author: Hung Guei
* Computer Games and Intelligence (CGI) Lab, NYCU, Taiwan
* https://cgilab.nctu.edu.tw/
*/
#pragma once
#include <string>
#include <random>
#include <sstream>
#include <map>
#include <type_traits>
#include <algorithm>
#include <fstream>
#include "board.h"
#include "action.h"
#include "weight.h"
class agent {
public:
agent(const std::string& args = "") {
std::stringstream ss("name=unknown role=unknown " + args);
for (std::string pair; ss >> pair; ) {
std::string key = pair.substr(0, pair.find('='));
std::string value = pair.substr(pair.find('=') + 1);
meta[key] = { value };
}
}
virtual ~agent() {}
virtual void open_episode(const std::string& flag = "") {}
virtual void close_episode(const std::string& flag = "") {}
virtual action take_action(const board& b) { return action(); }
virtual bool check_for_win(const board& b) { return false; }
public:
virtual std::string property(const std::string& key) const { return meta.at(key); }
virtual void notify(const std::string& msg) { meta[msg.substr(0, msg.find('='))] = { msg.substr(msg.find('=') + 1) }; }
virtual std::string name() const { return property("name"); }
virtual std::string role() const { return property("role"); }
protected:
typedef std::string key;
struct value {
std::string value;
operator std::string() const { return value; }
template<typename numeric, typename = typename std::enable_if<std::is_arithmetic<numeric>::value, numeric>::type>
operator numeric() const { return numeric(std::stod(value)); }
};
std::map<key, value> meta;
};
/**
* base agent for agents with randomness
*/
class random_agent : public agent {
public:
random_agent(const std::string& args = "") : agent(args) {
if (meta.find("seed") != meta.end())
engine.seed(int(meta["seed"]));
}
virtual ~random_agent() {}
protected:
std::default_random_engine engine;
};
/**
* base agent for agents with weight tables and a learning rate
*/
class weight_agent : public agent {
public:
weight_agent(const std::string& args = "") : agent(args), alpha(0) {
if (meta.find("init") != meta.end())
init_weights(meta["init"]);
if (meta.find("load") != meta.end())
load_weights(meta["load"]);
if (meta.find("alpha") != meta.end())
alpha = float(meta["alpha"]);
}
virtual ~weight_agent() {
if (meta.find("save") != meta.end())
save_weights(meta["save"]);
}
protected:
virtual void init_weights(const std::string& info) {
std::string res = info; // comma-separated sizes, e.g., "65536,65536"
for (char& ch : res)
if (!std::isdigit(ch)) ch = ' ';
std::stringstream in(res);
for (size_t size; in >> size; net.emplace_back(size));
}
virtual void load_weights(const std::string& path) {
std::ifstream in(path, std::ios::in | std::ios::binary);
if (!in.is_open()) std::exit(-1);
uint32_t size;
in.read(reinterpret_cast<char*>(&size), sizeof(size));
net.resize(size);
for (weight& w : net) in >> w;
in.close();
}
virtual void save_weights(const std::string& path) {
std::ofstream out(path, std::ios::out | std::ios::binary | std::ios::trunc);
if (!out.is_open()) std::exit(-1);
uint32_t size = net.size();
out.write(reinterpret_cast<char*>(&size), sizeof(size));
for (weight& w : net) out << w;
out.close();
}
protected:
std::vector<weight> net;
float alpha;
};
/**
* default random environment
* add a new random tile to an empty cell
* 2-tile: 90%
* 4-tile: 10%
*/
class random_placer : public random_agent {
public:
random_placer(const std::string& args = "") : random_agent("name=place role=placer " + args),
space({ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }), popup(0, 9) {}
virtual action take_action(const board& after) {
std::shuffle(space.begin(), space.end(), engine);
for (int pos : space) {
if (after(pos) != 0) continue;
board::cell tile = popup(engine) ? 1 : 2;
return action::place(pos, tile);
}
return action();
}
private:
std::array<int, 16> space;
std::uniform_int_distribution<int> popup;
};
/**
* random player, i.e., slider
* select a legal action randomly
*/
class random_slider : public random_agent {
public:
random_slider(const std::string& args = "") : random_agent("name=slide role=slider " + args),
opcode({ 0, 1, 2, 3 }) {}
virtual action take_action(const board& before) {
std::shuffle(opcode.begin(), opcode.end(), engine);
for (int op : opcode) {
board::reward reward = board(before).slide(op);
if (reward != -1) return action::slide(op);
}
return action();
}
private:
std::array<int, 4> opcode;
};