diff --git a/crates/burn-core/src/nn/mod.rs b/crates/burn-core/src/nn/mod.rs index ac428c3063..b3897760ee 100644 --- a/crates/burn-core/src/nn/mod.rs +++ b/crates/burn-core/src/nn/mod.rs @@ -33,6 +33,7 @@ mod prelu; mod relu; mod rnn; mod rope_encoding; +mod sequential; mod sigmoid; mod swiglu; mod tanh; @@ -52,6 +53,7 @@ pub use prelu::*; pub use relu::*; pub use rnn::*; pub use rope_encoding::*; +pub use sequential::*; pub use sigmoid::*; pub use swiglu::*; pub use tanh::*; diff --git a/crates/burn-core/src/nn/sequential.rs b/crates/burn-core/src/nn/sequential.rs new file mode 100644 index 0000000000..64a3dc68d4 --- /dev/null +++ b/crates/burn-core/src/nn/sequential.rs @@ -0,0 +1,92 @@ +/// Create a sequential neural network, similar to numpy's nn.Sequential. +/// +/// To use this macro, separate your modules into three categories: +/// - Unit modules: Modules that don't take any parameters (eg. Relu, Sigmoid) +/// - Modules: Modules that take parameters, but don't have a backend parameter (eg. Dropout, LeakyRelu) +/// - Backend modules: Modules that take a backend parameter (eg. Linear) +/// +/// List these classes of modules as comma-separated within classes, then semicolons between, like so: +/// ```ignore +/// gen_sequential! { +/// // No config +/// Relu, +/// Sigmoid; +/// // Has config +/// DropoutConfig => Dropout, +/// LeakyReluConfig => LeakyRelu; +/// // Requires a backend () +/// LinearConfig => Linear +/// } +/// ``` +/// +/// If there aren't any members of a particular class, the semicolon is still needed: +/// ```ignore +/// gen_sequential! { +/// Relu, +/// Sigmoid; +/// // Nothing with no config +/// ; +/// LinearConfig => Linear +/// } +/// ``` +/// +/// To use this macro, use the types `SequentialConfig` and `Sequential` in your code. +#[macro_export] +macro_rules! gen_sequential { + ($($unit:tt),*; $($cfg:ty => $module:tt),*; $($bcfg:ty => $bmodule:tt),*) => { + #[derive(Debug, burn::config::Config)] + pub enum SequentialLayerConfig { + $($unit,)* + $($module($cfg),)* + $($bmodule($bcfg),)* + } + + #[derive(Debug, burn::config::Config)] + pub struct SequentialConfig { + pub layers: Vec + } + + impl SequentialConfig { + pub fn init(&self, device: &B::Device) -> Sequential { + Sequential { + layers: self.layers.iter().map(|l| match l { + $(SequentialLayerConfig::$unit => SequentialLayer::$unit($unit),)* + $(SequentialLayerConfig::$module(c) => SequentialLayer::$module(c.init()),)* + $(SequentialLayerConfig::$bmodule(c) => SequentialLayer::$bmodule(c.init(device)),)* + }).collect() + } + } + } + + #[derive(Debug, burn::module::Module)] + pub enum SequentialLayer { + /// In case the expansion doesn't use any backend-based layers. This should never be used. + _PhantomData(::core::marker::PhantomData), + $($unit($unit),)* + $($module($module),)* + $($bmodule($bmodule),)* + } + + #[derive(Debug, burn::module::Module)] + pub struct Sequential { + pub layers: Vec> + } + + impl Sequential { + pub fn forward(&self, mut input: burn::tensor::Tensor) -> burn::tensor::Tensor { + for layer in &self.layers { + input = match layer { + SequentialLayer::_PhantomData(_) => unreachable!("PhantomData should never be instantiated"), + $(SequentialLayer::$unit(u) => u.forward(input),)* + $(SequentialLayer::$module(m) => m.forward(input),)* + $(SequentialLayer::$bmodule(b) => b.forward(input),)* + }; + } + + input + } + } + } +} + +pub use gen_sequential; diff --git a/crates/burn-core/tests/test_gen_sequential.rs b/crates/burn-core/tests/test_gen_sequential.rs new file mode 100644 index 0000000000..11b729f725 --- /dev/null +++ b/crates/burn-core/tests/test_gen_sequential.rs @@ -0,0 +1,33 @@ +use burn_core::nn::{ + gen_sequential, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear, LinearConfig, Relu, +}; + +use burn_core as burn; + +gen_sequential! { + Relu; + DropoutConfig => Dropout, + LeakyReluConfig => LeakyRelu; + LinearConfig => Linear +} + +type TestBackend = burn_ndarray::NdArray; + +#[test] +fn sequential_should_construct() { + let cfg = SequentialConfig { + layers: vec![ + SequentialLayerConfig::Relu, + SequentialLayerConfig::Dropout(DropoutConfig { prob: 0.3 }), + SequentialLayerConfig::LeakyRelu(LeakyReluConfig { + negative_slope: 0.01, + }), + SequentialLayerConfig::Linear(LinearConfig::new(10, 10)), + ], + }; + + let device = Default::default(); + + let module: Sequential = cfg.init(&device); + assert_eq!(module.layers.len(), 4); +}