We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4d2cb14 commit 58e9ea0Copy full SHA for 58e9ea0
core/lowering/passes/fuse_addmm_branches.cpp
@@ -49,7 +49,7 @@ struct AddMMBranchFusion {
49
if ((*arm1_start)->kind().toQualString() == std::string("aten::addmm") &&
50
(*(++arm1_start))->kind() == prim::Return &&
51
(*arm2_start)->kind().toQualString() == std::string("aten::matmul") &&
52
- (*(++arm2_start))->kind().toQualString() != std::string("aten::add") &&
+ (*(++arm2_start))->kind().toQualString() == std::string("aten::add") &&
53
(*(++arm2_start))->kind() == prim::Return) {
54
// Make sure that block0 is solely just the aten::addmm op and block1 is matmul + add
55
return true;
0 commit comments