diff --git a/docker/requirements.txt b/docker/requirements.txt
index a5c1f93..214809b 100644
--- a/docker/requirements.txt
+++ b/docker/requirements.txt
@@ -1,4 +1,5 @@
bitstring>=3.1.7
+clize==4.1.1
numpy
onnx==1.7.0
onnxruntime==1.4.0
diff --git a/docs/QONNX/bipolar_quant_op.md b/docs/QONNX/bipolar_quant_op.md
new file mode 100644
index 0000000..3a70458
--- /dev/null
+++ b/docs/QONNX/bipolar_quant_op.md
@@ -0,0 +1,92 @@
+### **BipolarQuant**
+
+Calculates the binary quantized values of one input data (Tensor) and produces one output data (Tensor).
+Additionally, takes one float as input, which define the scaling.
+
+#### Version
+
+This operator is not part of the ONNX standard and is not currently versioned.
+
+#### Attributes
+
+
+
+
+#### Inputs
+
+
+- X (differentiable) : tensor(float32)
+- input tensor to quantize
+- scale : float32
+- The scale factor
+
+
+
+#### Outputs
+
+
+- Y (differentiable) : tensor(float32)
+- Output tensor
+
+
+
+#### Examples
+
+BipolarQuant
+
+```python
+from onnx import helper
+import numpy as np
+
+# Define node settings and input
+x = np.random.randn(100).astype(np.float32)*10.
+scale = np.array(1.)
+
+# Create node
+node = helper.make_node(
+ 'BipolarQuant',
+ domain='finn.custom_op.general',
+ inputs=['x', 'scale'],
+ outputs=['y'],
+)
+
+# Execute the same settings with the reference implementation (quant)
+# See the sample implementation for more details on quant.
+output_ref = binary_quant(x, scale)
+
+# Execute node and compare
+expect(node, inputs=[x, scale], outputs=[output_ref], name='test_binary_quant')
+
+```
+
+
+
+
+#### Sample Implementation
+
+
+BipolarQuant
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+
+def binary_quant(inp_tensor, scale):
+ # Quantizing
+ y_int = inp_tensor
+ y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
+ y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
+ # Scaling
+ out_tensor = y_int * scale
+
+ return out_tensor
+
+```
+
+
diff --git a/docs/QONNX/quant_op.md b/docs/QONNX/quant_op.md
new file mode 100644
index 0000000..aa7e94a
--- /dev/null
+++ b/docs/QONNX/quant_op.md
@@ -0,0 +1,193 @@
+### **Quant**
+
+Calculates the quantized values of one input data (Tensor) and produces one output data (Tensor).
+Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization.
+The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute
+rounding_mode defines how quantized values are rounded.
+
+Note: This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
+
+#### Version
+
+This operator is not part of the ONNX standard and is not currently versioned.
+
+#### Attributes
+
+
+- signed : int (default is 1)
+- Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
+- narrow : int (default is 0)
+- Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
+- rounding_mode : string (default is "ROUND")
+- Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation.
+
+
+#### Inputs
+
+
+- X (differentiable) : tensor(float32)
+- input tensor to quantize
+- scale : float32
+- The scale factor
+- zeropt : float32
+- The zero-point
+- bitwidth : int32
+- The number of bits used by the quantization
+
+
+
+#### Outputs
+
+
+- Y (differentiable) : tensor(float32)
+- Output tensor
+
+
+
+#### Examples
+
+Quant
+
+```python
+from onnx import helper
+import numpy as np
+
+# Define node settings and input
+x = np.random.randn(100).astype(np.float32)*10.
+scale = np.array(1.)
+zeropt = np.array(0.)
+bitwidth = np.array(4)
+signed = 1
+narrow = 0
+rounding_mode = "ROUND"
+
+# Create node
+node = helper.make_node(
+ 'Quant',
+ domain='finn.custom_op.general',
+ inputs=['x', 'scale', 'zeropt', 'bitwidth'],
+ outputs=['y'],
+ narrow=narrow,
+ signed=signed,
+ rounding_mode=rounding_mode,
+)
+
+# Execute the same settings with the reference implementation (quant)
+# See the sample implementation for more details on quant.
+output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
+
+# Execute node and compare
+expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_quant')
+
+```
+
+
+
+
+#### Sample Implementation
+
+
+Quant
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+
+def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode):
+ # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
+ # Scaling
+ y_int = inp_tensor / scale
+ y_int = y_int + zeropt
+ # Clamping
+ min_int_val = min_int(signed, narrow, bitwidth)
+ max_int_val = max_int(signed, narrow, bitwidth)
+ y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
+ y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
+ # Rounding
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y_int = rounding_fx(y_int)
+
+ # Re-scaling
+ out_tensor = y_int - zeropt
+ out_tensor = out_tensor * scale
+
+ return out_tensor
+
+def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the minimum integer representable by a given number of bits.
+ Args:
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the minimum value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+ Returns:
+ int: Maximum unsigned integer that can be represented according to
+ the input arguments.
+ Examples:
+ >>> min_int(signed=True, narrow_range=True, bit_width=8)
+ int(-127)
+ >>> min_int(signed=False, narrow_range=True, bit_width=8)
+ int(0)
+ >>> min_int(signed=True, narrow_range=False, bit_width=8)
+ int(-128)
+ >>> min_int(signed=False, narrow_range=False, bit_width=8)
+ int(0)
+ """
+ if signed and narrow_range:
+ value = -(2 ** (bit_width - 1)) + 1
+ elif signed and not narrow_range:
+ value = -(2 ** (bit_width - 1))
+ else:
+ value = 0 * bit_width
+ return value
+
+
+def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the maximum integer representable by a given number of bits.
+ Args:
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the maximum unsigned value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+ Returns:
+ Tensor: Maximum integer that can be represented according to
+ the input arguments.
+ Examples:
+ >>> max_int(signed=True, narrow_range=True, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=True, bit_width=8)
+ int(254)
+ >>> max_int(signed=True, narrow_range=False, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=False, bit_width=8)
+ int(255)
+ """
+ if not signed and not narrow_range:
+ value = (2 ** bit_width) - 1
+ elif not signed and narrow_range:
+ value = (2 ** bit_width) - 2
+ else:
+ value = (2 ** (bit_width - 1)) - 1
+ return value
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of Quant and Trunc ops
+ to the corresponding numpy functions."""
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+
+```
+
+
diff --git a/docs/QONNX/trunc_op.md b/docs/QONNX/trunc_op.md
new file mode 100644
index 0000000..80bfc4c
--- /dev/null
+++ b/docs/QONNX/trunc_op.md
@@ -0,0 +1,131 @@
+### **Trunc**
+
+Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor).
+Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization.
+The attribute rounding_mode defines how truncated values are rounded.
+
+#### Version
+
+This operator is not part of the ONNX standard and is not currently versioned.
+
+#### Attributes
+
+
+- rounding_mode : string (default is "FLOOR")
+- Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation.
+
+
+#### Inputs
+
+
+- X (differentiable) : tensor(float32)
+- input tensor to truncate
+- scale : float32
+- The scale factor
+- zeropt : float32
+- The zero-point
+- in_bitwidth : int32
+- The number of bits used at the input of the truncation
+- out_bitwidth : int32
+- The number of bits used at the output of the truncation
+
+
+
+#### Outputs
+
+
+- Y (differentiable) : tensor(float32)
+- Output tensor
+
+
+
+#### Examples
+
+Trunc
+
+```python
+from onnx import helper
+import numpy as np
+
+# Define node settings and input
+x = np.random.randn(100).astype(np.float32)*10.
+scale = np.array(1.)
+zeropt = np.array(0.)
+in_bitwidth = np.array(10)
+out_bitwidth = np.array(4)
+rounding_mode = "ROUND"
+
+# Create node
+node = helper.make_node(
+ 'Trunc',
+ domain='finn.custom_op.general',
+ inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
+ outputs=['y'],
+ rounding_mode=rounding_mode,
+)
+
+# Execute the same settings with the reference implementation (trunc)
+# See the sample implementation for more details on trunc.
+output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)
+
+# Execute node and compare
+expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')
+
+```
+
+
+
+
+#### Sample Implementation
+
+
+Trunc
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+
+def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
+ # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Truncate
+ trunc_bit_width = input_bit_width - output_bit_width
+ trunc_scale = 2.0 ** trunc_bit_width
+ y = y / trunc_scale
+
+ # To int
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ y = y - zeropt
+ y = y * scale
+
+ return y
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of Quant and Trunc ops
+ to the corresponding numpy functions."""
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+
+```
+
+
diff --git a/setup.cfg b/setup.cfg
index e0a13f1..c437139 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -55,6 +55,7 @@ docs =
sphinx>=3.2.1
sphinx_rtd_theme>=0.5.0
onnx =
+ clize==4.1.1
onnx==1.7.0
onnxruntime==1.4.0
toposort>=1.5.0
@@ -63,6 +64,8 @@ testing =
pytest-cov
[options.entry_points]
+console_scripts =
+ inference_cost = finn.util.inference_cost:main
# Add here console scripts like:
# console_scripts =
# script_name = finn.finn_base.module:function
diff --git a/src/finn/analysis/inference_cost.py b/src/finn/analysis/inference_cost.py
new file mode 100644
index 0000000..1dc6a17
--- /dev/null
+++ b/src/finn/analysis/inference_cost.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distributionode.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permissionode.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+
+from finn.util.basic import get_by_name
+
+
+def get_node_tensor_dtypes(model, node):
+ # input tensor (input 0)
+ i_name = node.input[0]
+ i_dtype = model.get_tensor_datatype(i_name)
+ # weight tensor (input 1)
+ w_name = node.input[1]
+ w_dtype = model.get_tensor_datatype(w_name)
+ # output tensor (input 0)
+ o_name = node.output[0]
+ o_dtype = model.get_tensor_datatype(o_name)
+ return (i_dtype, w_dtype, o_dtype)
+
+
+def get_node_tensor_shapes(model, node):
+ # input tensor (input 0)
+ i_name = node.input[0]
+ i_shape = model.get_tensor_shape(i_name)
+ assert i_shape is not None, "Input has undefined shape: " + str(node)
+ # weight tensor (input 1)
+ w_name = node.input[1]
+ w_shape = model.get_tensor_shape(w_name)
+ assert w_shape is not None, "Weight has undefined shape: " + str(node)
+ # output tensor (output 0)
+ o_name = node.output[0]
+ o_shape = model.get_tensor_shape(o_name)
+ assert o_shape is not None, "Output has undefined shape: " + str(node)
+ return (i_shape, w_shape, o_shape)
+
+
+def get_node_weight_density(model, w_name):
+ w_tensor = model.get_initializer(w_name)
+ if w_tensor is None:
+ return 1.0
+ w_total = np.prod(w_tensor.shape)
+ w_density = np.count_nonzero(w_tensor) / w_total
+ return w_density
+
+
+def aggregate_dict_keys(res_dict):
+ total_dict = {}
+ for layer in res_dict:
+ layer_res_dict = res_dict[layer]
+ for r_type in layer_res_dict.keys():
+ if "efficiency" in r_type:
+ continue
+ r_amount = layer_res_dict[r_type]
+ r_amount = float(r_amount)
+ if r_type in total_dict.keys():
+ total_dict[r_type] += r_amount
+ else:
+ total_dict[r_type] = r_amount
+ return total_dict
+
+
+def inference_cost_conv(model, node, discount_sparsity):
+ # extract info about the conv kernel attributes
+ k = get_by_name(node.attribute, "kernel_shape").ints
+ k_prod = np.prod(k)
+ group = get_by_name(node.attribute, "group")
+ if group is None:
+ group = 1
+ else:
+ group = group.i
+ # extract info from tensor shapes and datatypes
+ (i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node)
+ (i_shape, w_shape, o_shape) = get_node_tensor_shapes(model, node)
+ bsize = i_shape[0]
+ ifm_ch = i_shape[1]
+ ofm_ch = o_shape[1]
+ assert ofm_ch == w_shape[0], "Mismatch in output channels"
+ assert ofm_ch % group == 0, "Invalid group setting: " + str(node)
+ ofm_pix_total = np.prod(o_shape[2:])
+ n_macs = bsize * (ofm_ch // group) * ifm_ch * k_prod * ofm_pix_total
+ w_mem = np.prod(w_shape)
+ o_mem = np.prod(o_shape)
+ if discount_sparsity:
+ wname = node.input[1]
+ density = get_node_weight_density(model, wname)
+ n_macs *= density
+ w_mem *= density
+ idt_name = i_dtype.name
+ wdt_name = w_dtype.name
+ odt_name = o_dtype.name
+ mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
+ w_mem_type_str = "mem_w_%s" % (wdt_name)
+ o_mem_type_str = "mem_o_%s" % (odt_name)
+ ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
+ return ret
+
+
+def inference_cost_matmul(model, node, discount_sparsity):
+ # extract info from tensor shapes and datatypes
+ (i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node)
+ (i_shape, w_shape, o_shape) = get_node_tensor_shapes(model, node)
+ if node.op_type == "Gemm":
+ assert len(i_shape) == 2 and len(w_shape) == 2
+ tA = get_by_name(node.attribute, "transA")
+ tB = get_by_name(node.attribute, "transB")
+ if tA is not None and tA.i == 1:
+ i_shape = i_shape[::-1]
+ if tB is not None and tB.i == 1:
+ w_shape = w_shape[::-1]
+ # exclude common dim (last axis) from one side to avoid duplication
+ n_macs = np.prod(i_shape[:-1]) * np.prod(w_shape)
+ # deal with both dyn,param and dyn,dyn cases for weight memory
+ inp0_is_const = model.get_initializer(node.input[0]) is not None
+ inp1_is_const = model.get_initializer(node.input[1]) is not None
+ if inp0_is_const and (not inp1_is_const):
+ # inp 0 is static
+ w_mem = np.prod(i_shape)
+ wname = node.input[0]
+ elif (not inp0_is_const) and inp1_is_const:
+ # inp 1 is static
+ w_mem = np.prod(w_shape)
+ wname = node.input[1]
+ elif (not inp0_is_const) and (not inp1_is_const):
+ # both inputs dynamic
+ w_mem = 0
+ wname = None
+ if discount_sparsity and wname is not None:
+ density = get_node_weight_density(model, wname)
+ n_macs *= density
+ w_mem *= density
+ o_mem = np.prod(o_shape)
+ idt_name = i_dtype.name
+ wdt_name = w_dtype.name
+ odt_name = o_dtype.name
+ mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
+ w_mem_type_str = "mem_w_%s" % (wdt_name)
+ o_mem_type_str = "mem_o_%s" % (odt_name)
+ ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}
+ return ret
+
+
+def inference_cost_upsample(model, node, discount_sparsity):
+ # extract info about the upsampling kernel attributes
+ mode = get_by_name(node.attribute, "mode").s.decode("utf-8")
+ scales_tensor = node.input[1]
+ scales_initializer = model.get_initializer(scales_tensor)
+
+ # extract info from tensor shapes and datatypes
+ (i_dtype, scale_dtype, o_dtype) = get_node_tensor_dtypes(model, node)
+ (i_shape, scale_shape, o_shape) = get_node_tensor_shapes(model, node)
+ bsize = i_shape[0]
+ ifm_ch = i_shape[1]
+ ofm_pix_total = np.prod(o_shape[2:])
+
+ # MAC calculation
+ if mode == "nearest":
+ # No calculation involved, since data is just copied over multiple times
+ n_macs = 0
+ elif mode == "linear":
+ # Data gets linearly interpolated in each dimension
+ # Two MACs per dimension and output pixel assumed
+ n_dim_scaling = np.sum(scales_initializer > 1)
+ n_macs = 2 * n_dim_scaling * ofm_pix_total * ifm_ch * bsize
+ else:
+ raise ValueError(f"Upsampling mode {mode} not supported for estimation.")
+
+ # Mem calculation
+ o_mem = np.prod(o_shape)
+ idt_name = i_dtype.name
+ odt_name = o_dtype.name
+ mac_op_type_str = "op_mac_%s_%s" % (idt_name, idt_name)
+ o_mem_type_str = "mem_o_%s" % (odt_name)
+
+ ret = {mac_op_type_str: n_macs, o_mem_type_str: o_mem}
+ return ret
+
+
+def inference_cost(model, discount_sparsity=True):
+ "Ensure all nodes have unique names prior to calling this analysis pass."
+
+ node_costs = {}
+ zero_cost_ops = [
+ "MaxPool",
+ "AveragePool",
+ "Quant",
+ "Reshape",
+ "Concat",
+ "Transpose",
+ "Div",
+ "Mul",
+ "Add",
+ "Sub",
+ "BatchNormalization",
+ "Relu",
+ "Elu",
+ "Selu",
+ "Sigmoid",
+ "Identity",
+ "Flatten",
+ ]
+ unsupported_ops = set()
+ inference_cost_fxn_map = {
+ "Conv": inference_cost_conv,
+ "MatMul": inference_cost_matmul,
+ "Gemm": inference_cost_matmul,
+ "Upsample": inference_cost_upsample,
+ }
+ for node in model.graph.node:
+ if node.op_type in inference_cost_fxn_map.keys():
+ node_cost = inference_cost_fxn_map[node.op_type](
+ model, node, discount_sparsity
+ )
+ node_costs[node.name] = node_cost
+ elif node.op_type in zero_cost_ops:
+ continue
+ else:
+ unsupported_ops.add(node.op_type)
+
+ ret = aggregate_dict_keys(node_costs)
+ ret["unsupported"] = unsupported_ops
+ ret["discount_sparsity"] = discount_sparsity
+
+ return ret
diff --git a/src/finn/custom_op/general/__init__.py b/src/finn/custom_op/general/__init__.py
index 3bb8bef..102a31b 100644
--- a/src/finn/custom_op/general/__init__.py
+++ b/src/finn/custom_op/general/__init__.py
@@ -26,12 +26,15 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+from finn.custom_op.general.bipolar_quant import BipolarQuant
from finn.custom_op.general.debugmarker import DebugMarker
from finn.custom_op.general.genericpartition import GenericPartition
from finn.custom_op.general.im2col import Im2Col
from finn.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from finn.custom_op.general.multithreshold import MultiThreshold
+from finn.custom_op.general.quant import Quant
from finn.custom_op.general.quantavgpool2d import QuantAvgPool2d
+from finn.custom_op.general.trunc import Trunc
from finn.custom_op.general.xnorpopcount import XnorPopcountMatMul
custom_op = dict()
@@ -43,3 +46,6 @@
custom_op["MultiThreshold"] = MultiThreshold
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
custom_op["Im2Col"] = Im2Col
+custom_op["Quant"] = Quant
+custom_op["Trunc"] = Trunc
+custom_op["BipolarQuant"] = BipolarQuant
diff --git a/src/finn/custom_op/general/bipolar_quant.py b/src/finn/custom_op/general/bipolar_quant.py
new file mode 100644
index 0000000..fa4c02f
--- /dev/null
+++ b/src/finn/custom_op/general/bipolar_quant.py
@@ -0,0 +1,106 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import onnx.helper as helper
+
+from finn.core.datatype import DataType
+from finn.custom_op.base import CustomOp
+
+
+def binary_quant(inp_tensor, scale):
+ # ToDo: Update this link, when the PR gets merged
+ # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
+
+ # Quantizing
+ y_int = inp_tensor
+ y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
+ y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
+ # Scaling
+ out_tensor = y_int * scale
+
+ return out_tensor
+
+
+class BipolarQuant(CustomOp):
+ """Bipolar quantization operation for QONNX. Takes four inputs:
+ - input tensor to quantize
+ - the scale
+
+ The output is a tensor of the same shape as the input tensor, with quantized
+ values.
+ """
+
+ def get_nodeattr_types(self):
+ return dict()
+
+ def make_shape_compatible_op(self, model):
+ node = self.onnx_node
+ return helper.make_node("Identity", [node.input[0]], [node.output[0]])
+
+ def get_integer_datatype(self, model):
+ return DataType["BIPOLAR"]
+
+ def get_output_dtype(self, model):
+ node = self.onnx_node
+ # scale must be read from initializers
+ scale = model.get_initializer(node.input[1])
+ # determine the FINN DataType
+ unit_scale = np.all(scale == 1.0)
+ if unit_scale:
+ finn_dt = self.get_integer_datatype(model)
+ else:
+ finn_dt = DataType["FLOAT32"]
+
+ return finn_dt
+
+ def infer_node_datatype(self, model):
+ try:
+ finn_dt = self.get_output_dtype(model)
+ except AssertionError:
+ finn_dt = DataType["FLOAT32"]
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], finn_dt)
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ # calculate output
+ ret = binary_quant(inp_tensor, scale)
+ # ensure output is ndarray (even if 0d)
+ # since numpy silently flattens 0d arrays to scalars
+ # more: https://github.com/numpy/numpy/issues/13105
+ if not isinstance(ret, np.ndarray):
+ ret = np.asarray(ret)
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
diff --git a/src/finn/custom_op/general/multithreshold.py b/src/finn/custom_op/general/multithreshold.py
index 35a84d8..56a05d4 100644
--- a/src/finn/custom_op/general/multithreshold.py
+++ b/src/finn/custom_op/general/multithreshold.py
@@ -102,7 +102,17 @@ def make_shape_compatible_op(self, model):
def infer_node_datatype(self, model):
node = self.onnx_node
odt = self.get_nodeattr("out_dtype")
- model.set_tensor_datatype(node.output[0], DataType[odt])
+ is_float = False
+ scale = self.get_nodeattr("out_scale")
+ bias = self.get_nodeattr("out_bias")
+ if scale is not None and (int(scale) != scale):
+ is_float = True
+ if bias is not None and (int(bias) != bias):
+ is_float = True
+ if is_float:
+ model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
+ else:
+ model.set_tensor_datatype(node.output[0], DataType[odt])
def execute_node(self, context, graph):
node = self.onnx_node
diff --git a/src/finn/custom_op/general/quant.py b/src/finn/custom_op/general/quant.py
new file mode 100644
index 0000000..8eb3887
--- /dev/null
+++ b/src/finn/custom_op/general/quant.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import onnx.helper as helper
+
+from finn.core.datatype import DataType
+from finn.custom_op.base import CustomOp
+
+
+def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the minimum integer representable by a given number of bits.
+ Args:
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the minimum value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+ Returns:
+ int: Maximum unsigned integer that can be represented according to
+ the input arguments.
+ Examples:
+ >>> min_int(signed=True, narrow_range=True, bit_width=8)
+ int(-127)
+ >>> min_int(signed=False, narrow_range=True, bit_width=8)
+ int(0)
+ >>> min_int(signed=True, narrow_range=False, bit_width=8)
+ int(-128)
+ >>> min_int(signed=False, narrow_range=False, bit_width=8)
+ int(0)
+ """
+ if signed and narrow_range:
+ value = -(2 ** (bit_width - 1)) + 1
+ elif signed and not narrow_range:
+ value = -(2 ** (bit_width - 1))
+ else:
+ value = 0 * bit_width
+ return value
+
+
+def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the maximum integer representable by a given number of bits.
+ Args:
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the maximum unsigned value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+ Returns:
+ Tensor: Maximum integer that can be represented according to
+ the input arguments.
+ Examples:
+ >>> max_int(signed=True, narrow_range=True, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=True, bit_width=8)
+ int(254)
+ >>> max_int(signed=True, narrow_range=False, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=False, bit_width=8)
+ int(255)
+ """
+ if not signed and not narrow_range:
+ value = (2 ** bit_width) - 1
+ elif not signed and narrow_range:
+ value = (2 ** bit_width) - 2
+ else:
+ value = (2 ** (bit_width - 1)) - 1
+ return value
+
+
+def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode):
+ # ToDo: Update this link, when the PR gets merged
+ # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
+
+ # Scaling
+ y_int = inp_tensor / scale
+ y_int = y_int + zeropt
+ if bitwidth == 1 and signed:
+ # BUG: 1-bit Quant ops currently not exported correctly
+ # manually convert to bipolar values
+ y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
+ y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
+ else:
+ # Clamping
+ min_int_val = min_int(signed, narrow, bitwidth)
+ max_int_val = max_int(signed, narrow, bitwidth)
+ y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
+ y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
+ # Rounding
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y_int = rounding_fx(y_int)
+ # Re-scaling
+ out_tensor = y_int - zeropt
+ out_tensor = out_tensor * scale
+
+ return out_tensor
+
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of Quant and Trunc ops
+ to the corresponding numpy functions."""
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+
+
+class Quant(CustomOp):
+ """Generic quantization operation for QONNX. Takes four inputs:
+ - input tensor to quantize
+ - the scale
+ - the zero-point
+ - the bit-width
+
+ The output is a tensor of the same shape as the input tensor, with quantized
+ values.
+ """
+
+ def get_nodeattr_types(self):
+ return {
+ # whether the quantization interval should be signed or not
+ # (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127])
+ "signed": ("i", True, 1),
+ # when signed=1, whether to use narrow range or not
+ # (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127])
+ "narrow": ("i", True, 1),
+ # The rounding mode, which is used for the quant function
+ # ToDo: This should be required (True) instead of optional (False)
+ "rounding_mode": ("s", False, "ROUND"),
+ }
+
+ def make_shape_compatible_op(self, model):
+ node = self.onnx_node
+ return helper.make_node("Identity", [node.input[0]], [node.output[0]])
+
+ def get_integer_datatype(self, model):
+ signed = self.get_nodeattr("signed")
+ bit_width = model.get_initializer(self.onnx_node.input[3])
+ bit_width = int(bit_width)
+ if bit_width == 1:
+ if signed:
+ finn_dt = DataType["BIPOLAR"]
+ else:
+ finn_dt = DataType["BINARY"]
+ else:
+ if signed:
+ finn_dt = DataType["INT" + str(bit_width)]
+ else:
+ finn_dt = DataType["UINT" + str(bit_width)]
+ return finn_dt
+
+ def get_output_dtype(self, model):
+ node = self.onnx_node
+ # scale, zero-point and bitwidth must be read from initializers
+ scale = model.get_initializer(node.input[1])
+ zeropt = model.get_initializer(node.input[2])
+ bitwidth = model.get_initializer(node.input[3])
+ assert scale is not None, "Found unspecified scale for Quant node: " + str(node)
+ assert (
+ zeropt is not None
+ ), "Found unspecified zero point for Quant node: " + str(node)
+ assert (
+ bitwidth is not None
+ ), "Found unspecified bitwidth for Quant node: " + str(node)
+ # extract the bitwidth (assume scalar)
+ assert bitwidth.ndim == 0, "Bitwidth must be scalar for Quant node: " + str(
+ node
+ )
+ bitwidth = bitwidth.item()
+ assert (
+ int(bitwidth) == bitwidth
+ ), "Bitwidth must be integer for Quant node: " + str(node)
+ bitwidth = int(bitwidth)
+ # determine the FINN DataType
+ unit_scale = np.all(scale == 1.0)
+ zero_zeropt = np.all(zeropt == 0.0)
+ assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now"
+ if unit_scale and zero_zeropt:
+ finn_dt = self.get_integer_datatype(model)
+ else:
+ finn_dt = DataType["FLOAT32"]
+
+ return finn_dt
+
+ def infer_node_datatype(self, model):
+ try:
+ finn_dt = self.get_output_dtype(model)
+ except AssertionError:
+ finn_dt = DataType["FLOAT32"]
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], finn_dt)
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ zeropt = context[node.input[2]]
+ bitwidth = context[node.input[3]]
+ # save attributes
+ signed = self.get_nodeattr("signed")
+ narrow = self.get_nodeattr("narrow")
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ # calculate output
+ ret = quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
+ # ensure output is ndarray (even if 0d)
+ # since numpy silently flattens 0d arrays to scalars
+ # more: https://github.com/numpy/numpy/issues/13105
+ if not isinstance(ret, np.ndarray):
+ ret = np.asarray(ret)
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
diff --git a/src/finn/custom_op/general/trunc.py b/src/finn/custom_op/general/trunc.py
new file mode 100644
index 0000000..7faaeda
--- /dev/null
+++ b/src/finn/custom_op/general/trunc.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import onnx.helper as helper
+
+from finn.core.datatype import DataType
+from finn.custom_op.base import CustomOp
+from finn.custom_op.general.quant import resolve_rounding_mode
+
+
+def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
+ # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
+
+ # Scaling
+ y = inp_tensor / scale
+ y = y + zeropt
+ # Rounding
+ y = np.round(y)
+ # Truncate
+ trunc_bit_width = input_bit_width - output_bit_width
+ trunc_scale = 2.0 ** trunc_bit_width
+ y = y / trunc_scale
+
+ # To int
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y = rounding_fx(y)
+
+ # Rescale
+ y = y - zeropt
+ y = y * scale
+
+ return y
+
+
+class Trunc(CustomOp):
+ """Generic truncation operation for QONNX. Takes four inputs:
+ - input tensor to truncate
+ - the scale
+ - the zero-point
+ - the truncation bit-width
+
+ The output is a tensor of the same shape as the input tensor, with truncated
+ values.
+ """
+
+ def get_nodeattr_types(self):
+ return {
+ # The rounding mode, which is used for the trunc function
+ "rounding_mode": ("s", True, "FLOOR"),
+ }
+
+ def make_shape_compatible_op(self, model):
+ node = self.onnx_node
+ return helper.make_node("Identity", [node.input[0]], [node.output[0]])
+
+ def infer_node_datatype(self, model):
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ zeropt = context[node.input[2]]
+ input_bit_width = context[node.input[3]]
+ output_bit_width = context[node.input[4]]
+ # save attributes
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ # calculate output
+ ret = trunc(
+ inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode
+ )
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
diff --git a/src/finn/transformation/extract_conv_bias.py b/src/finn/transformation/extract_conv_bias.py
new file mode 100644
index 0000000..9dc0f58
--- /dev/null
+++ b/src/finn/transformation/extract_conv_bias.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2021, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import warnings
+from onnx import TensorProto, helper
+
+from finn.transformation.base import Transformation
+
+
+class ExtractBiasFromConv(Transformation):
+ """
+ Extracts the (optional) Bias from a Conv node and inserts it behind the
+ Conv node as an Add node.
+ """
+
+ def apply(self, model):
+ graph = model.graph
+ node_ind = 0
+ for n in graph.node:
+ node_ind += 1
+ if n.op_type == "Conv":
+ # Check if the node has a bias input
+ if len(n.input) > 2:
+ # Extract bias
+ bias = model.get_initializer(n.input[2])
+ if bias is None:
+ warnings.warn(
+ f"Could not extract bias from Conv node {n}, "
+ f"due to missing static initialization."
+ )
+ continue
+
+ # Insert bias as Add node behind the Conv node
+ out_shape = model.get_tensor_shape(n.output[0])
+ # Reshape bias tensor
+ add_shape = [1] * len(out_shape)
+ # ToDo: this must change to "add_shape[-1] = bias.shape[0]" when
+ # the channels last layout comes around.
+ add_shape[1] = bias.shape[0]
+ model.set_initializer(n.input[2], bias.reshape(add_shape))
+
+ act_add_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ out_shape,
+ )
+ graph.value_info.append(act_add_tensor)
+
+ add_node = helper.make_node(
+ "Add",
+ [act_add_tensor.name, n.input[2]],
+ [n.output[0]],
+ )
+ graph.node.insert(node_ind, add_node)
+
+ # Repoint Conv output and remove bias tensor
+ n.output[0] = act_add_tensor.name
+ n.input.remove(n.input[2])
+
+ return model, True
+
+ return model, False
diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py
index 814eca4..f4dbf01 100644
--- a/src/finn/transformation/fold_constants.py
+++ b/src/finn/transformation/fold_constants.py
@@ -33,7 +33,11 @@
class FoldConstants(Transformation):
"""Replace the output of a node with const-only inputs with a precomputed
- result."""
+ result. Skip any op types given in exclude_op_types."""
+
+ def __init__(self, exclude_op_types=["Quant", "BipolarQuant"]):
+ super().__init__()
+ self.exclude_op_types = exclude_op_types
def apply(self, model):
graph = model.graph
@@ -48,7 +52,8 @@ def apply(self, model):
is_all_constant_inputs = len(node_inp_dyn) == 0
ishape = model.get_tensor_shape(n.input[0])
is_const_shape = (n.op_type == "Shape") and (ishape is not None)
- if is_all_constant_inputs or is_const_shape:
+ exclude = n.op_type in self.exclude_op_types
+ if (is_all_constant_inputs or is_const_shape) and not exclude:
# this node has no dynamic inputs, only constant ones -- so we can
# do constant folding.
oxe.execute_node(n, execution_context, graph)
diff --git a/src/finn/transformation/gemm_to_matmul.py b/src/finn/transformation/gemm_to_matmul.py
new file mode 100644
index 0000000..309108c
--- /dev/null
+++ b/src/finn/transformation/gemm_to_matmul.py
@@ -0,0 +1,217 @@
+# Copyright (c) 2021, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import numpy as np
+import warnings
+from onnx import TensorProto, helper
+
+from finn.core.datatype import DataType
+from finn.transformation.base import Transformation
+from finn.transformation.remove import RemoveIdentityOps
+from finn.util.basic import get_by_name
+
+
+class GemmToMatMul(Transformation):
+ """
+ Converts Gemm nodes into a MatMul and an Add node.
+ This transformation is built to support version 9 of the Gemm node, as
+ documented here: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Gemm-9
+ However, earlier and later versions of the node are likely to work as well.
+ Explicitly not supported is the optionality of input C in versions >=11 and
+ the broadcast attribute of versions <=6.
+ """
+
+ def apply(self, model):
+ graph = model.graph
+ node_ind = 0
+ for n in graph.node:
+ node_ind += 1
+ if n.op_type == "Gemm":
+ # Check for correct ONNX version
+ model_onnx_version = model.model.opset_import[0].version
+ if model_onnx_version != 9:
+ warnings.warn(
+ f"The GemmToMatMul transformation only offers explicit support "
+ f"for version 9 of the Gemm node, but the ONNX version of the "
+ f"supplied model is {model_onnx_version}. "
+ f"Thus the transformation may fail or "
+ f"return incomplete results."
+ )
+ running_node_index = node_ind
+ # Transpose A?
+ transA = get_by_name(n.attribute, "transA")
+ if transA is not None and transA.i:
+ # Insert transpose node
+ shape = model.get_tensor_shape(n.input[0])
+ if shape is not None:
+ shape = tuple(reversed(shape))
+ inp_trans_out = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ shape,
+ )
+ graph.value_info.append(inp_trans_out)
+ inp_trans_node = helper.make_node(
+ "Transpose", [n.input[0]], [inp_trans_out.name]
+ )
+ graph.node.insert(running_node_index, inp_trans_node)
+ running_node_index += 1
+ dt = model.get_tensor_datatype(n.input[0])
+ if dt != DataType["FLOAT32"]:
+ model.set_tensor_datatype(inp_trans_out.name, dt)
+
+ n.input[0] = inp_trans_out.name
+
+ # Transpose B?
+ transB = get_by_name(n.attribute, "transB")
+ if transB is not None and transB.i:
+ # Insert transpose node
+ shape = model.get_tensor_shape(n.input[1])
+ if shape is not None:
+ shape = tuple(reversed(shape))
+ inp_trans_out = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ shape,
+ )
+ graph.value_info.append(inp_trans_out)
+ inp_trans_node = helper.make_node(
+ "Transpose", [n.input[1]], [inp_trans_out.name]
+ )
+ graph.node.insert(running_node_index, inp_trans_node)
+ running_node_index += 1
+ # Copy over the datatype
+ dt = model.get_tensor_datatype(n.input[1])
+ if dt != DataType["FLOAT32"]:
+ model.set_tensor_datatype(inp_trans_out.name, dt)
+
+ n.input[1] = inp_trans_out.name
+
+ # Insert MatMul: A * B
+ matMul_node = helper.make_node(
+ "MatMul", [n.input[0], n.input[1]], [n.output[0]]
+ )
+ graph.node.insert(running_node_index, matMul_node)
+ matMul_node = graph.node[running_node_index]
+ running_node_index += 1
+
+ # Insert Mul: (A*B) * alpha
+ alpha = get_by_name(n.attribute, "alpha")
+ if alpha is None:
+ alpha = np.array(1.0)
+ else:
+ alpha = np.array(alpha.f)
+ mul_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ None,
+ )
+ graph.value_info.append(mul_tensor)
+ model.set_initializer(mul_tensor.name, alpha)
+
+ A_shape = model.get_tensor_shape(n.input[0])
+ B_shape = model.get_tensor_shape(n.input[1])
+ if A_shape is not None and B_shape is not None:
+ shape = [A_shape[0], B_shape[1]]
+ else:
+ shape = None
+ act_mul_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ shape,
+ )
+ graph.value_info.append(act_mul_tensor)
+ mul_node = helper.make_node(
+ "Mul",
+ [act_mul_tensor.name, mul_tensor.name],
+ [n.output[0]],
+ )
+ graph.node.insert(running_node_index, mul_node)
+ mul_node_main_branch = graph.node[running_node_index]
+ running_node_index += 1
+ matMul_node.output[0] = act_mul_tensor.name
+
+ # Other branch: Insert Mul: beta * C
+ beta = get_by_name(n.attribute, "beta")
+ if beta is None:
+ beta = np.array(1.0)
+ else:
+ beta = np.array(beta.f)
+ mul_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ None,
+ )
+ graph.value_info.append(mul_tensor)
+ model.set_initializer(mul_tensor.name, beta)
+
+ C_shape = model.get_tensor_shape(n.input[2])
+ act_mul_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ C_shape,
+ )
+ graph.value_info.append(act_mul_tensor)
+ mul_node = helper.make_node(
+ "Mul",
+ [n.input[2], mul_tensor.name],
+ [act_mul_tensor.name],
+ )
+ graph.node.insert(running_node_index, mul_node)
+ running_node_index += 1
+ dt = model.get_tensor_datatype(n.input[2])
+ if dt != DataType["FLOAT32"]:
+ model.set_tensor_datatype(act_mul_tensor.name, dt)
+ n.input[2] = act_mul_tensor.name
+
+ # Insert Add: ((A*B) * alpha) + (beta * C)
+ shape = model.get_tensor_shape(mul_node_main_branch.input[0])
+ act_add_tensor = helper.make_tensor_value_info(
+ model.make_new_valueinfo_name(),
+ TensorProto.FLOAT,
+ shape,
+ )
+ graph.value_info.append(act_add_tensor)
+ mul_node_main_branch.output[0] = act_add_tensor.name
+ add_node = helper.make_node(
+ "Add",
+ [act_add_tensor.name, n.input[2]],
+ [n.output[0]],
+ )
+
+ graph.node.insert(running_node_index, add_node)
+ running_node_index += 1
+
+ # Delete Gemm node
+ graph.node.remove(n)
+
+ # Remove potential unity multiplications from alpha and beta attributes
+ model = model.transform(RemoveIdentityOps())
+
+ return model, True
+
+ return model, False
diff --git a/src/finn/transformation/remove.py b/src/finn/transformation/remove.py
new file mode 100644
index 0000000..e976cae
--- /dev/null
+++ b/src/finn/transformation/remove.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2020, Xilinx
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+import numpy as np
+
+from finn.transformation.base import Transformation
+from finn.transformation.infer_shapes import InferShapes
+from finn.util.basic import get_by_name
+
+
+def remove_node_and_rewire(model, node):
+ producer = model.find_producer(node.input[0])
+ if producer is not None:
+ # wire output tensor to
+ # output of producer node
+ producer.output[0] = node.output[0]
+ else:
+ # node is first in graph
+ successors = model.find_direct_successors(node)
+ assert successors is not None, "Whole graph is one node."
+ for succ in successors:
+ for i, s_inp in enumerate(succ.input):
+ if s_inp == node.output[0]:
+ # rewire successor's input directly to graph input
+ succ.input[i] = node.input[0]
+ # remove node
+ model.graph.node.remove(node)
+
+
+class RemoveIdentityOps(Transformation):
+ """Remove identity ops like Add/Sub with zero or Mul/Div with one. A tolerance
+ value (defaults to 1e-05) can be specified during init for the comparison
+ to zero/one."""
+
+ def __init__(self, atol=1e-05):
+ super().__init__()
+ self.atol = atol
+
+ def apply(self, model):
+ graph = model.graph
+ node_ind = 0
+ graph_modified = False
+ for n in graph.node:
+ node_ind += 1
+ if (
+ n.op_type in ["Add", "Sub"]
+ and not model.is_fork_node(n)
+ and not model.is_join_node(n)
+ ):
+ A = model.get_initializer(n.input[1])
+ if (
+ A is not None
+ and np.isclose(A, np.zeros_like(A), atol=self.atol).all()
+ ):
+ remove_node_and_rewire(model, n)
+ graph_modified = True
+ break
+
+ elif (
+ n.op_type in ["Mul", "Div"]
+ and not model.is_fork_node(n)
+ and not model.is_join_node(n)
+ ):
+ A = model.get_initializer(n.input[1])
+ if (
+ A is not None
+ and np.isclose(A, np.ones_like(A), atol=self.atol).all()
+ ):
+ remove_node_and_rewire(model, n)
+ graph_modified = True
+ break
+ elif (
+ n.op_type == "Pad"
+ and not model.is_fork_node(n)
+ and not model.is_join_node(n)
+ ):
+ pads = get_by_name(n.attribute, "pads")
+ pads = np.asarray(pads.ints)
+ if (pads == 0).all():
+ remove_node_and_rewire(model, n)
+ graph_modified = True
+ break
+ model = model.transform(InferShapes())
+ return (model, graph_modified)
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index 68ac290..9f8b9ff 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -71,7 +71,7 @@
def is_finn_op(op_type):
"Return whether given op_type string is a FINN custom op"
- return op_type.startswith("finn")
+ return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op")
def get_rtlsim_trace_depth():
diff --git a/src/finn/util/inference_cost.py b/src/finn/util/inference_cost.py
new file mode 100644
index 0000000..0ebcec9
--- /dev/null
+++ b/src/finn/util/inference_cost.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import clize
+import json
+
+import finn.analysis.inference_cost as infca
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import (
+ GiveReadableTensorNames,
+ GiveUniqueNodeNames,
+ GiveUniqueParameterTensors,
+ RemoveStaticGraphInputs,
+ RemoveUnusedTensors,
+)
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
+
+
+def compute_bops(inf_cost_dict):
+ total_bops = 0.0
+ for (k, v) in inf_cost_dict.items():
+ if k.startswith("op_mac"):
+ comps = k.split("_")
+ dt1 = DataType[comps[2]]
+ dt2 = DataType[comps[3]]
+ total_bops += dt1.bitwidth() * dt2.bitwidth() * v
+ return total_bops
+
+
+def compute_mem_bits(inf_cost_dict, filter_string="mem_w"):
+ total_mem_bits = 0.0
+ for (k, v) in inf_cost_dict.items():
+ if k.startswith(filter_string):
+ comps = k.split("_")
+ dt = DataType[comps[2]]
+ total_mem_bits += dt.bitwidth() * v
+ return total_mem_bits
+
+
+def inference_cost(
+ model_filename,
+ *,
+ output_json=None,
+ output_onnx=None,
+ preprocess=True,
+ discount_sparsity=True
+):
+ """Print the inference cost estimate metric for given ONNX model.
+ Supports the Quant op for weight/activation quantization.
+
+ :param model_filename: Filename for ONNX model
+ :param output_json: Optional JSON filename to save the inference cost dict
+ :param output_onnx: Optional ONNX filename to save the final model after any
+ preprocessing
+ :param preprocess: If set, run preprocessing steps such as shape inference,
+ datatype inference and constant folding. Strongly recommended.
+ :param discount_sparsity: If set, will discount op cost of MAC ops with a
+ constant zero weight, and the mem cost of constant zero weights.
+ """
+ print("Inference cost for " + model_filename)
+ model = ModelWrapper(model_filename)
+ if preprocess:
+ qnt_nodes = model.get_nodes_by_op_type("Quant")
+ for qnt_node in qnt_nodes:
+ qnt_node.domain = "finn.custom_op.general"
+ model = model.transform(InferShapes())
+ model = model.transform(GiveUniqueParameterTensors())
+ model = model.transform(InferDataTypes())
+ model = model.transform(FoldConstants())
+ model = model.transform(RemoveUnusedTensors())
+ model = model.transform(RemoveStaticGraphInputs())
+ model = model.transform(InferDataTypes())
+ model = model.transform(GiveUniqueNodeNames())
+ model = model.transform(GiveReadableTensorNames())
+ if output_onnx is not None:
+ model.save(output_onnx)
+ ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity))
+ bops = compute_bops(ret)
+ mem_w_bits = compute_mem_bits(ret, "mem_w")
+ mem_o_bits = compute_mem_bits(ret, "mem_o")
+ ret["total_bops"] = bops
+ ret["total_mem_w_bits"] = mem_w_bits
+ ret["total_mem_o_bits"] = mem_o_bits
+
+ if "unsupported" in ret:
+ ret["unsupported"] = str(ret["unsupported"])
+ print(json.dumps(ret, sort_keys=True, indent=2))
+
+ if output_json is not None:
+ with open(output_json, "w") as f:
+ json.dump(ret, f, sort_keys=True, indent=2)
+
+
+def main():
+ clize.run(inference_cost)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py
new file mode 100644
index 0000000..b7584be
--- /dev/null
+++ b/tests/transformation/test_remove_identity_ops.py
@@ -0,0 +1,95 @@
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.remove import RemoveIdentityOps
+from finn.util.basic import gen_finn_dt_tensor
+
+
+def insert_identity_op(model, op, as_first_node, approx):
+ if approx:
+ zero_val = 0.000001
+ one_val = 0.999999
+ else:
+ zero_val = 0.0
+ one_val = 1.0
+ if op in ["Add", "Sub"]:
+ val = np.asarray([zero_val], dtype=np.float32)
+ elif op in ["Mul", "Div"]:
+ val = np.asarray([one_val], dtype=np.float32)
+ else:
+ return
+
+ graph = model.graph
+ if as_first_node:
+ identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
+ graph.node.insert(0, identity_node)
+ graph.node[1].input[0] = "ident_out"
+ else:
+ identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
+ graph.node.insert(3, identity_node)
+ graph.node[-1].input[0] = "ident_out"
+ model.set_initializer("value", val)
+
+ return model
+
+
+# identity operations to be inserted
+@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
+@pytest.mark.parametrize("approx", [False, True])
+@pytest.mark.parametrize("as_first_node", [False, True])
+def test_remove_identity_ops(op, as_first_node, approx):
+
+ # set up onnx model
+ inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
+ mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
+ shape = helper.make_tensor_value_info("shape", TensorProto.FLOAT, [2])
+ div = helper.make_tensor_value_info("div", TensorProto.FLOAT, [])
+ matmul = helper.make_tensor_value_info("matmul", TensorProto.FLOAT, [4, 2])
+ outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, 2])
+
+ mul_node = helper.make_node("Mul", ["inp", "mul"], ["mul_out"])
+ reshape_node = helper.make_node("Reshape", ["mul_out", "shape"], ["reshape_out"])
+ div_node = helper.make_node("Div", ["reshape_out", "div"], ["div_out"])
+ matmul_node = helper.make_node("MatMul", ["div_out", "matmul"], ["outp"])
+
+ graph = helper.make_graph(
+ nodes=[mul_node, reshape_node, div_node, matmul_node],
+ name="identity-graph",
+ inputs=[inp],
+ outputs=[outp],
+ value_info=[mul, shape, div, matmul],
+ )
+
+ model = helper.make_model(graph, producer_name="mulpastconv-model")
+ model = ModelWrapper(model)
+ inp_values = gen_finn_dt_tensor(DataType["INT2"], [1, 4, 1, 1])
+ mul_values = np.random.uniform(low=0.1, high=0.99, size=(1)).astype(np.float32)
+ shape_values = np.asarray([1, -1], dtype=np.int64)
+ div_values = np.random.uniform(low=0.1, high=0.99, size=(1)).astype(np.float32)
+ matmul_values = gen_finn_dt_tensor(DataType["INT2"], [4, 2])
+ model.set_initializer("mul", mul_values)
+ model.set_initializer("shape", shape_values)
+ model.set_initializer("div", div_values)
+ model.set_initializer("matmul", matmul_values)
+ insert_identity_op(model, op, as_first_node, approx)
+ model = model.transform(InferShapes())
+ model = model.transform(InferDataTypes())
+ idict = {"inp": inp_values}
+ odict = oxe.execute_onnx(model, idict)
+ out_before = odict["outp"]
+ num_of_nodes_before = len(model.graph.node)
+
+ model = model.transform(RemoveIdentityOps())
+ num_of_nodes_after = len(model.graph.node)
+ assert num_of_nodes_before - 1 == num_of_nodes_after
+
+ odict = oxe.execute_onnx(model, idict)
+ out_after = odict["outp"]
+ assert np.isclose(out_before, out_after, atol=1e-3).all()