Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor GatherNode to support scalar outputs. #2828

Merged
merged 5 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fn main() {
.input("tests/gather/gather_2d_idx.onnx")
.input("tests/gather/gather_scalar.onnx")
.input("tests/gather/gather_shape.onnx")
.input("tests/gather/gather_scalar_out.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
Expand Down
Binary file not shown.
62 changes: 62 additions & 0 deletions crates/burn-import/onnx-tests/tests/gather/gather_scalar_out.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather_scalar_out.onnx

# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.

import onnx


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[1]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[]
),
),

],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "gather_scalar_out.onnx"

# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)

onnx.save(onnx_model, file_name)


if __name__ == '__main__':
main()
14 changes: 14 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ include_models!(
gather_1d_idx,
gather_2d_idx,
gather_scalar,
gather_scalar_out,
gather_shape,
gather_elements,
gelu,
Expand Down Expand Up @@ -523,6 +524,19 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_scalar_out() {
let model: gather_scalar_out::Model<Backend> = gather_scalar_out::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 1>::from_floats([1., 2., 3.], &device);
let index = 1;
let output = model.forward(input, index);

assert_eq!(output, 2f32);
}

#[test]
fn gather_elements() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
207 changes: 157 additions & 50 deletions crates/burn-import/src/burn/node/gather.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorType, ToTokens, Type};
use crate::burn::{BurnImports, ScalarKind, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;
Expand All @@ -8,13 +8,13 @@ use quote::quote;
pub struct GatherNode {
pub input: Type,
pub index: Type,
pub output: TensorType,
pub output: Type,
pub dim: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
vec![self.output.clone()]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
Expand All @@ -39,63 +39,115 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
};

let output = &self.output.name;

match &self.index {
Type::Scalar(idx_scalar) => {
// To do a scalar select (select just a single index in one dim),
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
// then squeeze the dimension to reduce the rank
let index = &idx_scalar.name;
let output_rank = input_rank - 1;
quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let slice = Tensor::select(#input, #dim, indices);
let #output = slice.squeeze::<#output_rank>(#dim);
}
}
Type::Tensor(idx_tensor) => {
let index = scope.tensor_use_owned(idx_tensor, node_position);
let index_rank = idx_tensor.dim;
let output_rank = index_rank + input_rank - 1;
match index_rank {
1 => quote! {
let indices = #index;
let #output = Tensor::select(#input, #dim, indices);
let output = &self.output.name();

match &self.output {
Type::Scalar(sc) => {
assert_eq!(input_rank, 1);
let index = match &self.index {
Type::Scalar(idx) => idx.name.clone(),
_ => panic!("Gather needs Scalar index, got {:?}!", self.index),
};
let scalar_kind = &sc.kind;
match scalar_kind {
ScalarKind::Int32 => quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let gathered = Tensor::select(#input, #dim, indices);
let #output = gathered.into_scalar().to_i32();
#output
},
ScalarKind::Int64 => quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let gathered = Tensor::select(#input, #dim, indices);
let #output = gathered.into_scalar().to_i64();
},
ScalarKind::Float32 => quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let gathered = Tensor::select(#input, #dim, indices);
let #output = gathered.into_scalar().to_f32();
},
_ => quote! {
let indices = #index;

let n_dims = indices.dims().len();
let index_flat = match n_dims {
1 => indices.reshape([1, -1]),
n if n >= 2 => indices.flatten::<2>(0, n - 2),
_ => panic!("Number of dimensions must be greater than 0"),
};

let out = index_flat
.iter_dim(0)
.map(|idxs| {
let idxs = idxs.squeeze::<1>(0);
Tensor::select(#input.clone(), #dim, idxs)
})
.collect();
let #output = Tensor::stack::<#output_rank>(out, #dim);
ScalarKind::Float64 => quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let gathered = Tensor::select(#input, #dim, indices);
let #output = gathered.into_scalar().to_f64();
},
ScalarKind::Bool => quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let gathered = Tensor::select(#input, #dim, indices);
let #output = gathered.into_scalar().to_bool();
},
}
}
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
Type::Tensor(_) => {
match &self.index {
Type::Scalar(idx_scalar) => {
// To do a scalar select (select just a single index in one dim),
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
// then squeeze the dimension to reduce the rank
let index = &idx_scalar.name;
let output_rank = input_rank - 1;
quote! {
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let slice = Tensor::select(#input, #dim, indices);
let #output = slice.squeeze::<#output_rank>(#dim);
}
}
Type::Tensor(idx_tensor) => {
let index = scope.tensor_use_owned(idx_tensor, node_position);
let index_rank = idx_tensor.dim;
let output_rank = index_rank + input_rank - 1;
match index_rank {
1 => quote! {
let indices = #index;
let #output = Tensor::select(#input, #dim, indices);
},
_ => quote! {
let indices = #index;

let n_dims = indices.dims().len();
let index_flat = match n_dims {
1 => indices.reshape([1, -1]),
n if n >= 2 => indices.flatten::<2>(0, n - 2),
_ => panic!("Number of dimensions must be greater than 0"),
};

let out = index_flat
.iter_dim(0)
.map(|idxs| {
let idxs = idxs.squeeze::<1>(0);
Tensor::select(#input.clone(), #dim, idxs)
})
.collect();
let #output = Tensor::stack::<#output_rank>(out, #dim);
},
}
}
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
}
}
_ => panic!(
"Gather needs Scalar or Tensor output, got {:?}!",
self.output
),
}
}

fn into_node(self) -> super::Node<PS> {
Node::Gather(self)
}

fn register_imports(&self, imports: &mut BurnImports) {
match &self.output {
Type::Scalar(_) => {
imports.register("burn::tensor::cast::ToElement");
}
_ => {}
}
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
Expand All @@ -112,7 +164,7 @@ mod tests {
graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Tensor(TensorType::new_int("tensor2", 1)),
TensorType::new_float("tensor3", 2),
Type::Tensor(TensorType::new_float("tensor3", 2)),
0,
));

Expand Down Expand Up @@ -166,7 +218,7 @@ mod tests {
graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Tensor(TensorType::new_int("tensor2", 2)),
TensorType::new_float("tensor3", 3),
Type::Tensor(TensorType::new_float("tensor3", 3)),
0,
));

Expand Down Expand Up @@ -235,7 +287,7 @@ mod tests {
graph.register(GatherNode::new(
Type::Shape(ShapeType::new("shape1", 3)),
Type::Tensor(TensorType::new_int("tensor1", 1)),
TensorType::new_int("tensor2", 1),
Type::Tensor(TensorType::new_int("tensor2", 1)),
0,
));

Expand Down Expand Up @@ -295,7 +347,7 @@ mod tests {
graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
TensorType::new_float("tensor2", 1),
Type::Tensor(TensorType::new_float("tensor2", 1)),
0,
));

Expand Down Expand Up @@ -343,4 +395,59 @@ mod tests {

assert_tokens(graph.codegen(), expected);
}

#[test]
fn test_codegen_gather_scalar_output() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 1)),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Int64)),
0,
));

graph.register_input_output(
vec!["tensor1".to_string(), "scalar1".to_string()],
vec!["scalar2".to_string()],
);

let expected = quote! {
use burn::tensor::cast::ToElement;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 1>,
scalar1: i64
) -> i64 {
let indices = Tensor::<B, 1, _>::from_data([scalar1], &*self.device);
let gathered = Tensor::select(tensor1, 0, indices);
let scalar2 = gathered.into_scalar().to_i64();
scalar2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
4 changes: 2 additions & 2 deletions crates/burn-import/src/burn/ty.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::burn::ToTokens;
use proc_macro2::Ident;
use proc_macro2::Span;
use proc_macro2::TokenStream;
use quote::quote;

use crate::burn::ToTokens;

#[derive(Debug, Clone)]
pub struct TensorType {
pub name: Ident,
Expand Down Expand Up @@ -199,6 +198,7 @@ impl TensorType {
);
}
let formatted_name = Type::format_name(name.as_ref());

assert_ne!(
dim, 0,
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
Expand Down
Loading