1
1
out_from_state (state) = state
2
2
out_from_state (state:: Tuple ) = state[1 ]
3
3
4
- function scan (cell, x, state0)
5
- state = state0
4
+ function scan (cell, x, state)
6
5
y = []
7
6
for x_t in eachslice (x, dims = 2 )
8
7
state = cell (x_t, state)
@@ -12,7 +11,47 @@ function scan(cell, x, state0)
12
11
return stack (y, dims = 2 )
13
12
end
14
13
14
+ """
15
+ Recurrent(cell)
16
+
17
+ Create a recurrent layer that processes entire sequences out
18
+ of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
19
+ similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
20
+
21
+ The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
22
+ a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
23
+ the initial hidden state. The output of the `cell` is considered to be:
24
+ 1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
25
+ 2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
26
+
27
+ # Forward
28
+
29
+ rnn(x, [state])
30
+
31
+ The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`,
32
+ where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
33
+
34
+ The operation performed is semantically equivalent to the following code:
35
+ ```julia
36
+ state = Flux.initialstates(cell)
37
+ out = []
38
+ for x_t in eachslice(x, dims = 2)
39
+ state = cell(x_t, state)
40
+ out = [out..., get_output(state)]
41
+ end
42
+ stack(out, dims = 2)
43
+ ```
44
+ """
45
+ struct Recurrent{M}
46
+ cell:: M
47
+ end
48
+
49
+ @layer Recurrent
15
50
51
+ initialstates (rnn:: Recurrent ) = initialstates (rnn. cell)
52
+
53
+ (rnn:: Recurrent )(x:: AbstractArray ) = rnn (x, initialstates (rnn))
54
+ (rnn:: Recurrent )(x:: AbstractArray , state) = scan (rnn. cell, x, state)
16
55
17
56
# Vanilla RNN
18
57
@doc raw """
87
126
initialstates(rnn) -> AbstractVector
88
127
89
128
Return the initial hidden state for the given recurrent cell or recurrent layer.
129
+ Should be implemented for all recurrent cells and layers.
90
130
91
131
# Example
92
132
```julia
93
- using Flux
94
-
95
133
# Create an RNNCell from input dimension 10 to output dimension 20
96
134
rnn = RNNCell(10 => 20)
97
135
98
136
# Get the initial hidden state
99
- h0 = initialstates(rnn)
137
+ h0 = Flux. initialstates(rnn)
100
138
101
139
# Get some input data
102
140
x = rand(Float32, 10)
@@ -107,22 +145,20 @@ res = rnn(x, h0)
107
145
initialstates (rnn:: RNNCell ) = zeros_like (rnn. Wh, size (rnn. Wh, 2 ))
108
146
109
147
function RNNCell (
110
- (in, out):: Pair ,
111
- σ = tanh;
112
- init_kernel = glorot_uniform,
113
- init_recurrent_kernel = glorot_uniform,
114
- bias = true ,
115
- )
148
+ (in, out):: Pair ,
149
+ σ = tanh;
150
+ init_kernel = glorot_uniform,
151
+ init_recurrent_kernel = glorot_uniform,
152
+ bias = true ,
153
+ )
154
+
116
155
Wi = init_kernel (out, in)
117
156
Wh = init_recurrent_kernel (out, out)
118
157
b = create_bias (Wi, bias, size (Wi, 1 ))
119
158
return RNNCell (σ, Wi, Wh, b)
120
159
end
121
160
122
- function (rnn:: RNNCell )(x:: AbstractVecOrMat )
123
- state = initialstates (rnn)
124
- return rnn (x, state)
125
- end
161
+ (rnn:: RNNCell )(x:: AbstractVecOrMat ) = rnn (x, initialstates (rnn))
126
162
127
163
function (m:: RNNCell )(x:: AbstractVecOrMat , h:: AbstractVecOrMat )
128
164
_size_check (m, x, 1 => size (m. Wi, 2 ))
300
336
301
337
@layer LSTMCell
302
338
303
- function initialstates (lstm:: LSTMCell )
339
+ function initialstates (lstm:: LSTMCell )
304
340
return zeros_like (lstm. Wh, size (lstm. Wh, 2 )), zeros_like (lstm. Wh, size (lstm. Wh, 2 ))
305
341
end
306
342
0 commit comments