Skip to content

Commit 40b7f70

Browse files
MartinuzziFrancescoCarloLucibello
andauthoredDec 10, 2024
Adding initialstates function to RNNs (#2541)
* added initialstates * added initialstates to recurrent layers, added docstrings * fixed small errors * streamlined implementation, added tests * Update docs/src/reference/models/layers.md --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent 8c60006 commit 40b7f70

File tree

4 files changed

+131
-29
lines changed

4 files changed

+131
-29
lines changed
 

‎docs/src/reference/models/layers.md

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ GRUCell
112112
GRU
113113
GRUv3Cell
114114
GRUv3
115+
Flux.initialstates
115116
```
116117

117118
## Normalisation & Regularisation

‎src/Flux.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
5454
# layers
5555
Bilinear, Scale,
5656
# utils
57-
outputsize, state, create_bias, @layer,
57+
outputsize, state, create_bias, @layer, initialstates,
5858
# from OneHotArrays.jl
5959
onehot, onehotbatch, onecold,
6060
# from Train

‎src/layers/recurrent.jl

+96-25
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ The arguments of the forward pass are:
3232
3333
- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
3434
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
35-
If not provided, it is assumed to be a vector of zeros.
35+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
3636
3737
# Examples
3838
@@ -69,6 +69,29 @@ end
6969

7070
@layer RNNCell
7171

72+
"""
73+
initialstates(rnn) -> AbstractVector
74+
75+
Return the initial hidden state for the given recurrent cell or recurrent layer.
76+
77+
# Example
78+
```julia
79+
using Flux
80+
81+
# Create an RNNCell from input dimension 10 to output dimension 20
82+
rnn = RNNCell(10 => 20)
83+
84+
# Get the initial hidden state
85+
h0 = initialstates(rnn)
86+
87+
# Get some input data
88+
x = rand(Float32, 10)
89+
90+
# Run forward
91+
res = rnn(x, h0)
92+
"""
93+
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
94+
7295
function RNNCell(
7396
(in, out)::Pair,
7497
σ = tanh;
@@ -82,7 +105,10 @@ function RNNCell(
82105
return RNNCell(σ, Wi, Wh, b)
83106
end
84107

85-
(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
108+
function (rnn::RNNCell)(x::AbstractVecOrMat)
109+
state = initialstates(rnn)
110+
return rnn(x, state)
111+
end
86112

87113
function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
88114
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -130,7 +156,7 @@ The arguments of the forward pass are:
130156
- `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`.
131157
- `h`: The initial hidden state of the RNN.
132158
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
133-
If not provided, it is assumed to be a vector of zeros.
159+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
134160
135161
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
136162
@@ -173,12 +199,17 @@ end
173199

174200
@layer RNN
175201

202+
initialstates(rnn::RNN) = initialstates(rnn.cell)
203+
176204
function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
177205
cell = RNNCell(in => out, σ; cell_kwargs...)
178206
return RNN(cell)
179207
end
180208

181-
(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
209+
function (rnn::RNN)(x::AbstractArray)
210+
state = initialstates(rnn)
211+
return rnn(x, state)
212+
end
182213

183214
function (m::RNN)(x::AbstractArray, h)
184215
@assert ndims(x) == 2 || ndims(x) == 3
@@ -231,7 +262,7 @@ The arguments of the forward pass are:
231262
- `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`.
232263
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
233264
They should be vectors of size `out` or matrices of size `out x batch_size`.
234-
If not provided, they are assumed to be vectors of zeros.
265+
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
235266
236267
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
237268
@@ -261,6 +292,10 @@ end
261292

262293
@layer LSTMCell
263294

295+
function initialstates(lstm:: LSTMCell)
296+
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
297+
end
298+
264299
function LSTMCell(
265300
(in, out)::Pair;
266301
init_kernel = glorot_uniform,
@@ -274,10 +309,9 @@ function LSTMCell(
274309
return cell
275310
end
276311

277-
function (m::LSTMCell)(x::AbstractVecOrMat)
278-
h = zeros_like(x, size(m.Wh, 2))
279-
c = zeros_like(h)
280-
return m(x, (h, c))
312+
function (lstm::LSTMCell)(x::AbstractVecOrMat)
313+
state, cstate = initialstates(lstm)
314+
return lstm(x, (state, cstate))
281315
end
282316

283317
function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
@@ -332,7 +366,7 @@ The arguments of the forward pass are:
332366
- `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
333367
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
334368
They should be vectors of size `out` or matrices of size `out x batch_size`.
335-
If not provided, they are assumed to be vectors of zeros.
369+
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
336370
337371
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
338372
in tensors of size `out x len` or `out x len x batch_size`.
@@ -363,15 +397,16 @@ end
363397

364398
@layer LSTM
365399

400+
initialstates(lstm::LSTM) = initialstates(lstm.cell)
401+
366402
function LSTM((in, out)::Pair; cell_kwargs...)
367403
cell = LSTMCell(in => out; cell_kwargs...)
368404
return LSTM(cell)
369405
end
370406

371-
function (m::LSTM)(x::AbstractArray)
372-
h = zeros_like(x, size(m.cell.Wh, 2))
373-
c = zeros_like(h)
374-
return m(x, (h, c))
407+
function (lstm::LSTM)(x::AbstractArray)
408+
state, cstate = initialstates(lstm)
409+
return lstm(x, (state, cstate))
375410
end
376411

377412
function (m::LSTM)(x::AbstractArray, (h, c))
@@ -422,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences.
422457
The arguments of the forward pass are:
423458
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
424459
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
425-
If not provided, it is assumed to be a vector of zeros.
460+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
426461
427462
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
428463
@@ -447,6 +482,8 @@ end
447482

448483
@layer GRUCell
449484

485+
initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))
486+
450487
function GRUCell(
451488
(in, out)::Pair;
452489
init_kernel = glorot_uniform,
@@ -459,7 +496,10 @@ function GRUCell(
459496
return GRUCell(Wi, Wh, b)
460497
end
461498

462-
(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
499+
function (gru::GRUCell)(x::AbstractVecOrMat)
500+
state = initialstates(gru)
501+
return gru(x, state)
502+
end
463503

464504
function (m::GRUCell)(x::AbstractVecOrMat, h)
465505
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -514,7 +554,7 @@ The arguments of the forward pass are:
514554
515555
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
516556
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
517-
If not provided, it is assumed to be a vector of zeros.
557+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
518558
519559
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
520560
@@ -534,14 +574,16 @@ end
534574

535575
@layer GRU
536576

577+
initialstates(gru::GRU) = initialstates(gru.cell)
578+
537579
function GRU((in, out)::Pair; cell_kwargs...)
538580
cell = GRUCell(in => out; cell_kwargs...)
539581
return GRU(cell)
540582
end
541583

542-
function (m::GRU)(x::AbstractArray)
543-
h = zeros_like(x, size(m.cell.Wh, 2))
544-
return m(x, h)
584+
function (gru::GRU)(x::AbstractArray)
585+
state = initialstates(gru)
586+
return gru(x, state)
545587
end
546588

547589
function (m::GRU)(x::AbstractArray, h)
@@ -590,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
590632
The arguments of the forward pass are:
591633
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
592634
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
593-
If not provided, it is assumed to be a vector of zeros.
635+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
594636
595637
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
596638
"""
@@ -603,6 +645,8 @@ end
603645

604646
@layer GRUv3Cell
605647

648+
initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))
649+
606650
function GRUv3Cell(
607651
(in, out)::Pair;
608652
init_kernel = glorot_uniform,
@@ -616,7 +660,10 @@ function GRUv3Cell(
616660
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
617661
end
618662

619-
(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
663+
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
664+
state = initialstates(gru)
665+
return gru(x, state)
666+
end
620667

621668
function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
622669
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -667,21 +714,45 @@ but only a less popular variant.
667714
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
668715
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
669716
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
717+
718+
# Forward
719+
720+
gruv3(x, [h])
721+
722+
The arguments of the forward pass are:
723+
724+
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
725+
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
726+
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
727+
728+
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
729+
730+
# Examples
731+
732+
```julia
733+
d_in, d_out, len, batch_size = 2, 3, 4, 5
734+
gruv3 = GRUv3(d_in => d_out)
735+
x = rand(Float32, (d_in, len, batch_size))
736+
h0 = zeros(Float32, d_out)
737+
h = gruv3(x, h0) # out x len x batch_size
738+
```
670739
"""
671740
struct GRUv3{M}
672741
cell::M
673742
end
674743

675744
@layer GRUv3
676745

746+
initialstates(gru::GRUv3) = initialstates(gru.cell)
747+
677748
function GRUv3((in, out)::Pair; cell_kwargs...)
678749
cell = GRUv3Cell(in => out; cell_kwargs...)
679750
return GRUv3(cell)
680751
end
681752

682-
function (m::GRUv3)(x::AbstractArray)
683-
h = zeros_like(x, size(m.cell.Wh, 2))
684-
return m(x, h)
753+
function (gru::GRUv3)(x::AbstractArray)
754+
state = initialstates(gru)
755+
return gru(x, state)
685756
end
686757

687758
function (m::GRUv3)(x::AbstractArray, h)

‎test/layers/recurrent.jl

+33-3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
test_gradients(r, x, h, loss=loss3) # splat
4444
test_gradients(r, x, h, loss=loss4) # vcat and stack
4545

46+
# initial states are zero
47+
@test Flux.initialstates(r) zeros(Float32, 5)
48+
4649
# no initial state same as zero initial state
4750
@test r(x[1]) r(x[1], zeros(Float32, 5))
4851

@@ -80,8 +83,11 @@ end
8083
@test size(y) == (4, 3, 1)
8184
test_gradients(model, x)
8285

86+
rnn = model.rnn
87+
# initial states are zero
88+
@test Flux.initialstates(rnn) zeros(Float32, 4)
89+
8390
# no initial state same as zero initial state
84-
rnn = model.rnn
8591
@test rnn(x) rnn(x, zeros(Float32, 4))
8692

8793
x = rand(Float32, 2, 3)
@@ -120,6 +126,11 @@ end
120126
test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1]))
121127
test_gradients(cell, x, (h, c), loss = loss)
122128

129+
# initial states are zero
130+
h0, c0 = Flux.initialstates(cell)
131+
@test h0 zeros(Float32, 5)
132+
@test c0 zeros(Float32, 5)
133+
123134
# no initial state same as zero initial state
124135
hnew1, cnew1 = cell(x[1])
125136
hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5)))
@@ -166,6 +177,12 @@ end
166177
@test size(h) == (4, 3)
167178
@test c isa Array{Float32, 2}
168179
@test size(c) == (4, 3)
180+
181+
# initial states are zero
182+
h0, c0 = Flux.initialstates(lstm)
183+
@test h0 zeros(Float32, 4)
184+
@test c0 zeros(Float32, 4)
185+
169186
# no initial state same as zero initial state
170187
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
171188
@test h h1
@@ -192,6 +209,9 @@ end
192209
h = randn(Float32, 5)
193210
test_gradients(r, x, h; loss)
194211

212+
# initial states are zero
213+
@test Flux.initialstates(r) zeros(Float32, 5)
214+
195215
# no initial state same as zero initial state
196216
@test r(x[1]) r(x[1], zeros(Float32, 5))
197217

@@ -227,8 +247,12 @@ end
227247
@test size(y) == (4, 3, 1)
228248
test_gradients(model, x)
229249

230-
# no initial state same as zero initial state
250+
231251
gru = model.gru
252+
# initial states are zero
253+
@test Flux.initialstates(gru) zeros(Float32, 4)
254+
255+
# no initial state same as zero initial state
232256
@test gru(x) gru(x, zeros(Float32, 4))
233257

234258
# No Bias
@@ -246,6 +270,9 @@ end
246270
h = randn(Float32, 5)
247271
test_gradients(r, x, h)
248272

273+
# initial states are zero
274+
@test Flux.initialstates(r) zeros(Float32, 5)
275+
249276
# no initial state same as zero initial state
250277
@test r(x) r(x, zeros(Float32, 5))
251278

@@ -277,7 +304,10 @@ end
277304
@test size(y) == (4, 3, 1)
278305
test_gradients(model, x)
279306

280-
# no initial state same as zero initial state
281307
gru = model.gru
308+
# initial states are zero
309+
@test Flux.initialstates(gru) zeros(Float32, 4)
310+
311+
# no initial state same as zero initial state
282312
@test gru(x) gru(x, zeros(Float32, 4))
283313
end

0 commit comments

Comments
 (0)