Skip to content

Commit

Permalink
Add missing output padding to conv transpose ONNX (#2216)
Browse files Browse the repository at this point in the history
* Add output_padding support for ONNX ConvTranspose

* Add missing codegen

* Fix output padding codegen test
  • Loading branch information
laggui authored Aug 29, 2024
1 parent 28c2d4e commit a9abd8f
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 6 deletions.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.transposed_conv = nn.ConvTranspose2d(
4, 6, (3, 5), groups=2, stride=(2, 1), padding=(4, 2), dilation=(3, 1)
4, 6, (3, 5), groups=2, stride=(2, 1), padding=(4, 2), dilation=(3, 1), output_padding=(1, 0),
)

def forward(self, x):
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.transposed_conv = nn.ConvTranspose3d(
4, 6, (3, 5, 5), groups=2, stride=(2, 1, 1), padding=(4, 2, 1), dilation=(3, 1, 1)
4, 6, (3, 5, 5), groups=2, stride=(2, 1, 1), padding=(4, 2, 1), dilation=(3, 1, 1), output_padding=(1, 0, 0),
)

def forward(self, x):
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,14 +1512,14 @@ mod tests {

let output = model.forward(input);

let expected_shape = Shape::from([2, 6, 17, 15]);
let expected_shape = Shape::from([2, 6, 18, 15]);
assert_eq!(output.shape(), expected_shape);

// We are using the sum of the output tensor to test the correctness of the conv_transpose2d node
// because the output tensor is too large to compare with the expected tensor.
let output_sum = output.sum().into_scalar();

let expected_sum = -120.070_15; // result running pytorch model (conv_transpose2d.py)
let expected_sum = -134.96603; // result running pytorch model (conv_transpose2d.py)

assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
Expand All @@ -1534,14 +1534,14 @@ mod tests {

let output = model.forward(input);

let expected_shape = Shape::from([2, 6, 5, 5, 9]);
let expected_shape = Shape::from([2, 6, 6, 5, 9]);
assert_eq!(output.shape(), expected_shape);

// We are using the sum of the output tensor to test the correctness of the conv_transpose3d node
// because the output tensor is too large to compare with the expected tensor.
let output_sum = output.sum().into_scalar();

let expected_sum = -67.267_15; // result running pytorch model (conv_transpose3d.py)
let expected_sum = -105.69771; // result running pytorch model (conv_transpose3d.py)

assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-import/src/burn/node/conv_transpose_2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConvTranspose2dNode {
let dilation = self.config.dilation.to_tokens();
let groups = self.config.groups.to_tokens();
let padding = self.config.padding.to_tokens();
let padding_out = self.config.padding_out.to_tokens();
let bias = self.config.bias;

let tokens = quote! {
let #name = ConvTranspose2dConfig::new(#channels, #kernel_size)
.with_stride(#stride)
.with_padding(#padding)
.with_padding_out(#padding_out)
.with_dilation(#dilation)
.with_groups(#groups)
.with_bias(#bias)
Expand Down Expand Up @@ -173,6 +175,7 @@ mod tests {
let conv_transpose_2d = ConvTranspose2dConfig::new([3, 3], [3, 3])
.with_stride([1, 1])
.with_padding([0, 0])
.with_padding_out([0, 0])
.with_dilation([1, 1])
.with_groups(1)
.with_bias(true)
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-import/src/burn/node/conv_transpose_3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConvTranspose3dNode {
let dilation = self.config.dilation.to_tokens();
let groups = self.config.groups.to_tokens();
let padding = self.config.padding.to_tokens();
let padding_out = self.config.padding_out.to_tokens();
let bias = self.config.bias;

let tokens = quote! {
let #name = ConvTranspose3dConfig::new(#channels, #kernel_size)
.with_stride(#stride)
.with_padding(#padding)
.with_padding_out(#padding_out)
.with_dilation(#dilation)
.with_groups(#groups)
.with_bias(#bias)
Expand Down Expand Up @@ -173,6 +175,7 @@ mod tests {
let conv_transpose_3d = ConvTranspose3dConfig::new([3, 3], [3, 3, 3])
.with_stride([1, 1, 1])
.with_padding([0, 0, 0])
.with_padding_out([0, 0, 0])
.with_dilation([1, 1, 1])
.with_groups(1)
.with_bias(true)
Expand Down
14 changes: 14 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1) as usize;
let output_padding = attrs
.remove("output_padding")
.map(AttributeValue::into_i64s)
.unwrap_or_else(|| vec![0, 0]);

// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
Expand Down Expand Up @@ -256,6 +260,7 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
.with_stride([stride[0] as usize, stride[1] as usize])
.with_padding([pads[0] as usize, pads[1] as usize])
.with_dilation([dilations[0] as usize, dilations[1] as usize])
.with_padding_out([output_padding[0] as usize, output_padding[1] as usize])
.with_groups(group)
.with_bias(bias)
}
Expand All @@ -281,6 +286,10 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
.remove("group")
.map(AttributeValue::into_i64)
.unwrap_or(1) as usize;
let output_padding = attrs
.remove("output_padding")
.map(AttributeValue::into_i64s)
.unwrap_or_else(|| vec![0, 0, 0]);

// Trick with remove + empty check is simplest way to not forget some attribute for runtime:
if !attrs.is_empty() {
Expand Down Expand Up @@ -316,6 +325,11 @@ pub fn conv_transpose3d_config(curr: &Node) -> ConvTranspose3dConfig {
dilations[1] as usize,
dilations[2] as usize,
])
.with_padding_out([
output_padding[0] as usize,
output_padding[1] as usize,
output_padding[2] as usize,
])
.with_groups(group)
.with_bias(bias)
}
Expand Down

0 comments on commit a9abd8f

Please # to comment.