From 29e57f2b3e3a2bca9765a13a69127e78945fa363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B9=B3=E5=B1=B1=E4=BE=91=E6=A8=B9?= Date: Fri, 12 May 2023 15:24:57 +0900 Subject: [PATCH] Support onnx gemm when the src_node is none. --- nngen/onnx/gemm.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/nngen/onnx/gemm.py b/nngen/onnx/gemm.py index 27ea05b8..f9d726aa 100644 --- a/nngen/onnx/gemm.py +++ b/nngen/onnx/gemm.py @@ -22,22 +22,25 @@ def Gemm(visitor, node, for i, src in enumerate(node.input): src_node = util.search_node_from_model(visitor.model, src) - if (i == 0 and src_node.op_type == 'Flatten' and - len(visitor.consumers[src]) == 1): + if src_node is None: + pass + else: + if (i == 0 and src_node.op_type == 'Flatten' and + len(visitor.consumers[src]) == 1): - src_obj = flatten.Flatten(visitor, src_node, no_transpose=True) - srcs.append(src_obj) - continue - - if (i == 0 and src_node.op_type == 'Reshape' and - len(visitor.consumers[src]) == 1): - - shape = visitor.visit(src_node.input[1]) - if len(shape) == 2: - src_obj = reshape.Reshape(visitor, src_node, no_transpose=True) + src_obj = flatten.Flatten(visitor, src_node, no_transpose=True) srcs.append(src_obj) continue + if (i == 0 and src_node.op_type == 'Reshape' and + len(visitor.consumers[src]) == 1): + + shape = visitor.visit(src_node.input[1]) + if len(shape) == 2: + src_obj = reshape.Reshape(visitor, src_node, no_transpose=True) + srcs.append(src_obj) + continue + src_obj = visitor.visit(src) srcs.append(src_obj)