Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Stacked RNN in Flux.jl? #2452

Closed
NeroBlackstone opened this issue Jun 3, 2024 · 1 comment · Fixed by #2549
Closed

Stacked RNN in Flux.jl? #2452

NeroBlackstone opened this issue Jun 3, 2024 · 1 comment · Fixed by #2549
Labels

Comments

@NeroBlackstone
Copy link

Motivation and description

How to build Stacked RNN in Flux.jl?

Is the following code the correct way?

using Flux
model = Chain(GRUv3(27 => 32),GRUv3(32 => 32),Dense(32 => 27))
Chain(
  Recur(
    GRUv3Cell(27 => 32),                # 5_792 parameters
  ),
  Recur(
    GRUv3Cell(32 => 32),                # 6_272 parameters
  ),
  Dense(32 => 27),                      # 891 parameters
)         # Total: 12 trainable arrays, 12_955 parameters,
          # plus 2 non-trainable, 64 parameters, summarysize 1.938 KiB.

There is no documentation mentioning this.

Possible Implementation

No response

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 12, 2024

Starting with Flux v0.15, a stacked RNN can be defined as follows:

stacked_rnn = Chain(LSTM(3 => 3), Dropout(0.5), LSTM(3 => 3))

If control of the initial states is also needed, define a custom struct:

struct StackedRNN{L,S}
    layers::L
    states0::S
end

function StackedRNN(d, num_layers)
    layers = [LSTM(d => d) for _ in num_layers]
    states0 = [Flux.initialstates(l) for l in layers]
    return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
     for (layer, state0) in zip(rnn.layers, rnn.states0)
         x = layer(x, state) 
     end
     return x
end

This stuff should be documented in the guide
https://github.com/FluxML/Flux.jl/blob/master/docs/src/guide/models/recurrence.md

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants