From b8ed5d4df72ff705398c51b6c5a242c95f41e9a6 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Thu, 12 Dec 2024 08:04:54 +0100
Subject: [PATCH 1/6] Recurrence

 src/layers/recurrent.jl | 43 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 41 insertions(+), 2 deletions(-)

diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 9386a3fc2d..adab5f415b 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -1,8 +1,7 @@
 out_from_state(state) = state
 out_from_state(state::Tuple) = state[1]
-function scan(cell, x, state0)
-  state = state0
+function scan(cell, x, state)
   y = []
   for x_t in eachslice(x, dims = 2)
     state = cell(x_t, state)
@@ -12,7 +11,47 @@ function scan(cell, x, state0)
   return stack(y, dims = 2)
+    Recurrent(cell)
+Create a recurrent layer that processes entire sequences out
+of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
+similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
+The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
+a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
+the initial hidden state. The output of the `cell` is considered to be:
+1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
+2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
+# Forward
+    rnn(x, [state])
+The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`, 
+where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
+The operation performed is semantically equivalent to the following code:
+state = Flux.initialstates(cell)
+out = []
+for x_t in eachslice(x, dims = 2)
+  state = cell(x_t, state)
+  out = [out..., get_output(state)]
+stack(out, dims = 2)
+struct Recurrent{M}
+  cell::M
+@layer Recurrent
+initialstates(rnn::Recurrent) = initialstates(rnn.cell)
+(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
+(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
 # Vanilla RNN
 @doc raw"""

From 098c0641da95c7e214f07623df5ff6152f3fedf1 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Fri, 13 Dec 2024 15:27:46 +0100
Subject: [PATCH 2/6] recurrence

 docs/src/guide/models/ | 60 +++++++++++++++++++++++++++++
 src/Flux.jl                         |  2 +-
 src/layers/recurrent.jl             | 59 ++++++++++++++++++++++------
 test/layers/recurrent.jl            |  9 +++++
 4 files changed, 117 insertions(+), 13 deletions(-)

diff --git a/docs/src/guide/models/ b/docs/src/guide/models/
index 5b2e70f095..3079ab7311 100644
--- a/docs/src/guide/models/
+++ b/docs/src/guide/models/
@@ -166,3 +166,63 @@ opt_state = Flux.setup(AdamW(1e-3), model)
 g = gradient(m -> Flux.mse(m(x), y), model)[1]
 Flux.update!(opt_state, model, g)
+Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
+julia> rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
+  LSTMCell(2 => 3),                     # 72 parameters
+)                   # Total: 3 arrays, 72 parameters, 448 bytes.
+julia> y = rnn(rand(Float32, 2, 4, 3));
+## Stacking recurrent layers
+Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
+For instance, a model with two LSTM layers can be defined as follows:
+julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
+  LSTM(3 => 5),                         # 180 parameters
+  Dropout(0.5),
+  LSTM(5 => 5),                         # 220 parameters
+)                   # Total: 6 arrays, 400 parameters, 1.898 KiB.
+julia> x = rand(Float32, 3, 4);
+julia> y = stacked_rnn(x);
+julia> size(y)
+(5, 4)
+If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: 
+struct StackedRNN{L,S}
+    layers::L
+    states0::S
+Flux.@layer StackedRNN
+function StackedRNN(d::Int; num_layers::Int)
+    layers = [LSTM(d => d) for _ in 1:num_layers]
+    states0 = [Flux.initialstates(l) for l in layers]
+    return StackedRNN(layers, states0)
+function (m::StackedRNN)(x)
+   for (layer, state0) in zip(rnn.layers, rnn.states0)
+       x = layer(x, state0) 
+   end
+   return x
+rnn = StackedRNN(3; num_layers=2)
+x = rand(Float32, 3, 2)
+y = rnn(x)
diff --git a/src/Flux.jl b/src/Flux.jl
index 8fb2351aa2..3a598e88f5 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
 export Chain, Dense, Embedding, EmbeddingBag,
        Maxout, SkipConnection, Parallel, PairwiseFusion,
        RNNCell, LSTMCell, GRUCell, GRUv3Cell,
-       RNN, LSTM, GRU, GRUv3,
+       RNN, LSTM, GRU, GRUv3, Recurrence,
        SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
        AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
        Dropout, AlphaDropout,
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index adab5f415b..e63dc5927e 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -28,30 +28,47 @@ the initial hidden state. The output of the `cell` is considered to be:
     rnn(x, [state])
-The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`, 
-where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
+The input `x` should be an array of size `in x len` or `in x len x batch_size`, 
+where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
+The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
+The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
 The operation performed is semantically equivalent to the following code:
