diff --git a/src/finn/transformation/create_generic_partitions.py b/src/finn/transformation/create_generic_partitions.py index 5e4bb80..67da854 100755 --- a/src/finn/transformation/create_generic_partitions.py +++ b/src/finn/transformation/create_generic_partitions.py @@ -124,7 +124,7 @@ def apply(self, model): assert ( self.partitioning(node) != partition_id ), """cycle-free graph violated: partition depends on itself""" - print(node) + # print(node) predecessors = model.find_direct_predecessors(node) if predecessors is not None: next_to_check.extend(predecessors) @@ -141,11 +141,25 @@ def apply(self, model): for o in p_out_vi: p_model.graph.output.append(o) - # remove redundant input value_info entries + # remove redundant input and output value_info entries for i in p_in_vi: - if i in p_model.graph.value_info: + # the tensor can be both an input and value_info, so we also have to + # ensure that the tensor is not a relevant value_info before removing + if ( + i in p_model.graph.value_info + and p_model.find_producer(i.name) is None + ): p_model.graph.value_info.remove(i) + for o in p_out_vi: + # the tensor can both an output and value_info, so we also have to + # ensure that the tensor is not a relevant value_info before removing + if ( + o in p_model.graph.value_info + and p_model.find_consumers(o.name) is None + ): + p_model.graph.value_info.remove(o) + # save partition model p_model_filename = ( self.partition_dir + "/partition_" + str(partition_id) + ".onnx"