diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 15eb22db5a..27930aac97 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -133,7 +133,7 @@ represent the corresponding Burn Op. | [QLinearMatMul][124] | ❌ | ❌ | | [QuantizeLinear][125] | ❌ | ❌ | | [RandomNormal][126] | ✅ | ✅ | -| [RandomNormalLike][127] | ❌ | ✅ | +| [RandomNormalLike][127] | ✅ | ✅ | | [RandomUniform][128] | ✅ | ✅ | | [RandomUniformLike][129] | ❌ | ✅ | | [Range][130] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 198f20eb9c..2229626c00 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -75,6 +75,7 @@ fn main() { .input("tests/pow/pow_int.onnx") .input("tests/prelu/prelu.onnx") .input("tests/random_normal/random_normal.onnx") + .input("tests/random_normal_like/random_normal_like.onnx") .input("tests/random_uniform/random_uniform.onnx") .input("tests/range/range.onnx") .input("tests/recip/recip.onnx") diff --git a/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx new file mode 100644 index 0000000000..6e4b6f97c6 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx @@ -0,0 +1,15 @@ +pytorch2.2.0: +P +onnx::RandomNormalLike_01/RandomNormalLike"RandomNormalLike* +dtype +main_graphZ. +onnx::RandomNormalLike_0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py new file mode 100644 index 0000000000..64aed94e3b --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# used to generate model: random_normal_like.onnx + +import torch +import torch.nn as nn + + +class RandomNormalLikeModel(nn.Module): + def __init__(self): + super(RandomNormalLikeModel, self).__init__() + + def forward(self, x): + return torch.randn_like(x) + + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + + # Set print options for better precision output + torch.set_printoptions(precision=8) + + # Export Random NormalLike Model + model = RandomNormalLikeModel() + model.eval() + device = torch.device("cpu") + + # Generate test input: a 2D matrix or batch of 2D matrices + file_name = "random_normal_like.onnx" + test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices + torch.onnx.export(model, + test_input, + file_name, + verbose=False, + opset_version=16) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + print("Test input data shape: {}".format(test_input.shape)) + output = model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + print("Test output: {}".format(output)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 24ddaeca45..a4acd51477 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -84,6 +84,7 @@ include_models!( pow_int, prelu, random_normal, + random_normal_like, random_uniform, range, recip, @@ -2157,6 +2158,18 @@ mod tests { assert_eq!(expected_shape, output.shape()); } + #[test] + fn random_normal_like() { + let device = Default::default(); + let model = random_normal_like::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 4, 4])); + let expected_shape = Shape::from([2, 4, 4]); + + let output = model.forward(input.into()); + + assert_eq!(expected_shape, output.shape()); + } + #[test] fn constant_of_shape() { // This tests shape is being passed directly to the model diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 427bbdcad7..123a29a223 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -10,10 +10,10 @@ use super::{ gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, - prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, - range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode, - unsqueeze::UnsqueezeNode, + prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, + random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, + slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -121,8 +121,9 @@ pub enum Node { Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), - RandomUniform(RandomUniformNode), RandomNormal(RandomNormalNode), + RandomNormalLike(RandomNormalLikeNode), + RandomUniform(RandomUniformNode), ConstantOfShape(ConstantOfShapeNode), // For now, we have to keep the precision settings in order to correctly serialize the fields // into the right data types. @@ -172,6 +173,7 @@ macro_rules! match_all { Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), Node::RandomNormal(node) => $func(node), + Node::RandomNormalLike(node) => $func(node), Node::RandomUniform(node) => $func(node), Node::ConstantOfShape(node) => $func(node), _ => unimplemented!(), @@ -230,6 +232,7 @@ impl Node { Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", Node::RandomNormal(_) => "random_normal", + Node::RandomNormalLike(_) => "random_normal_like", Node::RandomUniform(_) => "random_uniform", Node::ConstantOfShape(_) => "constant_of_shape", _ => unimplemented!(), diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 6313d25f63..57dcae1c6f 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -30,6 +30,7 @@ pub(crate) mod mean; pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; +pub(crate) mod random_normal_like; pub(crate) mod random_uniform; pub(crate) mod range; pub(crate) mod reshape; diff --git a/crates/burn-import/src/burn/node/random_normal_like.rs b/crates/burn-import/src/burn/node/random_normal_like.rs new file mode 100644 index 0000000000..c502f9c76a --- /dev/null +++ b/crates/burn-import/src/burn/node/random_normal_like.rs @@ -0,0 +1,101 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct RandomNormalLikeNode { + pub mean: f64, + pub scale: f64, + pub input: TensorType, + pub output: TensorType, +} + +impl RandomNormalLikeNode { + // Set distribution parameters based on mean and scale + fn get_distribution(&self) -> TokenStream { + let mean = self.mean; + let std_deviation = self.scale; + quote! { Distribution::Normal(#mean, #std_deviation) } + } +} + +impl NodeCodegen for RandomNormalLikeNode { + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let output = &self.output.name; + let input = &self.input.name; + let dist = self.get_distribution(); + quote! { + let #output = #input.random_like(#dist); + } + } + + fn into_node(self) -> Node { + Node::RandomNormalLike(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::Distribution"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorKind, TensorType}; + use burn::record::FullPrecisionSettings; + + #[test] + fn test_random_normal_like_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(RandomNormalLikeNode::new( + 0.0f64, + 1.0f64, + TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])), + TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::tensor::Distribution; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.random_like(Distribution::Normal(0f64, 1f64)); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 4b081f6e43..16d56284cc 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -44,6 +44,7 @@ use crate::{ pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, + random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, @@ -345,6 +346,9 @@ impl ParsedOnnxGraph { NodeType::Tile => graph.register(Self::tile_conversion(node)), NodeType::Trilu => graph.register(Self::trilu_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), + NodeType::RandomNormalLike => { + graph.register(Self::random_normal_like_conversion(node)) + } NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) } @@ -472,6 +476,27 @@ impl ParsedOnnxGraph { RandomNormalNode::new(output_type, mean, scale) } + fn random_normal_like_conversion(node: Node) -> RandomNormalLikeNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let mean = node + .attrs + .get("mean") + .map(|val| val.clone().into_f32() as f64) + .unwrap_or(0.0f64); + let scale = node + .attrs + .get("scale") + .map(|val| val.clone().into_f32() as f64) + .unwrap_or(1.0f64); + + if node.attrs.contains_key("seed") { + warn!("seed attribute is not supported!"); + } + + RandomNormalLikeNode::new(mean, scale, input, output) + } + pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode { // Additional types needed for ConstantOfShape: use crate::burn::node::constant_of_shape::ConstantValue; diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index cd22a1403a..eabbb781f7 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -61,6 +61,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::PRelu => same_as_input_broadcast(node), NodeType::Pow => same_as_input_broadcast(node), NodeType::RandomNormal => random_update_output(node), + NodeType::RandomNormalLike => random_normal_like_update_output(node), NodeType::RandomUniform => random_update_output(node), NodeType::Range => range_update_outputs(node), NodeType::Reciprocal => same_as_input(node), @@ -204,6 +205,34 @@ fn random_update_output(node: &mut Node) { }) } +/// Reads & interprets an optional `dtype` attribute +fn random_normal_like_update_output(node: &mut Node) { + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::FLOAT16 => ElementType::Float16, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Tensor with type {dtype:?} not supported for random output"), + }; + + if let ArgType::Tensor(tensor) = &node.inputs[0].clone().ty { + if let Some(shape) = tensor.shape.clone() { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + dim: shape.len(), + shape: Some(shape), + }) + } + } else { + panic!("Only tensor input is valid"); + } +} + /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { // Extract the configuration of the linear layer (inputs are known)