From b8ed5d4df72ff705398c51b6c5a242c95f41e9a6 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Thu, 12 Dec 2024 08:04:54 +0100
Subject: [PATCH 1/6] Recurrence

---
 src/layers/recurrent.jl | 43 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 41 insertions(+), 2 deletions(-)

diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 9386a3fc2d..adab5f415b 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -1,8 +1,7 @@
 out_from_state(state) = state
 out_from_state(state::Tuple) = state[1]
 
-function scan(cell, x, state0)
-  state = state0
+function scan(cell, x, state)
   y = []
   for x_t in eachslice(x, dims = 2)
     state = cell(x_t, state)
@@ -12,7 +11,47 @@ function scan(cell, x, state0)
   return stack(y, dims = 2)
 end
 
+"""
+    Recurrent(cell)
+
+Create a recurrent layer that processes entire sequences out
+of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
+similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
+
+The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
+a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
+the initial hidden state. The output of the `cell` is considered to be:
+1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
+2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
+
+# Forward
+
+    rnn(x, [state])
+
+The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`, 
+where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
+
+The operation performed is semantically equivalent to the following code:
+```julia
+state = Flux.initialstates(cell)
+out = []
+for x_t in eachslice(x, dims = 2)
+  state = cell(x_t, state)
+  out = [out..., get_output(state)]
+end
+stack(out, dims = 2)
+```
+"""
+struct Recurrent{M}
+  cell::M
+end
+
+@layer Recurrent
+
+initialstates(rnn::Recurrent) = initialstates(rnn.cell)
 
