-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathvanilla.jl
160 lines (148 loc) · 6.17 KB
/
vanilla.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# The solver type
"""
ValueIterationSolver <: Solver
The solver type. Contains the following parameters that can be passed as keyword arguments to the constructor
- max_iterations::Int64, the maximum number of iterations value iteration runs for (default 100)
- belres::Float64, the Bellman residual (default 1e-3)
- verbose::Bool, if set to true, the bellman residual and the time per iteration will be printed to STDOUT (default false)
- include_Q::Bool, if set to true, the solver outputs the Q values in addition to the utility and the policy (default true)
- init_util::Vector{Float64}, provides a custom initialization of the utility vector. (initializes utility to 0 by default)
"""
mutable struct ValueIterationSolver <: Solver
max_iterations::Int64 # max number of iterations
belres::Float64 # the Bellman Residual
verbose::Bool
include_Q::Bool
init_util::Vector{Float64}
end
# Default constructor
function ValueIterationSolver(;max_iterations::Int64 = 100,
belres::Float64 = 1e-3,
verbose::Bool = false,
include_Q::Bool = true,
init_util::Vector{Float64}=Vector{Float64}(undef, 0))
return ValueIterationSolver(max_iterations, belres, verbose, include_Q, init_util)
end
@POMDP_require solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
@req reward(::P,::S,::A,::S)
@req stateindex(::P,::S)
@req actionindex(::P, ::A)
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
@req length(::typeof(ss))
@req length(::typeof(as))
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
end
#####################################################################
# Solve runs the value iteration algorithm.
# The policy input argument is either provided by the user or
# initialized during the function call.
# Verbose is a flag that triggers text output to the command line
# Example code for running the function:
# mdp = GridWorld(10, 10) # initialize a 10x10 grid world MDP (user written code)
# solver = ValueIterationSolver(max_iterations=40, belres=1e-3)
# policy = ValueIterationPolicy(mdp)
# solve(solver, mdp, policy, verbose=true)
#####################################################################
function solve(solver::ValueIterationSolver, mdp::MDP; kwargs...)
# deprecation warning - can be removed when Julia 1.0 is adopted
if !isempty(kwargs)
@warn("Keyword args for solve(::ValueIterationSolver, ::MDP) are no longer supported. For verbose output, use the verbose option in the ValueIterationSolver")
end
@warn_requirements solve(solver, mdp)
# solver parameters
max_iterations = solver.max_iterations
belres = solver.belres
discount_factor = discount(mdp)
ns = length(states(mdp))
na = length(actions(mdp))
# intialize the utility and Q-matrix
if !isempty(solver.init_util)
@assert length(solver.init_util) == ns "Input utility dimension mismatch"
util = solver.init_util
else
util = zeros(ns)
end
include_Q = solver.include_Q
if include_Q
qmat = zeros(ns, na)
end
pol = zeros(Int64, ns)
total_time = 0.0
iter_time = 0.0
# create an ordered list of states for fast iteration
state_space = ordered_states(mdp)
# main loop
for i = 1:max_iterations
residual = 0.0
iter_time = @elapsed begin
# state loop
for (istate,s) in enumerate(state_space)
sub_aspace = actions(mdp, s)
if isterminal(mdp, s)
util[istate] = 0.0
pol[istate] = 1
else
old_util = util[istate] # for residual
max_util = -Inf
# action loop
# util(s) = max_a( R(s,a) + discount_factor * sum(T(s'|s,a)util(s') )
for a in sub_aspace
iaction = actionindex(mdp, a)
dist = transition(mdp, s, a) # creates distribution over neighbors
u = 0.0
for (sp, p) in weighted_iterator(dist)
p == 0.0 ? continue : nothing # skip if zero prob
r = reward(mdp, s, a, sp)
isp = stateindex(mdp, sp)
u += p * (r + discount_factor * util[isp])
end
new_util = u
if new_util > max_util
max_util = new_util
pol[istate] = iaction
end
include_Q ? (qmat[istate, iaction] = new_util) : nothing
end # action
# update the value array
util[istate] = max_util
diff = abs(max_util - old_util)
diff > residual ? (residual = diff) : nothing
end
end # state
end # time
total_time += iter_time
solver.verbose ? @printf("[Iteration %-4d] residual: %10.3G | iteration runtime: %10.3f ms, (%10.3G s total)\n", i, residual, iter_time*1000.0, total_time) : nothing
residual < belres ? break : nothing
end # main
if include_Q
return ValueIterationPolicy(mdp, qmat, util, pol)
else
return ValueIterationPolicy(mdp, utility=util, policy=pol, include_Q=false)
end
end
function solve(::ValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::ValueIterationSolver, ::POMDP)` is not supported,
`ValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end