diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index d29e56c7..eab3e434 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -11,13 +11,16 @@ jobs: deploy: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.x' + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index a81c8ee1..e4835320 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ dmypy.json # Cython debug symbols cython_debug/ + +# IDE stuff +.vscode diff --git a/setup.cfg b/setup.cfg index 50d574eb..95ba8daf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,10 @@ install_requires = onnx==1.11.0 onnxruntime==1.11.1 toposort>=1.5.0 + tf2onnx==1.9.2 + QKeras==0.9.0 + tensorflow==2.7.0 + numpy [options.packages.find] diff --git a/src/qonnx/__init__.py b/src/qonnx/__init__.py index e69de29b..01d009c8 100644 --- a/src/qonnx/__init__.py +++ b/src/qonnx/__init__.py @@ -0,0 +1 @@ +from qonnx import converters # noqa: F401 diff --git a/src/qonnx/converters/__init__.py b/src/qonnx/converters/__init__.py new file mode 100644 index 00000000..8b113eef --- /dev/null +++ b/src/qonnx/converters/__init__.py @@ -0,0 +1 @@ +from .keras import from_keras # noqa: F401 diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py new file mode 100644 index 00000000..3c4adc49 --- /dev/null +++ b/src/qonnx/converters/keras.py @@ -0,0 +1,235 @@ +import onnx +import tensorflow as tf +import tf2onnx +from collections import OrderedDict +from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.cleanup import cleanup_model + +from .qkeras.onnx import get_qkeras_onnx_handlers +from .qkeras.qlayers import extract_quantizers_from_layer + +_unsupported_layers = [ + # These require some extra work + "QBatchNormalization", + "QConv2DBatchnorm", + "QDepthwiseConv2DBatchnorm", +] + +# Skip remove_identity optimizer +del tf2onnx.optimizer._optimizers["remove_identity"] + + +def add_value_info_for_constants(model: onnx.ModelProto): + """ + Currently onnx.shape_inference doesn't use the shape of initializers, so add + that info explicitly as ValueInfoProtos. + Mutates the model. + Args: + model: The ModelProto to update. + """ + # All (top-level) constants will have ValueInfos before IRv4 as they are all inputs + if model.ir_version < 4: + return model + + def add_const_value_infos_to_graph(graph: onnx.GraphProto): + inputs = {i.name for i in graph.input} + existing_info = {vi.name: vi for vi in graph.value_info} + for init in graph.initializer: + # Check it really is a constant, not an input + if init.name in inputs: + continue + + # The details we want to add + elem_type = init.data_type + shape = init.dims + + # Get existing or create new value info for this constant + vi = existing_info.get(init.name) + if vi is None: + vi = graph.value_info.add() + vi.name = init.name + + # Even though it would be weird, we will not overwrite info even if it doesn't match + tt = vi.type.tensor_type + if tt.elem_type == onnx.TensorProto.UNDEFINED: + tt.elem_type = elem_type + if not tt.HasField("shape"): + # Ensure we set an empty list if the const is scalar (zero dims) + tt.shape.dim.extend([]) + for dim in shape: + tt.shape.dim.add().dim_value = dim + + # Handle subgraphs + for node in graph.node: + for attr in node.attribute: + # Ref attrs refer to other attrs, so we don't need to do anything + if attr.ref_attr_name != "": + continue + + if attr.type == onnx.AttributeProto.GRAPH: + add_const_value_infos_to_graph(attr.g) + if attr.type == onnx.AttributeProto.GRAPHS: + for g in attr.graphs: + add_const_value_infos_to_graph(g) + + add_const_value_infos_to_graph(model.graph) + return model + + +def _is_qkeras_model(model): + def iterate_model(model): + for layer in model.layers: + if isinstance(layer, tf.keras.Model): + found_qkeras = iterate_model(layer) + if found_qkeras: + return True + elif layer.__class__.__name__ in QKERAS_LAYERS: + return True + + return False + + return iterate_model(model) + + +def _check_supported_layers(model): + def iterate_model(model): + for layer in model.layers: + if isinstance(layer, tf.keras.Model): + iterate_model(layer) + elif layer.__class__.__name__ in _unsupported_layers: + raise Exception("Currently unsupported layer found in QKeras model: {}".format(layer.__class__.__name__)) + + iterate_model(model) + + +def _strip_qkeras_model(model): + quantizers = OrderedDict() + + def extract_quantizers(layer): + keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer) + if layer_quantizers: + layer_quantizers = { + k: None if v == "None" else v for k, v in layer_quantizers.items() + } # Get rid of 'None' strings + layer_quantizers["input"] = layer.input.name + quantizers[layer.name] = layer_quantizers + + layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None) + if layer_class is None: + raise Exception("Cannot create Keras layer from QKeras class {}".format(keras_cls_name)) + + return layer_class.from_config(layer_cfg) + + stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers) + stripped_model.set_weights(model.get_weights()) + return stripped_model, quantizers + + +def _convert_quantizers_to_nodes(onnx_model, quantizers_dict): + + for node_name, quantizers in quantizers_dict.items(): + print(node_name, quantizers) + + for n in onnx_model.graph.node: + print(n) + + return onnx_model.model + + +def from_keras( + model, + name="qkeras_to_qonnx_converted", + input_signature=None, + opset=None, + custom_ops=None, + custom_op_handlers=None, + custom_rewriter=None, + inputs_as_nchw=None, + extra_opset=None, + shape_override=None, + target=None, + large_model=False, + output_path=None, +): + """Convert a keras model to QONNX. The API follows the `from_keras` function of tf2onnx. + + Args: + model: the tf.keras model we want to convert + input_signature: a tf.TensorSpec or a numpy array defining the shape/dtype of the input + opset: the opset to be used for the ONNX model, default is the latest + custom_ops: if a model contains ops not recognized by onnx runtime, + you can tag these ops with a custom op domain so that the + runtime can still open the model. Type is a dictionary `{op name: domain}`. + target: list of workarounds applied to help certain platforms + custom_op_handlers: dictionary of custom ops handlers + custom_rewriter: list of custom graph rewriters + extra_opset: list of extra opset's, for example the opset's used by custom ops + shape_override: dict with inputs that override the shapes given by tensorflow + inputs_as_nchw: transpose inputs in list from nchw to nhwc + large_model: use the ONNX external tensor storage format + output_path: save model to output_path + + Returns: + An ONNX model_proto and an external_tensor_storage dict. + """ + + assert not large_model # TODO for now, let's focus only on models that don't store tensors externally + + if _is_qkeras_model(model): + _check_supported_layers(model) + keras_model, quantizers = _strip_qkeras_model(model) + else: + keras_model, quantizers = model, {} + + qkeras_op_handlers = get_qkeras_onnx_handlers(quantizers) + + if custom_op_handlers is not None: + qkeras_op_handlers.update(custom_op_handlers) + + model_proto, external_storage = tf2onnx.convert.from_keras( + keras_model, + input_signature=input_signature, + opset=opset, + custom_ops=qkeras_op_handlers, + custom_op_handlers=qkeras_op_handlers, + custom_rewriter=custom_rewriter, + inputs_as_nchw=inputs_as_nchw, + extra_opset=extra_opset, + shape_override=shape_override, + target=target, + large_model=large_model, + output_path=None, + ) + + onnx_model = ModelWrapper(model_proto) + # Set the first value of input/output shape to 1, currently this is set to unknown, + # because it is technically the batch size + if not (len(onnx_model.graph.input) == 1 and len(onnx_model.graph.output) == 1): + raise ValueError("Qkeras to QONNX conversion only supports models with exactly one input and output.") + inp_shape = onnx_model.get_tensor_shape(onnx_model.graph.input[0].name) + out_shape = onnx_model.get_tensor_shape(onnx_model.graph.output[0].name) + inp_shape[0] = 1 + out_shape[0] = 1 + onnx_model.set_tensor_shape(onnx_model.graph.input[0].name, inp_shape) + onnx_model.set_tensor_shape(onnx_model.graph.output[0].name, out_shape) + + # Set all Quant output tensors to float32 datatype, otherwise they are undefined and crash ONNX execution + qonnx_domain_ops = ["Quant", "Trunc", "BipolarQuant"] + for q_op_type in qonnx_domain_ops: + quant_nodes = onnx_model.get_nodes_by_op_type(q_op_type) + q_node_outputs = [qn.output[0] for qn in quant_nodes] + for tensor in onnx_model.graph.value_info: + if tensor.name in q_node_outputs: + tensor.type.tensor_type.elem_type = 1 + + onnx_model.save(f"tmp_{name}.onnx") + + onnx_model = cleanup_model(onnx_model) + onnx_model.model = add_value_info_for_constants(onnx_model.model) + + if output_path is not None: + onnx_model.save(output_path) + + return onnx_model.model, external_storage diff --git a/src/qonnx/converters/qkeras/__init__.py b/src/qonnx/converters/qkeras/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py new file mode 100644 index 00000000..1ac05d16 --- /dev/null +++ b/src/qonnx/converters/qkeras/onnx.py @@ -0,0 +1,155 @@ +import numpy as np +from tf2onnx.onnx_opset.math import DirectOp, MatMul +from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp + +from .quantizers import get_quant_params + + +def get_qkeras_onnx_handlers(all_quantizers): + return { + "Conv2D": (conv2d_handler, ["Conv2D", all_quantizers]), + "MatMul": (dense_handler, ["MatMul", all_quantizers]), + "BiasAdd": (bias_handler, ["BiasAdd", all_quantizers]), + "Relu": (relu_handler, ["Relu", all_quantizers]), + "Identity": (identity_handler, ["Identity", all_quantizers]), + } + + +def _extract_node_name(onnx_node, keras_quantizers): + onnx_name = onnx_node.name + keras_names = keras_quantizers.keys() + for keras_name in keras_names: + match = "/" + keras_name + "/" + if match in onnx_name: + return keras_name + elif "Identity" in onnx_name: + onnx_input = onnx_node.input[0] + keras_input = keras_quantizers[keras_name]["input"] + if keras_input in onnx_input: + return keras_name + + return None + + +def qlayer_handler(ctx, node, name, args): + all_quantizers = args[0] + keras_name = _extract_node_name(node, all_quantizers) + if not keras_name: + return # Not found in quantizers, nothing to do + quantizers = all_quantizers[keras_name] + + if quantizers.get("kernel_quantizer"): + weights = node.inputs[1].get_tensor_value(as_list=True) + quant_params = get_quant_params(weights, quantizers["kernel_quantizer"]) + attr = quant_params["attributes"] + input_nodes = [node.input[1]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_kernel_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + ctx.insert_new_node_on_input( + node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" + ) + + if quantizers.get("bias_quantizer") and len(node.input) == 3: + bias = node.inputs[2].get_tensor_value(as_list=True) + quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + attr = quant_params["attributes"] + input_nodes = [node.input[2]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_bias_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") + + if quantizers.get("activation"): + dtypes = [ctx.get_dtype(node.output[0])] + quant_params = get_quant_params(None, quantizers["activation"]) + attr = quant_params["attributes"] + input_nodes = [node.output[0]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_activation_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.name + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) + ctx.insert_node_on_output(quant_act_node, node.output[0]) + + +def qact_handler(ctx, node, name, args): + all_quantizers = args[0] + keras_name = _extract_node_name(node, all_quantizers) + if not keras_name: + return # Not found in quantizers, nothing to do + quantizers = all_quantizers[keras_name] + + if quantizers.get("activation"): + dtypes = [ctx.get_dtype(node.output[0])] + quant_params = get_quant_params(None, quantizers["activation"]) + attr = quant_params["attributes"] + input_nodes = [node.output[0]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_activation_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.name + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) + ctx.insert_node_on_output(quant_act_node, node.output[0]) + + +def conv2d_handler(ctx, node, name, args): + ConvOp.any_version(11, ctx, node) + qlayer_handler(ctx, node, name, args) + + +def dense_handler(ctx, node, name, args): + MatMul.version_1(ctx, node) + qlayer_handler(ctx, node, name, args) + + +def bias_handler(ctx, node, name, args): + BiasAdd.version_1(ctx, node) + + all_quantizers = args[0] + keras_name = _extract_node_name(node, all_quantizers) + if not keras_name: + return # Not found in quantizers, nothing to do + quantizers = all_quantizers[keras_name] + + if quantizers.get("bias_quantizer"): + bias = node.inputs[1].get_tensor_value(as_list=True) + quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + attr = quant_params["attributes"] + input_nodes = [node.input[1]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_activation_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") + + +def relu_handler(ctx, node, name, args): + DirectOp.version_1(ctx, node) + qact_handler(ctx, node, name, args) + + +def identity_handler(ctx, node, name, args): + DirectOp.version_1(ctx, node) + qact_handler(ctx, node, name, args) diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py new file mode 100644 index 00000000..e3dc2da5 --- /dev/null +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -0,0 +1,161 @@ +import qkeras + +# import tensorflow as tf +from qkeras.quantizers import BaseQuantizer +from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS + + +def extract_quantizers_from_layer(layer): + layer_class = layer.__class__.__name__ + if layer_class in QKERAS_LAYERS: + handler = handler_map.get(layer_class, None) + if handler: + return handler_map[layer_class](layer) + else: + return extract_generic(layer) + else: + return layer_class, layer.get_config(), None + + +def _is_keras_quantizer(quant): + try: + # If we can deserialize the quantizer, it means it belongs to qkeras + # TODO Since quantizer can be any callable, this should be more robust + quant_obj = qkeras.get_quantizer(quant) + return isinstance(quant_obj, BaseQuantizer) + except ValueError: + return False + + +def _extract_initializers(layer_cfg): + initializers = {} + for key, value in layer_cfg.items(): + if value is None: + continue + if "initializer" in key: + if value["class_name"] == "QInitializer": + # Save the quantized initializer and it's replacement + initializers[key] = (value, value["config"]["initializer"]) + + return initializers + + +def _extract_constraints(layer_cfg): + constraints = {} + for key, value in layer_cfg.items(): + if value is None: + continue + if "constraint" in key: + if value["class_name"] == "Clip": + # QKeras doesn't keep the original constraint in the config (bug?) + constraints[key] = (value, None) + + return constraints + + +_activation_map = { + "quantized_bits": "linear", + "quantized_relu": "relu", + "binary": "linear", + "ternary": "linear", + "quantized_tanh": "tanh", + "quantized_sigmoid": "sigmoid", + "quantized_po2": "linear", + "quantized_relu_po2": "relu", +} + + +def _replace_activation(quant_act): + if quant_act is not None: + for act_key, act_val in _activation_map.items(): + if act_key in quant_act: + return act_val + + return "linear" + + +def extract_qlayer(layer): + quantizers = layer.get_quantization_config() + + keras_config = layer.get_config() + + keras_config.pop("kernel_quantizer", None) + keras_config.pop("bias_quantizer", None) + keras_config.pop("kernel_range", None) + keras_config.pop("bias_range", None) + + # Check if activation is quantized + if _is_keras_quantizer(keras_config["activation"]): + keras_config["activation"] = _replace_activation(quantizers["activation"]) + else: + quantizers["activation"] = None + + quant_init = _extract_initializers(keras_config) + for key, (quant, non_quant) in quant_init.items(): + quantizers[key] = quant + keras_config[key] = non_quant + + quant_const = _extract_constraints(keras_config) + for key, (quant, non_quant) in quant_const.items(): + quantizers[key] = quant + keras_config[key] = non_quant + + return layer.__class__.__name__[1:], keras_config, quantizers + + +def extract_qact(layer): + # As of version 0.9.0, QKeras actiations store quantizer config as a plain string, not dict + # TODO complain to Hao about this inconsistency + quantizers = {"activation": layer.get_quantization_config()} + + keras_config = layer.get_config() + keras_config["activation"] = _replace_activation(quantizers["activation"]) + + return "Activation", keras_config, quantizers + + +# This is a remnant of a first attempt at creating an universal parser +# however it got too complicated so portions were extracted to separate +# parsers. In the future, this should be removed. +def extract_generic(layer): + get_quant_op = getattr(layer, "get_quantization_config", None) + if callable(get_quant_op): + quantizers = get_quant_op() + else: + quantizers = {} + + keras_cls_name = layer.__class__.__name__[1:] # Drop the 'Q' from the layer name + + layer_cfg = layer.get_config() + # Try to remove quantizers from the config + non_quant = [] + if layer.name in quantizers: + for quant_key, quant in quantizers.items(): + try: + # If we can deserialize the quantizer, it means it belongs to qkeras + qkeras.get_quantizer(quant) + layer_cfg.pop(quant_key) + except ValueError: + # Otherwise it is not a quantizer (an activation, filter config, etc) + non_quant.append(quant_key) + + for quant_key in non_quant: + quantizers.pop(quant_key) + + # TODO Put proper activation in layer config + + # TODO extract initializers and constraints + + # Also remove deprecated 'kernel_range' and 'bias_range' From QConv1D/2D + layer_cfg.pop("kernel_range", None) + layer_cfg.pop("bias_range", None) + + return keras_cls_name, layer_cfg, quantizers + + +handler_map = { + "QConv1D": extract_qlayer, + "QConv2D": extract_qlayer, + "QDense": extract_qlayer, + "QActivation": extract_qact, +} diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py new file mode 100644 index 00000000..983cc997 --- /dev/null +++ b/src/qonnx/converters/qkeras/quantizers.py @@ -0,0 +1,116 @@ +import qkeras +import six + + +def get_quant_params(tensor, qkeras_quantizer): + if isinstance(qkeras_quantizer, str): + qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer) + + return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer) + + +def _get_scale_from_alpha(tensor, quantizer): + alpha = quantizer.get_config()["alpha"] + + if alpha is None: + return 1 + elif isinstance(alpha, six.string_types): + raise Exception(f"Cannot parse alpha = {alpha}.") + return 1 + else: + return alpha + + +def _get_quantizer_scale(tensor, quantizer): + # call the quantizer on the tensor to get its scale + import numpy as np + + quantizer(np.array(tensor).astype(np.float32)) + return quantizer.scale + + +def convert_quantized_bits(tensor, quantizer): + config = quantizer.get_config() + signed = int(config["keep_negative"]) + narrow = int(config["symmetric"]) + qscale = _get_quantizer_scale(tensor, quantizer) + assert qscale == 1, "Non-unity alpha is not yet supported" + scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) + zero_point = 0 + bit_width = int(config["bits"]) + rounding_mode = "ROUND" + + settings = { + "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, + "inputs": {"scale": scale, "zero_point": zero_point, "bit_width": bit_width}, + } + return settings + + +def convert_quantized_relu(tensor, quantizer): + config = quantizer.get_config() + + signed = int(config["negative_slope"] != 0.0) + narrow = int(False) + scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) + zero_point = 0 + bit_width = int(config["bits"]) + rounding_mode = "ROUND" + + settings = { + "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, + "inputs": {"scale": scale, "zero_point": zero_point, "bit_width": bit_width}, + } + return settings + + +def convert_binary(tensor, quantizer): + signed = 1 + narrow = 1 + qscale = _get_quantizer_scale(tensor, quantizer) + assert qscale == 1, "binary - non-unity alpha is not yet supported" + scale = 1 + zero_point = 0 + bit_width = 1 + rounding_mode = "ROUND" + + settings = { + "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, + "inputs": {"scale": scale, "zero_point": zero_point, "bit_width": bit_width}, + } + return settings + + +def convert_ternary(tensor, quantizer): + config = quantizer.get_config() + signed = 1 + narrow = 1 + qscale = _get_quantizer_scale(tensor, quantizer) + assert qscale == 1, "ternary - non-unity alpha is not yet supported" + # qkeras ternary quantizer has threshold parameter to change rounding point + # here we could scale such that normal 'ROUND' op gives the same result, but doesn't work with re-scaling + t = config["threshold"] + if t is None: + ternary = qkeras.ternary() + t = ternary.default_threshold + assert t == 0.5, "ternary - only threshold 0.5 is supported" + # note that if assertions fail, Quant node is not inserted, but model is still converted + # this seems to be unexpected behavior + scale = 1.0 + zero_point = 0 + bit_width = 2 + rounding_mode = "ROUND" + + settings = { + "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, + "inputs": {"scale": scale, "zero_point": zero_point, "bit_width": bit_width}, + } + return settings + + +handler_map = { + "quantized_bits": convert_quantized_bits, + "quantized_relu": convert_quantized_relu, + "binary": convert_binary, + "ternary": convert_ternary, +} diff --git a/tests/test_keras_convert.py b/tests/test_keras_convert.py new file mode 100644 index 00000000..d0799f05 --- /dev/null +++ b/tests/test_keras_convert.py @@ -0,0 +1,290 @@ +import pytest + +import numpy as np +import onnx +import tensorflow as tf + +# from numpy.testing import assert_allclose +from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary + +# from tensorflow.keras import backend as K +from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Input +from tensorflow.keras.models import Model + +import qonnx +import qonnx.core.onnx_exec as oxe +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.infer_shapes import InferShapes + +act_quantizers = [ + quantized_relu(8), + quantized_relu(8, 4), + quantized_relu(4), + quantized_relu(4, 4), + quantized_bits(8, 4, 0, alpha=1), + quantized_bits(8, 4, 1, alpha=1), + quantized_bits(8, 8, 0, alpha=1), + quantized_bits(8, 8, 1, alpha=1), + quantized_bits(4, 4, 0, alpha=1), + quantized_bits(4, 4, 1, alpha=1), + quantized_bits(4, 0, 0, alpha=1), + quantized_bits(4, 0, 1, alpha=1), + quantized_bits(4, 2, 0, alpha=1), + quantized_bits(2, 2, 1, alpha=1), + quantized_bits(2, 1, 1, alpha=1), + ternary(alpha=1, threshold=0.5), + binary(alpha=1), +] +act_quantizers_ids = list(range(len(act_quantizers))) + + +@pytest.mark.parametrize("quantizer", act_quantizers, ids=act_quantizers_ids) +def test_qkeras_qactivation(quantizer, request): + x = x_in = Input((16), name="input") + x = QActivation(activation=quantizer, name="act_0")(x) + model = Model(inputs=[x_in], outputs=[x]) + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 16)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_qkeras_conversion", opset=9) + assert external_storage is None + model_path = f"model_test_qkeras_qactivation_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + + onnx_model = ModelWrapper(model_path) + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5) + + +# pairs of quantizers for kernel and bias +kb_quantizers = [ + (quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 0, alpha=1)), + (quantized_bits(8, 4, 1, alpha=1), quantized_bits(8, 4, 1, alpha=1)), + (quantized_bits(8, 8, 0, alpha=1), quantized_bits(8, 8, 0, alpha=1)), + (quantized_bits(8, 8, 1, alpha=1), quantized_bits(8, 8, 1, alpha=1)), + (quantized_bits(4, 4, 0, alpha=1), quantized_bits(8, 8, 0, alpha=1)), + (quantized_bits(4, 4, 1, alpha=1), quantized_bits(8, 8, 1, alpha=1)), + (quantized_bits(4, 0, 0, alpha=1), quantized_bits(8, 0, 0, alpha=1)), + (quantized_bits(4, 0, 1, alpha=1), quantized_bits(8, 0, 1, alpha=1)), + (quantized_bits(4, 2, 0, alpha=1), quantized_bits(8, 2, 0, alpha=1)), + (quantized_bits(2, 2, 1, alpha=1), quantized_bits(2, 2, 1, alpha=1)), + (quantized_bits(2, 1, 1, alpha=1), quantized_bits(2, 1, 1, alpha=1)), + (ternary(alpha=1, threshold=0.5), quantized_bits(4, 4)), + (binary(alpha=1), quantized_bits(4, 4)), +] +kb_quantizers_ids = list(range(len(kb_quantizers))) + + +@pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids) +def test_qkeras_qconv2d(quantizers, request): + kq, bq = quantizers + k_ini = tf.keras.initializers.RandomUniform(minval=kq.min(), maxval=kq.max()) + b_ini = tf.keras.initializers.RandomUniform(minval=bq.min(), maxval=bq.max()) + x = x_in = Input((28, 28, 3), name="input") + x = QConv2D( + 32, + (2, 2), + strides=(2, 2), + kernel_quantizer=kq, + bias_quantizer=bq, + kernel_initializer=k_ini, + bias_initializer=b_ini, + name="conv2d_0", + )(x) + model = Model(inputs=[x_in], outputs=[x]) + + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 28, 28, 3)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_qkeras_conversion", opset=9) + assert external_storage is None + model_path = f"model_test_qkeras_qconv2d_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + + onnx_model = ModelWrapper(model_path) + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids) +def test_qkeras_qdense(quantizers, request): + kq, bq = quantizers + # Initialize the kernel & bias to RandonUniform within the range of the quantizers + k_ini = tf.keras.initializers.RandomUniform(minval=kq.min(), maxval=kq.max()) + b_ini = tf.keras.initializers.RandomUniform(minval=bq.min(), maxval=bq.max()) + x = x_in = Input((16), name="input") + x = QDense( + 32, + kernel_quantizer=kq, + bias_quantizer=bq, + kernel_initializer=k_ini, + bias_initializer=b_ini, + name="dense_0", + )(x) + model = Model(inputs=[x_in], outputs=[x]) + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 16)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_qkeras_conversion", opset=9) + assert external_storage is None + model_path = f"model_test_qkeras_qdense_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + + onnx_model = ModelWrapper(model_path) + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4) + + +def test_keras_conv2d_conversion(): + x = x_in = Input((28, 28, 1), name="input") + x = Conv2D(32, (2, 2), strides=(2, 2), name="conv2d_0_m")(x) + x = Activation("relu", name="act0_m")(x) + x = Conv2D(64, (3, 3), strides=(2, 2), name="conv2d_1_m", activation="relu")(x) + x = Conv2D(64, (2, 2), strides=(2, 2), name="conv2d_2_m")(x) + x = Activation("relu", name="act2_m")(x) + x = Flatten(name="flatten")(x) + x = Dense(10, bias_initializer="ones", name="dense")(x) + x = Activation("softmax", name="softmax")(x) + + model = Model(inputs=[x_in], outputs=[x]) + + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 28, 28, 1)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_keras_conv2d_conversion") + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + + assert external_storage is None + onnx.save(onnx_model, "model_test_keras_conv2d_conversion.onnx") + + onnx_model = ModelWrapper("model_test_keras_conv2d_conversion.onnx") + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5) + + +def test_keras_dense_conversion(): + ini = tf.keras.initializers.RandomUniform(minval=-1.0, maxval=1.0) + x = x_in = Input((15), name="input") + x = Dense(10, kernel_initializer=ini, bias_initializer=ini, name="dense1")(x) + x = Activation("relu", name="act0_m")(x) + x = Dense(10, kernel_initializer=ini, bias_initializer=ini, activation="relu", name="dense2")(x) + x = Dense(10, kernel_initializer=ini, bias_initializer=ini, name="dense3")(x) + x = Activation("softmax", name="softmax")(x) + + model = Model(inputs=[x_in], outputs=[x]) + + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 15)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_keras_dense_conversion") + assert external_storage is None + onnx.save(onnx_model, "model_test_keras_dense_conversion.onnx") + + onnx_model = ModelWrapper("model_test_keras_dense_conversion.onnx") + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5) + + +def test_qkeras_qconv2d_conversion(): + ini = tf.keras.initializers.RandomUniform(minval=-1.0, maxval=1.0) + x = x_in = Input((28, 28, 1), name="input") + x = QConv2D( + 32, + (2, 2), + strides=(2, 2), + kernel_quantizer=binary(alpha=1.0), + bias_quantizer=quantized_bits(4, 0, 1), + activation=quantized_bits(6, 2, 1, alpha=1.0), + kernel_initializer=ini, + bias_initializer=ini, + name="conv2d_0_m", + )(x) + x = QActivation("quantized_relu(6,2,1)", name="act0_m")(x) + x = QConv2D( + 64, + (3, 3), + strides=(2, 2), + kernel_quantizer=ternary(alpha=1.0), + bias_quantizer=quantized_bits(4, 0, 1), + kernel_initializer=ini, + bias_initializer=ini, + name="conv2d_1_m", + activation=quantized_relu(6, 3, 1), + )(x) + x = QConv2D( + 64, + (2, 2), + strides=(2, 2), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), + kernel_initializer=ini, + use_bias=False, # Lets try this one without bias to see if that trips up the converter + name="conv2d_2_m", + )(x) + x = QActivation("quantized_relu(6,4,1)", name="act2_m")(x) + x = QConv2D( + 64, + (2, 2), + strides=(2, 2), + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), + kernel_initializer=ini, + use_bias=False, # Lets try this one without bias to see if that trips up the converter + name="conv2d_3_m", + )(x) + x = QActivation("quantized_bits(4,4,0,alpha=1)", name="act3_m")(x) + x = Flatten(name="flatten")(x) + x = QDense( + 10, + kernel_quantizer=quantized_bits(6, 2, 1, alpha=1.0), + bias_quantizer=quantized_bits(4, 0, 1), + kernel_initializer=ini, + bias_initializer=ini, + name="dense", + )(x) + x = Activation("softmax", name="softmax")(x) + + model = Model(inputs=[x_in], outputs=[x]) + + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 28, 28, 1)).astype(dtype=np.float32) + y_qkeras = model.predict(x_test) + + onnx_model, external_storage = qonnx.converters.from_keras(model, "test_qkeras_qconv2d_conversion", opset=9) + assert external_storage is None + onnx.save(onnx_model, "model_test_qkeras_qconv2d_conversion.onnx") + + onnx_model = ModelWrapper("model_test_qkeras_qconv2d_conversion.onnx") + onnx_model = onnx_model.transform(InferShapes()) + + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + + np.testing.assert_array_equal(y_qkeras, y_qonnx) + + +if __name__ == "__main__": + pytest.main([__file__])