+(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
+(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
 
 # Vanilla RNN
 @doc raw"""

From 098c0641da95c7e214f07623df5ff6152f3fedf1 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Fri, 13 Dec 2024 15:27:46 +0100
Subject: [PATCH 2/6] recurrence

---
 docs/src/guide/models/recurrence.md | 60 +++++++++++++++++++++++++++++
 src/Flux.jl                         |  2 +-
 src/layers/recurrent.jl             | 59 ++++++++++++++++++++++------
 test/layers/recurrent.jl            |  9 +++++
 4 files changed, 117 insertions(+), 13 deletions(-)

diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md
index 5b2e70f095..3079ab7311 100644
--- a/docs/src/guide/models/recurrence.md
+++ b/docs/src/guide/models/recurrence.md
@@ -166,3 +166,63 @@ opt_state = Flux.setup(AdamW(1e-3), model)
 g = gradient(m -> Flux.mse(m(x), y), model)[1]
 Flux.update!(opt_state, model, g)
 ```
+
+Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
+
+```jldoctest
+julia> rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
+Recurrence(
+  LSTMCell(2 => 3),                     # 72 parameters
+)                   # Total: 3 arrays, 72 parameters, 448 bytes.
+
+julia> y = rnn(rand(Float32, 2, 4, 3));
+```
+
+## Stacking recurrent layers
+
+Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
+For instance, a model with two LSTM layers can be defined as follows:
+
+```jldoctest
+julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
+Chain(
+  LSTM(3 => 5),                         # 180 parameters
+  Dropout(0.5),
+  LSTM(5 => 5),                         # 220 parameters
+)                   # Total: 6 arrays, 400 parameters, 1.898 KiB.
+
+julia> x = rand(Float32, 3, 4);
+
+julia> y = stacked_rnn(x);
+
+julia> size(y)
+(5, 4)
+```
+
+If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: 
+
+```julia
+struct StackedRNN{L,S}
+    layers::L
+    states0::S
+end
+
+Flux.@layer StackedRNN
+
+function StackedRNN(d::Int; num_layers::Int)
+    layers = [LSTM(d => d) for _ in 1:num_layers]
+    states0 = [Flux.initialstates(l) for l in layers]
+    return StackedRNN(layers, states0)
+end
+
+function (m::StackedRNN)(x)
+   for (layer, state0) in zip(rnn.layers, rnn.states0)
+       x = layer(x, state0) 
+   end
+   return x
+end
+
+rnn = StackedRNN(3; num_layers=2)
+x = rand(Float32, 3, 2)
+y = rnn(x)
+```
diff --git a/src/Flux.jl b/src/Flux.jl
index 8fb2351aa2..3a598e88f5 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
 export Chain, Dense, Embedding, EmbeddingBag,
        Maxout, SkipConnection, Parallel, PairwiseFusion,
        RNNCell, LSTMCell, GRUCell, GRUv3Cell,
-       RNN, LSTM, GRU, GRUv3,
+       RNN, LSTM, GRU, GRUv3, Recurrence,
        SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
        AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
        Dropout, AlphaDropout,
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index adab5f415b..e63dc5927e 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -28,30 +28,47 @@ the initial hidden state. The output of the `cell` is considered to be:
 
     rnn(x, [state])
 
-The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`, 
-where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
+The input `x` should be an array of size `in x len` or `in x len x batch_size`, 
+where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
+The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
+`Flux.initialstates(cell)`.
+
+The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
 
 The operation performed is semantically equivalent to the following code:
 ```julia
+out_from_state(state) = state
+out_from_state(state::Tuple) = state[1]
+
 state = Flux.initialstates(cell)
 out = []
 for x_t in eachslice(x, dims = 2)
   state = cell(x_t, state)
-  out = [out..., get_output(state)]
+  out = [out..., out_from_state(state)]
 end
 stack(out, dims = 2)
 ```
+
+# Examples
+
+```jldoctest
+julia> rnn = Recurrent(RNNCell(2 => 3))
+
+julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
+
+julia> y = rnn(x); # out x len x batch_size
+```
 """
-struct Recurrent{M}
+struct Recurrence{M}
   cell::M
 end
 
-@layer Recurrent
+@layer Recurrence
 
-initialstates(rnn::Recurrent) = initialstates(rnn.cell)
+initialstates(rnn::Recurrence) = initialstates(rnn.cell)
 
-(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
-(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
+(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
+(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)
 
 # Vanilla RNN
 @doc raw"""
@@ -250,7 +267,7 @@ struct RNN{M}
   cell::M
 end
 
-@layer RNN
+@layer :noexpand RNN
 
 initialstates(rnn::RNN) = initialstates(rnn.cell)
 
@@ -271,6 +288,12 @@ function (m::RNN)(x::AbstractArray, h)
   return scan(m.cell, x, h)
 end
 
+function Base.show(io::IO, m::RNN)
+  print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
+  print(io, ", ", m.cell.σ)
+  print(io, ")")
+end
+
 
 # LSTM
 @doc raw"""
@@ -439,7 +462,7 @@ struct LSTM{M}
   cell::M
 end
 
-@layer LSTM
+@layer :noexpand LSTM
 
 initialstates(lstm::LSTM) = initialstates(lstm.cell)
 
@@ -455,6 +478,10 @@ function (m::LSTM)(x::AbstractArray, state0)
   return scan(m.cell, x, state0)
 end
 
+function Base.show(io::IO, m::LSTM)
+  print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
+end
+
 # GRU
 
 @doc raw"""
@@ -607,7 +634,7 @@ struct GRU{M}
   cell::M
 end
 
-@layer GRU
+@layer :noexpand GRU
 
 initialstates(gru::GRU) = initialstates(gru.cell)
 
@@ -623,6 +650,10 @@ function (m::GRU)(x::AbstractArray, h)
   return scan(m.cell, x, h)
 end
 
+function Base.show(io::IO, m::GRU)
+  print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
+end
+
 # GRU v3
 @doc raw"""
     GRUv3Cell(in => out; init_kernel = glorot_uniform,
@@ -767,7 +798,7 @@ struct GRUv3{M}
   cell::M
 end
 
-@layer GRUv3
+@layer :noexpand GRUv3
 
 initialstates(gru::GRUv3) = initialstates(gru.cell)
 
@@ -782,3 +813,7 @@ function (m::GRUv3)(x::AbstractArray, h)
   @assert ndims(x) == 2 || ndims(x) == 3
   return scan(m.cell, x, h)
 end
+
+function Base.show(io::IO, m::GRUv3)
+  print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
+end
\ No newline at end of file
diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl
index 864e5dad8e..3d7d53a486 100644
--- a/test/layers/recurrent.jl
+++ b/test/layers/recurrent.jl
@@ -305,3 +305,12 @@ end
     # no initial state same as zero initial state
     @test gru(x) ≈ gru(x, zeros(Float32, 4))
 end
+
+@testset "Recurrence" begin
+    for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
+        cell = rnn.cell
+        rec = Recurrence(cell)
+        x = rand(Float32, 2, 3, 4)
+        @test rec(x) ≈ rnn(x)
+    end
+end

From 48e91bc57ddb6952fb6de633d33e21998843f425 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Fri, 13 Dec 2024 15:28:52 +0100
Subject: [PATCH 3/6] cleanup

---
 test/layers/recurrent.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl
index 3d7d53a486..3ad7428601 100644
--- a/test/layers/recurrent.jl
+++ b/test/layers/recurrent.jl
@@ -307,10 +307,10 @@ end
 end
 
 @testset "Recurrence" begin
+    x = rand(Float32, 2, 3, 4)
     for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
         cell = rnn.cell
         rec = Recurrence(cell)
-        x = rand(Float32, 2, 3, 4)
         @test rec(x) ≈ rnn(x)
     end
 end

From 47b570c037becca7f5ecceccc47ea13f14284551 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Fri, 13 Dec 2024 17:24:28 +0100
Subject: [PATCH 4/6] fix

---
 src/layers/recurrent.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index e63dc5927e..616dfc83b3 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -240,9 +240,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
 julia> h = zeros(Float32, (d_out, batch_size));
 
 julia> rnn = RNN(d_in => d_out)
-RNN(
-  RNNCell(4 => 6, tanh),                # 66 parameters
-)                   # Total: 3 arrays, 66 parameters, 424 bytes.
+RNN(4 => 6, tanh)   # 66 parameters
 
 julia> y = rnn(x, h);   # [y] = [d_out, len, batch_size]
 ```

From 89fc89fe60d93963d0e271fbf7b01902a5a5abc7 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Fri, 13 Dec 2024 17:32:41 +0100
Subject: [PATCH 5/6] fix doctest

---
 .github/workflows/ci.yml            | 2 +-
 docs/make.jl                        | 1 +
 docs/src/guide/models/basics.md     | 4 ++--
 docs/src/reference/models/layers.md | 1 +
 src/layers/recurrent.jl             | 4 ++--
 5 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index acd279d709..dfa4a442f2 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -75,4 +75,4 @@ jobs:
         env:
           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
           DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
-          DATADEPS_ALWAYS_ACCEPT: true
+
diff --git a/docs/make.jl b/docs/make.jl
index 4367639d8e..d74486936b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
       OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, 
       DataFrames, JLD2, MLDataDevices
 
+ENV["DATADEPS_ALWAYS_ACCEPT"] = true
 
 DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
 
diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md
index 7ad62ee207..7141fb0b28 100644
--- a/docs/src/guide/models/basics.md
+++ b/docs/src/guide/models/basics.md
@@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
 <h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a></h3>
 ```
 
-Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
+Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
 Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
 hooks into Julia's compiler to find out what operations `f` contains, and transforms this
 to produce code for computing `∂f/∂x`.
@@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
   Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
 * Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
   - Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
-  - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
+  - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref  Flux.glorot_uniform) by default.
   - It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
   - And it has some performance tricks: making sure element types match, and re-using some memory.
 * The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md
index 355d3e7833..562304de70 100644
--- a/docs/src/reference/models/layers.md
+++ b/docs/src/reference/models/layers.md
@@ -104,6 +104,7 @@ PairwiseFusion
 Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
 
 ```@docs
+Recurrence
 RNNCell
 RNN
 LSTMCell
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 616dfc83b3..5b88768ad1 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -12,7 +12,7 @@ function scan(cell, x, state)
 end
 
 """
-    Recurrent(cell)
+    Recurrence(cell)
 
 Create a recurrent layer that processes entire sequences out
 of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
@@ -52,7 +52,7 @@ stack(out, dims = 2)
 # Examples
 
 ```jldoctest
-julia> rnn = Recurrent(RNNCell(2 => 3))
+julia> rnn = Recurrence(RNNCell(2 => 3))
 
 julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
 

From c4c582d1a68fb26cc23b4c736ba5ffb0b3faefea Mon Sep 17 00:00:00 2001
From: Carlo Lucibello <carlo.lucibello@gmail.com>
Date: Fri, 13 Dec 2024 17:40:19 +0100
Subject: [PATCH 6/6] fix doctests

---
 docs/src/guide/models/recurrence.md | 31 +++++++++--------------------
 src/layers/recurrent.jl             |  3 +++
 2 files changed, 12 insertions(+), 22 deletions(-)

diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md
index 3079ab7311..446ff86f82 100644
--- a/docs/src/guide/models/recurrence.md
+++ b/docs/src/guide/models/recurrence.md
@@ -169,13 +169,10 @@ Flux.update!(opt_state, model, g)
 
 Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
 
-```jldoctest
-julia> rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
-Recurrence(
-  LSTMCell(2 => 3),                     # 72 parameters
-)                   # Total: 3 arrays, 72 parameters, 448 bytes.
-
-julia> y = rnn(rand(Float32, 2, 4, 3));
+```julia
+rnn = Recurrence(LSTMCell(2 => 3))   # similar to LSTM(2 => 3)
+x = rand(Float32, 2, 4, 3)
+y = rnn(x)
 ```
 
 ## Stacking recurrent layers
@@ -183,20 +180,10 @@ julia> y = rnn(rand(Float32, 2, 4, 3));
 Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
 For instance, a model with two LSTM layers can be defined as follows:
 
-```jldoctest
-julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
-Chain(
-  LSTM(3 => 5),                         # 180 parameters
-  Dropout(0.5),
-  LSTM(5 => 5),                         # 220 parameters
-)                   # Total: 6 arrays, 400 parameters, 1.898 KiB.
-
-julia> x = rand(Float32, 3, 4);
-
-julia> y = stacked_rnn(x);
-
-julia> size(y)
-(5, 4)
+```julia
+stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
+x = rand(Float32, 3, 4)
+y = stacked_rnn(x)
 ```
 
 If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: 
@@ -223,6 +210,6 @@ function (m::StackedRNN)(x)
 end
 
 rnn = StackedRNN(3; num_layers=2)
-x = rand(Float32, 3, 2)
+x = rand(Float32, 3, 10)
 y = rnn(x)
 ```
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 5b88768ad1..a6a82b0e72 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -53,6 +53,9 @@ stack(out, dims = 2)
 
 ```jldoctest
 julia> rnn = Recurrence(RNNCell(2 => 3))
+Recurrence(
+  RNNCell(2 => 3, tanh),                # 18 parameters
+)                   # Total: 3 arrays, 18 parameters, 232 bytes.
 
 julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size