diff --git a/src/finn/custom_op/general/im2col.py b/src/finn/custom_op/general/im2col.py index e76c613..3f8d118 100644 --- a/src/finn/custom_op/general/im2col.py +++ b/src/finn/custom_op/general/im2col.py @@ -1,5 +1,5 @@ import numpy as np -from onnx import TensorProto, helper +from onnx import helper import finn.util.basic as util from finn.core.datatype import DataType @@ -184,20 +184,11 @@ def make_shape_compatible_op(self, model): ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad_h, dilation_h) ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad_w, dilation_w) - # implement tensor with correct shape - values = np.random.randn(1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch).astype( - np.float32 - ) return helper.make_node( - "Constant", + "RandomNormal", inputs=[], outputs=[self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), + shape=[1, ofm_dim_h, ofm_dim_w, k_h * k_w * ifm_ch], ) def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/general/maxpoolnhwc.py b/src/finn/custom_op/general/maxpoolnhwc.py index ad40f38..23ffdce 100644 --- a/src/finn/custom_op/general/maxpoolnhwc.py +++ b/src/finn/custom_op/general/maxpoolnhwc.py @@ -63,18 +63,11 @@ def make_shape_compatible_op(self, model): ho = compute_pool_output_dim(hi, kernel_shape[0], strides[0], pads[0]) wo = compute_pool_output_dim(wi, kernel_shape[1], strides[1], pads[2]) oshape = (n, ho, wo, c) - # implement tensor with correct shape - values = np.random.randn(*oshape).astype(np.float32) return helper.make_node( - "Constant", + "RandomNormal", inputs=[], outputs=[self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), + shape=list(oshape), ) def infer_node_datatype(self, model): diff --git a/src/finn/custom_op/general/quantavgpool2d.py b/src/finn/custom_op/general/quantavgpool2d.py index 99a9d43..014f7a2 100644 --- a/src/finn/custom_op/general/quantavgpool2d.py +++ b/src/finn/custom_op/general/quantavgpool2d.py @@ -71,18 +71,11 @@ def make_shape_compatible_op(self, model): ho = compute_pool_output_dim(hi, k, s) wo = compute_pool_output_dim(wi, k, s) oshape = (n, ho, wo, c) - # implement tensor with correct shape - values = np.random.randn(*oshape).astype(np.float32) return helper.make_node( - "Constant", + "RandomNormal", inputs=[], outputs=[node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), + shape=list(oshape), ) else: