Skip to content

Commit

Permalink
Revert "Consolidate tests"
Browse files Browse the repository at this point in the history
This reverts commit b6631ba.
  • Loading branch information
jameshiew committed Feb 9, 2025
1 parent b6631ba commit a1f87b8
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2188,26 +2188,51 @@ mod tests {
}

#[test]
fn add_constant() {
fn add_constant_f32() {
let device = Default::default();
const CONST_VALUE: i32 = 2;
let model = constant_f32::Model::<Backend>::new(&device);
let input = Tensor::<Backend, 3>::zeros(Shape::from([2, 3, 4]), &device);
let expected = Tensor::<Backend, 3>::full([2, 3, 4], 2, &device).to_data();

let float_input = Tensor::<Backend, 3>::zeros(Shape::from([2, 3, 4]), &device);
let float_expected = Tensor::<Backend, 3>::full([2, 3, 4], CONST_VALUE, &device).to_data();
let output = model.forward(input);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn add_constant_f64() {
let device = Default::default();
let model = constant_f64::Model::<Backend>::new(&device);
let input = Tensor::<Backend, 3>::zeros(Shape::from([2, 3, 4]), &device);
let expected = Tensor::<Backend, 3>::full([2, 3, 4], 2, &device).to_data();

let output = model.forward(input);

output.to_data().assert_eq(&expected, true);
}

let f32_output = constant_f32::Model::<Backend>::new(&device).forward(float_input.clone());
f32_output.to_data().assert_eq(&float_expected, true);
let f64_output = constant_f64::Model::<Backend>::new(&device).forward(float_input.clone());
f64_output.to_data().assert_eq(&float_expected, true);
#[test]
fn add_constant_i32() {
let device = Default::default();
let model = constant_i32::Model::<Backend>::new(&device);
let input = Tensor::<Backend, 3, Int>::zeros(Shape::from([2, 3, 4]), &device);
let expected = Tensor::<Backend, 3, Int>::full([2, 3, 4], 2, &device).to_data();

let output = model.forward(input);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn add_constant_i64() {
let device = Default::default();
let model = constant_i64::Model::<Backend>::new(&device);
let input = Tensor::<Backend, 3, Int>::zeros(Shape::from([2, 3, 4]), &device);
let expected = Tensor::<Backend, 3, Int>::full([2, 3, 4], 2, &device).to_data();

let int_input = Tensor::<Backend, 3, Int>::zeros(Shape::from([2, 3, 4]), &device);
let int_expected =
Tensor::<Backend, 3, Int>::full([2, 3, 4], CONST_VALUE, &device).to_data();
let output = model.forward(input);

let i32_output = constant_i32::Model::<Backend>::new(&device).forward(int_input.clone());
i32_output.to_data().assert_eq(&int_expected, true);
let i64_output = constant_i64::Model::<Backend>::new(&device).forward(int_input.clone());
i64_output.to_data().assert_eq(&int_expected, true);
output.to_data().assert_eq(&expected, true);
}

#[test]
Expand Down

0 comments on commit a1f87b8

Please # to comment.