diff --git a/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.onnx b/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.onnx index 11b973ff4b..4012cd94a1 100644 Binary files a/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.onnx and b/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.py b/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.py index e35301af82..fe4dbd97f5 100755 --- a/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.py +++ b/crates/burn-import/onnx-tests/tests/conv_transpose2d/conv_transpose2d.py @@ -10,7 +10,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.transposed_conv = nn.ConvTranspose2d( - 4, 6, (3, 5), groups=2, stride=(2, 1), padding=(4, 2), dilation=(3, 1) + 4, 6, (3, 5), groups=2, stride=(2, 1), padding=(4, 2), dilation=(3, 1), output_padding=(1, 0), ) def forward(self, x): diff --git a/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.onnx b/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.onnx index 7415482826..67e326d91a 100644 Binary files a/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.onnx and b/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.py b/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.py index 82d483cb25..064c74b29f 100755 --- a/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.py +++ b/crates/burn-import/onnx-tests/tests/conv_transpose3d/conv_transpose3d.py @@ -10,7 +10,7 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.transposed_conv = nn.ConvTranspose3d( - 4, 6, (3, 5, 5), groups=2, stride=(2, 1, 1), padding=(4, 2, 1), dilation=(3, 1, 1) + 4, 6, (3, 5, 5), groups=2, stride=(2, 1, 1), padding=(4, 2, 1), dilation=(3, 1, 1), output_padding=(1, 0, 0), ) def forward(self, x): diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 83b15dba84..234232a762 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -1512,14 +1512,14 @@ mod tests { let output = model.forward(input); - let expected_shape = Shape::from([2, 6, 17, 15]); + let expected_shape = Shape::from([2, 6, 18, 15]); assert_eq!(output.shape(), expected_shape); // We are using the sum of the output tensor to test the correctness of the conv_transpose2d node // because the output tensor is too large to compare with the expected tensor. let output_sum = output.sum().into_scalar(); - let expected_sum = -120.070_15; // result running pytorch model (conv_transpose2d.py) + let expected_sum = -134.96603; // result running pytorch model (conv_transpose2d.py) assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); } @@ -1534,14 +1534,14 @@ mod tests { let output = model.forward(input); - let expected_shape = Shape::from([2, 6, 5, 5, 9]); + let expected_shape = Shape::from([2, 6, 6, 5, 9]); assert_eq!(output.shape(), expected_shape); // We are using the sum of the output tensor to test the correctness of the conv_transpose3d node // because the output tensor is too large to compare with the expected tensor. let output_sum = output.sum().into_scalar(); - let expected_sum = -67.267_15; // result running pytorch model (conv_transpose3d.py) + let expected_sum = -105.69771; // result running pytorch model (conv_transpose3d.py) assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); } diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index 30f26ef567..0cce170125 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -64,12 +64,14 @@ impl NodeCodegen for ConvTranspose2dNode { let dilation = self.config.dilation.to_tokens(); let groups = self.config.groups.to_tokens(); let padding = self.config.padding.to_tokens(); + let padding_out = self.config.padding_out.to_tokens(); let bias = self.config.bias; let tokens = quote! { let #name = ConvTranspose2dConfig::new(#channels, #kernel_size) .with_stride(#stride) .with_padding(#padding) + .with_padding_out(#padding_out) .with_dilation(#dilation) .with_groups(#groups) .with_bias(#bias) @@ -173,6 +175,7 @@ mod tests { let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [3, 3]) .with_stride([1, 1]) .with_padding([0, 0]) + .with_padding_out([0, 0]) .with_dilation([1, 1]) .with_groups(1) .with_bias(true) diff --git a/crates/burn-import/src/burn/node/conv_transpose_3d.rs b/crates/burn-import/src/burn/node/conv_transpose_3d.rs index 5ce39ac3e1..a3ecd1fccc 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_3d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_3d.rs @@ -64,12 +64,14 @@ impl NodeCodegen for ConvTranspose3dNode { let dilation = self.config.dilation.to_tokens(); let groups = self.config.groups.to_tokens(); let padding = self.config.padding.to_tokens(); + let padding_out = self.config.padding_out.to_tokens(); let bias = self.config.bias; let tokens = quote! { let #name = ConvTranspose3dConfig::new(#channels, #kernel_size) .with_stride(#stride) .with_padding(#padding) + .with_padding_out(#padding_out) .with_dilation(#dilation) .with_groups(#groups) .with_bias(#bias) @@ -173,6 +175,7 @@ mod tests { let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [3, 3, 3]) .with_stride([1, 1, 1]) .with_padding([0, 0, 0]) + .with_padding_out([0, 0, 0]) .with_dilation([1, 1, 1]) .with_groups(1) .with_bias(true) diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index a2f43b4e94..4621b129d6 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -229,6 +229,10 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { .remove("group") .map(AttributeValue::into_i64) .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0]); // Trick with remove + empty check is simplest way to not forget some attribute for runtime: if !attrs.is_empty() { @@ -256,6 +260,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { .with_stride([stride[0] as usize, stride[1] as usize]) .with_padding([pads[0] as usize, pads[1] as usize]) .with_dilation([dilations[0] as usize, dilations[1] as usize]) + .with_padding_out([output_padding[0] as usize, output_padding[1] as usize]) .with_groups(group) .with_bias(bias) } @@ -281,6 +286,10 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { .remove("group") .map(AttributeValue::into_i64) .unwrap_or(1) as usize; + let output_padding = attrs + .remove("output_padding") + .map(AttributeValue::into_i64s) + .unwrap_or_else(|| vec![0, 0, 0]); // Trick with remove + empty check is simplest way to not forget some attribute for runtime: if !attrs.is_empty() { @@ -316,6 +325,11 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig { dilations[1] as usize, dilations[2] as usize, ]) + .with_padding_out([ + output_padding[0] as usize, + output_padding[1] as usize, + output_padding[2] as usize, + ]) .with_groups(group) .with_bias(bias) }