+out_from_state(state) = state
+out_from_state(state::Tuple) = state[1]
 state = Flux.initialstates(cell)
 out = []
 for x_t in eachslice(x, dims = 2)
   state = cell(x_t, state)
-  out = [out..., get_output(state)]
+  out = [out..., out_from_state(state)]
 stack(out, dims = 2)
+# Examples
+julia> rnn = Recurrent(RNNCell(2 => 3))
+julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
+julia> y = rnn(x); # out x len x batch_size
-struct Recurrent{M}
+struct Recurrence{M}
-@layer Recurrent
+@layer Recurrence
-initialstates(rnn::Recurrent) = initialstates(rnn.cell)
+initialstates(rnn::Recurrence) = initialstates(rnn.cell)
-(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
-(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
+(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
+(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)
 # Vanilla RNN
 @doc raw"""
@@ -250,7 +267,7 @@ struct RNN{M}
-@layer RNN
+@layer :noexpand RNN
 initialstates(rnn::RNN) = initialstates(rnn.cell)
@@ -271,6 +288,12 @@ function (m::RNN)(x::AbstractArray, h)
   return scan(m.cell, x, h)
+function, m::RNN)
+  print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
+  print(io, ", ", m.cell.σ)
+  print(io, ")")
 @doc raw"""
@@ -439,7 +462,7 @@ struct LSTM{M}
-@layer LSTM
+@layer :noexpand LSTM
 initialstates(lstm::LSTM) = initialstates(lstm.cell)
@@ -455,6 +478,10 @@ function (m::LSTM)(x::AbstractArray, state0)
   return scan(m.cell, x, state0)
+function, m::LSTM)
+  print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
 # GRU
 @doc raw"""
@@ -607,7 +634,7 @@ struct GRU{M}
-@layer GRU
+@layer :noexpand GRU
 initialstates(gru::GRU) = initialstates(gru.cell)
@@ -623,6 +650,10 @@ function (m::GRU)(x::AbstractArray, h)
   return scan(m.cell, x, h)
+function, m::GRU)
+  print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
 # GRU v3
 @doc raw"""
     GRUv3Cell(in => out; init_kernel = glorot_uniform,
@@ -767,7 +798,7 @@ struct GRUv3{M}
-@layer GRUv3
+@layer :noexpand GRUv3
 initialstates(gru::GRUv3) = initialstates(gru.cell)
@@ -782,3 +813,7 @@ function (m::GRUv3)(x::AbstractArray, h)
   @assert ndims(x) == 2 || ndims(x) == 3
   return scan(m.cell, x, h)
+function, m::GRUv3)
+  print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
\ No newline at end of file
diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl
index 864e5dad8e..3d7d53a486 100644
--- a/test/layers/recurrent.jl
+++ b/test/layers/recurrent.jl
@@ -305,3 +305,12 @@ end
     # no initial state same as zero initial state
     @test gru(x) ≈ gru(x, zeros(Float32, 4))
+@testset "Recurrence" begin
+    for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
+        cell = rnn.cell
+        rec = Recurrence(cell)
+        x = rand(Float32, 2, 3, 4)
+        @test rec(x) ≈ rnn(x)
+    end

From 48e91bc57ddb6952fb6de633d33e21998843f425 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Fri, 13 Dec 2024 15:28:52 +0100
Subject: [PATCH 3/6] cleanup

 test/layers/recurrent.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl
index 3d7d53a486..3ad7428601 100644
--- a/test/layers/recurrent.jl
+++ b/test/layers/recurrent.jl
@@ -307,10 +307,10 @@ end
 @testset "Recurrence" begin
+    x = rand(Float32, 2, 3, 4)
     for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
         cell = rnn.cell
         rec = Recurrence(cell)
-        x = rand(Float32, 2, 3, 4)
         @test rec(x) ≈ rnn(x)

From 47b570c037becca7f5ecceccc47ea13f14284551 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Fri, 13 Dec 2024 17:24:28 +0100
Subject: [PATCH 4/6] fix

 src/layers/recurrent.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index e63dc5927e..616dfc83b3 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -240,9 +240,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
 julia> h = zeros(Float32, (d_out, batch_size));
 julia> rnn = RNN(d_in => d_out)
-  RNNCell(4 => 6, tanh),                # 66 parameters
-)                   # Total: 3 arrays, 66 parameters, 424 bytes.
+RNN(4 => 6, tanh)   # 66 parameters
 julia> y = rnn(x, h);   # [y] = [d_out, len, batch_size]

From 89fc89fe60d93963d0e271fbf7b01902a5a5abc7 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Fri, 13 Dec 2024 17:32:41 +0100
Subject: [PATCH 5/6] fix doctest

 .github/workflows/ci.yml            | 2 +-
 docs/make.jl                        | 1 +
 docs/src/guide/models/     | 4 ++--
 docs/src/reference/models/ | 1 +
 src/layers/recurrent.jl             | 4 ++--
 5 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index acd279d709..dfa4a442f2 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -75,4 +75,4 @@ jobs:
           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
           DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
diff --git a/docs/make.jl b/docs/make.jl
index 4367639d8e..d74486936b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
       OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, 
       DataFrames, JLD2, MLDataDevices
 DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
diff --git a/docs/src/guide/models/ b/docs/src/guide/models/
index 7ad62ee207..7141fb0b28 100644
--- a/docs/src/guide/models/
+++ b/docs/src/guide/models/
@@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
 <h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="">Zygote.jl</a></h3>
-Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](
+Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](
 Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
 hooks into Julia's compiler to find out what operations `f` contains, and transforms this
 to produce code for computing `∂f/∂x`.
@@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
   Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
 * Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
   - Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
-  - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
+  - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref  Flux.glorot_uniform) by default.
   - It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
   - And it has some performance tricks: making sure element types match, and re-using some memory.
 * The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
