@@ -32,7 +32,7 @@ The arguments of the forward pass are:
32
32
33
33
- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
34
34
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
35
- If not provided, it is assumed to be a vector of zeros.
35
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref) .
36
36
37
37
# Examples
38
38
69
69
70
70
@layer RNNCell
71
71
72
+ """
73
+ initialstates(rnn) -> AbstractVector
74
+
75
+ Return the initial hidden state for the given recurrent cell or recurrent layer.
76
+
77
+ # Example
78
+ ```julia
79
+ using Flux
80
+
81
+ # Create an RNNCell from input dimension 10 to output dimension 20
82
+ rnn = RNNCell(10 => 20)
83
+
84
+ # Get the initial hidden state
85
+ h0 = initialstates(rnn)
86
+
87
+ # Get some input data
88
+ x = rand(Float32, 10)
89
+
90
+ # Run forward
91
+ res = rnn(x, h0)
92
+ """
93
+ initialstates (rnn:: RNNCell ) = zeros_like (rnn. Wh, size (rnn. Wh, 2 ))
94
+
72
95
function RNNCell (
73
96
(in, out):: Pair ,
74
97
σ = tanh;
@@ -82,7 +105,10 @@ function RNNCell(
82
105
return RNNCell (σ, Wi, Wh, b)
83
106
end
84
107
85
- (m:: RNNCell )(x:: AbstractVecOrMat ) = m (x, zeros_like (x, size (m. Wh, 1 )))
108
+ function (rnn:: RNNCell )(x:: AbstractVecOrMat )
109
+ state = initialstates (rnn)
110
+ return rnn (x, state)
111
+ end
86
112
87
113
function (m:: RNNCell )(x:: AbstractVecOrMat , h:: AbstractVecOrMat )
88
114
_size_check (m, x, 1 => size (m. Wi, 2 ))
@@ -130,7 +156,7 @@ The arguments of the forward pass are:
130
156
- `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`.
131
157
- `h`: The initial hidden state of the RNN.
132
158
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
133
- If not provided, it is assumed to be a vector of zeros.
159
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref) .
134
160
135
161
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
136
162
@@ -173,12 +199,17 @@ end
173
199
174
200
@layer RNN
175
201
202
+ initialstates (rnn:: RNN ) = initialstates (rnn. cell)
203
+
176
204
function RNN ((in, out):: Pair , σ = tanh; cell_kwargs... )
177
205
cell = RNNCell (in => out, σ; cell_kwargs... )
178
206
return RNN (cell)
179
207
end
180
208
181
- (m:: RNN )(x:: AbstractArray ) = m (x, zeros_like (x, size (m. cell. Wh, 1 )))
209
+ function (rnn:: RNN )(x:: AbstractArray )
210
+ state = initialstates (rnn)
211
+ return rnn (x, state)
212
+ end
182
213
183
214
function (m:: RNN )(x:: AbstractArray , h)
184
215
@assert ndims (x) == 2 || ndims (x) == 3
@@ -231,7 +262,7 @@ The arguments of the forward pass are:
231
262
- `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`.
232
263
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
233
264
They should be vectors of size `out` or matrices of size `out x batch_size`.
234
- If not provided, they are assumed to be vectors of zeros.
265
+ If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref) .
235
266
236
267
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
237
268
261
292
262
293
@layer LSTMCell
263
294
295
+ function initialstates (lstm:: LSTMCell )
296
+ return zeros_like (lstm. Wh, size (lstm. Wh, 2 )), zeros_like (lstm. Wh, size (lstm. Wh, 2 ))
297
+ end
298
+
264
299
function LSTMCell (
265
300
(in, out):: Pair ;
266
301
init_kernel = glorot_uniform,
@@ -274,10 +309,9 @@ function LSTMCell(
274
309
return cell
275
310
end
276
311
277
- function (m:: LSTMCell )(x:: AbstractVecOrMat )
278
- h = zeros_like (x, size (m. Wh, 2 ))
279
- c = zeros_like (h)
280
- return m (x, (h, c))
312
+ function (lstm:: LSTMCell )(x:: AbstractVecOrMat )
313
+ state, cstate = initialstates (lstm)
314
+ return lstm (x, (state, cstate))
281
315
end
282
316
283
317
function (m:: LSTMCell )(x:: AbstractVecOrMat , (h, c))
@@ -332,7 +366,7 @@ The arguments of the forward pass are:
332
366
- `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
333
367
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
334
368
They should be vectors of size `out` or matrices of size `out x batch_size`.
335
- If not provided, they are assumed to be vectors of zeros.
369
+ If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref) .
336
370
337
371
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
338
372
in tensors of size `out x len` or `out x len x batch_size`.
@@ -363,15 +397,16 @@ end
363
397
364
398
@layer LSTM
365
399
400
+ initialstates (lstm:: LSTM ) = initialstates (lstm. cell)
401
+
366
402
function LSTM ((in, out):: Pair ; cell_kwargs... )
367
403
cell = LSTMCell (in => out; cell_kwargs... )
368
404
return LSTM (cell)
369
405
end
370
406
371
- function (m:: LSTM )(x:: AbstractArray )
372
- h = zeros_like (x, size (m. cell. Wh, 2 ))
373
- c = zeros_like (h)
374
- return m (x, (h, c))
407
+ function (lstm:: LSTM )(x:: AbstractArray )
408
+ state, cstate = initialstates (lstm)
409
+ return lstm (x, (state, cstate))
375
410
end
376
411
377
412
function (m:: LSTM )(x:: AbstractArray , (h, c))
@@ -422,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences.
422
457
The arguments of the forward pass are:
423
458
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
424
459
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
425
- If not provided, it is assumed to be a vector of zeros.
460
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref) .
426
461
427
462
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
428
463
447
482
448
483
@layer GRUCell
449
484
485
+ initialstates (gru:: GRUCell ) = zeros_like (gru. Wh, size (gru. Wh, 2 ))
486
+
450
487
function GRUCell (
451
488
(in, out):: Pair ;
452
489
init_kernel = glorot_uniform,
@@ -459,7 +496,10 @@ function GRUCell(
459
496
return GRUCell (Wi, Wh, b)
460
497
end
461
498
462
- (m:: GRUCell )(x:: AbstractVecOrMat ) = m (x, zeros_like (x, size (m. Wh, 2 )))
499
+ function (gru:: GRUCell )(x:: AbstractVecOrMat )
500
+ state = initialstates (gru)
501
+ return gru (x, state)
502
+ end
463
503
464
504
function (m:: GRUCell )(x:: AbstractVecOrMat , h)
465
505
_size_check (m, x, 1 => size (m. Wi, 2 ))
@@ -514,7 +554,7 @@ The arguments of the forward pass are:
514
554
515
555
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
516
556
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
517
- If not provided, it is assumed to be a vector of zeros.
557
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
518
558
519
559
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
520
560
@@ -534,14 +574,16 @@ end
534
574
535
575
@layer GRU
536
576
577
+ initialstates (gru:: GRU ) = initialstates (gru. cell)
578
+
537
579
function GRU ((in, out):: Pair ; cell_kwargs... )
538
580
cell = GRUCell (in => out; cell_kwargs... )
539
581
return GRU (cell)
540
582
end
541
583
542
- function (m :: GRU )(x:: AbstractArray )
543
- h = zeros_like (x, size (m . cell . Wh, 2 ) )
544
- return m (x, h )
584
+ function (gru :: GRU )(x:: AbstractArray )
585
+ state = initialstates (gru )
586
+ return gru (x, state )
545
587
end
546
588
547
589
function (m:: GRU )(x:: AbstractArray , h)
@@ -590,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
590
632
The arguments of the forward pass are:
591
633
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
592
634
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
593
- If not provided, it is assumed to be a vector of zeros.
635
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref) .
594
636
595
637
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
596
638
"""
603
645
604
646
@layer GRUv3Cell
605
647
648
+ initialstates (gru:: GRUv3Cell ) = zeros_like (gru. Wh, size (gru. Wh, 2 ))
649
+
606
650
function GRUv3Cell (
607
651
(in, out):: Pair ;
608
652
init_kernel = glorot_uniform,
@@ -616,7 +660,10 @@ function GRUv3Cell(
616
660
return GRUv3Cell (Wi, Wh, b, Wh_h̃)
617
661
end
618
662
619
- (m:: GRUv3Cell )(x:: AbstractVecOrMat ) = m (x, zeros_like (x, size (m. Wh, 2 )))
663
+ function (gru:: GRUv3Cell )(x:: AbstractVecOrMat )
664
+ state = initialstates (gru)
665
+ return gru (x, state)
666
+ end
620
667
621
668
function (m:: GRUv3Cell )(x:: AbstractVecOrMat , h)
622
669
_size_check (m, x, 1 => size (m. Wi, 2 ))
@@ -667,21 +714,45 @@ but only a less popular variant.
667
714
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
668
715
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
669
716
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
717
+
718
+ # Forward
719
+
720
+ gruv3(x, [h])
721
+
722
+ The arguments of the forward pass are:
723
+
724
+ - `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
725
+ - `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
726
+ If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
727
+
728
+ Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
729
+
730
+ # Examples
731
+
732
+ ```julia
733
+ d_in, d_out, len, batch_size = 2, 3, 4, 5
734
+ gruv3 = GRUv3(d_in => d_out)
735
+ x = rand(Float32, (d_in, len, batch_size))
736
+ h0 = zeros(Float32, d_out)
737
+ h = gruv3(x, h0) # out x len x batch_size
738
+ ```
670
739
"""
671
740
struct GRUv3{M}
672
741
cell:: M
673
742
end
674
743
675
744
@layer GRUv3
676
745
746
+ initialstates (gru:: GRUv3 ) = initialstates (gru. cell)
747
+
677
748
function GRUv3 ((in, out):: Pair ; cell_kwargs... )
678
749
cell = GRUv3Cell (in => out; cell_kwargs... )
679
750
return GRUv3 (cell)
680
751
end
681
752
682
- function (m :: GRUv3 )(x:: AbstractArray )
683
- h = zeros_like (x, size (m . cell . Wh, 2 ) )
684
- return m (x, h )
753
+ function (gru :: GRUv3 )(x:: AbstractArray )
754
+ state = initialstates (gru )
755
+ return gru (x, state )
685
756
end
686
757
687
758
function (m:: GRUv3 )(x:: AbstractArray , h)
0 commit comments