Skip to content

Commit ca2b5f9

Browse files
committed
fix(//core/lowering): Conv2D -> _convolution pass was triggering conv
transpose instead of conv Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent c83447e commit ca2b5f9

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

core/lowering/passes/conv2d_to_convolution.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
1414
return (%4))IR";
1515
std::string convolution_pattern = R"IR(
1616
graph(%x, %w, %b, %s, %p, %d, %g):
17-
%1 : bool = prim::Constant[value=1]()
17+
%1 : bool = prim::Constant[value=0]()
1818
%2 : int[] = prim::Constant[value=[0, 0]]()
19-
%3 : bool = prim::Constant[value=0]()
20-
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %3)
19+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1)
2120
return (%4))IR";;
2221

2322
// replace matmul + add pattern to linear

0 commit comments

Comments
 (0)