diff --git a/docs/src/reference/models/ b/docs/src/reference/models/
index 355d3e7833..562304de70 100644
--- a/docs/src/reference/models/
+++ b/docs/src/reference/models/
@@ -104,6 +104,7 @@ PairwiseFusion
 Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 616dfc83b3..5b88768ad1 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -12,7 +12,7 @@ function scan(cell, x, state)
-    Recurrent(cell)
+    Recurrence(cell)
 Create a recurrent layer that processes entire sequences out
 of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
@@ -52,7 +52,7 @@ stack(out, dims = 2)
 # Examples
-julia> rnn = Recurrent(RNNCell(2 => 3))
+julia> rnn = Recurrence(RNNCell(2 => 3))
 julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size

From c4c582d1a68fb26cc23b4c736ba5ffb0b3faefea Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <>
Date: Fri, 13 Dec 2024 17:40:19 +0100
Subject: [PATCH 6/6] fix doctests

 docs/src/guide/models/ | 31 +++++++++--------------------
 src/layers/recurrent.jl             |  3 +++
 2 files changed, 12 insertions(+), 22 deletions(-)

diff --git a/docs/src/guide/models/ b/docs/src/guide/models/
index 3079ab7311..446ff86f82 100644
--- a/docs/src/guide/models/
+++ b/docs/src/guide/models/
@@ -169,13 +169,10 @@ Flux.update!(opt_state, model, g)
 Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
-julia> rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
-  LSTMCell(2 => 3),                     # 72 parameters
-)                   # Total: 3 arrays, 72 parameters, 448 bytes.
-julia> y = rnn(rand(Float32, 2, 4, 3));
+rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
+x = rand(Float32, 2, 4, 3)
+y = rnn(x)
 ## Stacking recurrent layers
@@ -183,20 +180,10 @@ julia> y = rnn(rand(Float32, 2, 4, 3));
 Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
 For instance, a model with two LSTM layers can be defined as follows:
-julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
-  LSTM(3 => 5),                         # 180 parameters
-  Dropout(0.5),
-  LSTM(5 => 5),                         # 220 parameters
-)                   # Total: 6 arrays, 400 parameters, 1.898 KiB.
-julia> x = rand(Float32, 3, 4);
-julia> y = stacked_rnn(x);
-julia> size(y)
-(5, 4)
+stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
+x = rand(Float32, 3, 4)
+y = stacked_rnn(x)
 If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: 
@@ -223,6 +210,6 @@ function (m::StackedRNN)(x)
 rnn = StackedRNN(3; num_layers=2)
-x = rand(Float32, 3, 2)
+x = rand(Float32, 3, 10)
 y = rnn(x)
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 5b88768ad1..a6a82b0e72 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -53,6 +53,9 @@ stack(out, dims = 2)
 julia> rnn = Recurrence(RNNCell(2 => 3))
+  RNNCell(2 => 3, tanh),                # 18 parameters
+)                   # Total: 3 arrays, 18 parameters, 232 bytes.
 julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size