Skip to content

Commit 4ef5d11

Browse files
Recurrence
1 parent f96bd58 commit 4ef5d11

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

src/layers/recurrent.jl

+52-16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
out_from_state(state) = state
22
out_from_state(state::Tuple) = state[1]
33

4-
function scan(cell, x, state0)
5-
state = state0
4+
function scan(cell, x, state)
65
y = []
76
for x_t in eachslice(x, dims = 2)
87
state = cell(x_t, state)
@@ -12,7 +11,47 @@ function scan(cell, x, state0)
1211
return stack(y, dims = 2)
1312
end
1413

14+
"""
15+
Recurrent(cell)
16+
17+
Create a recurrent layer that processes entire sequences out
18+
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
19+
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
20+
21+
The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
22+
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
23+
the initial hidden state. The output of the `cell` is considered to be:
24+
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
25+
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
26+
27+
# Forward
28+
29+
rnn(x, [state])
30+
31+
The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`,
32+
where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
33+
34+
The operation performed is semantically equivalent to the following code:
35+
```julia
36+
state = Flux.initialstates(cell)
37+
out = []
38+
for x_t in eachslice(x, dims = 2)
39+
state = cell(x_t, state)
40+
out = [out..., get_output(state)]
41+
end
42+
stack(out, dims = 2)
43+
```
44+
"""
45+
struct Recurrent{M}
46+
cell::M
47+
end
48+
49+
@layer Recurrent
1550

51+
initialstates(rnn::Recurrent) = initialstates(rnn.cell)
52+
53+
(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
54+
(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
1655

1756
# Vanilla RNN
1857
@doc raw"""
@@ -87,16 +126,15 @@ end
87126
initialstates(rnn) -> AbstractVector
88127
89128
Return the initial hidden state for the given recurrent cell or recurrent layer.
129+
Should be implemented for all recurrent cells and layers.
90130
91131
# Example
92132
```julia
93-
using Flux
94-
95133
# Create an RNNCell from input dimension 10 to output dimension 20
96134
rnn = RNNCell(10 => 20)
97135
98136
# Get the initial hidden state
99-
h0 = initialstates(rnn)
137+
h0 = Flux.initialstates(rnn)
100138
101139
# Get some input data
102140
x = rand(Float32, 10)
@@ -107,22 +145,20 @@ res = rnn(x, h0)
107145
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
108146

109147
function RNNCell(
110-
(in, out)::Pair,
111-
σ = tanh;
112-
init_kernel = glorot_uniform,
113-
init_recurrent_kernel = glorot_uniform,
114-
bias = true,
115-
)
148+
(in, out)::Pair,
149+
σ = tanh;
150+
init_kernel = glorot_uniform,
151+
init_recurrent_kernel = glorot_uniform,
152+
bias = true,
153+
)
154+
116155
Wi = init_kernel(out, in)
117156
Wh = init_recurrent_kernel(out, out)
118157
b = create_bias(Wi, bias, size(Wi, 1))
119158
return RNNCell(σ, Wi, Wh, b)
120159
end
121160

122-
function (rnn::RNNCell)(x::AbstractVecOrMat)
123-
state = initialstates(rnn)
124-
return rnn(x, state)
125-
end
161+
(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn))
126162

127163
function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
128164
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -300,7 +336,7 @@ end
300336

301337
@layer LSTMCell
302338

303-
function initialstates(lstm:: LSTMCell)
339+
function initialstates(lstm::LSTMCell)
304340
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
305341
end
306342

0 commit comments

Comments
 (0)