Skip to content

Commit 58e9ea0

Browse files
committed
fix: Fix fuse addmm pass
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 4d2cb14 commit 58e9ea0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

core/lowering/passes/fuse_addmm_branches.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct AddMMBranchFusion {
4949
if ((*arm1_start)->kind().toQualString() == std::string("aten::addmm") &&
5050
(*(++arm1_start))->kind() == prim::Return &&
5151
(*arm2_start)->kind().toQualString() == std::string("aten::matmul") &&
52-
(*(++arm2_start))->kind().toQualString() != std::string("aten::add") &&
52+
(*(++arm2_start))->kind().toQualString() == std::string("aten::add") &&
5353
(*(++arm2_start))->kind() == prim::Return) {
5454
// Make sure that block0 is solely just the aten::addmm op and block1 is matmul + add
5555
return true;

0 commit comments

Comments
 (0)