From a092b47c59e6007d4488e076f5e0422fcf68bef5 Mon Sep 17 00:00:00 2001 From: Christoph Kirsch Date: Wed, 26 Feb 2025 19:54:44 +0100 Subject: [PATCH] Introducing Algebraic Normal Forms --- tools/bitme.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 3 deletions(-) diff --git a/tools/bitme.py b/tools/bitme.py index 1918dbf4..aa815779 100755 --- a/tools/bitme.py +++ b/tools/bitme.py @@ -758,6 +758,112 @@ def get_expression(self): arg_line.comment, arg_line.line_no) return disjunction_line +class ANF(Inputs): + total_number_of_values = 0 + + def __init__(self, var_line): + self.var_line = var_line + self.values = {} + + def __str__(self): + string = "" + for value in self.values: + if string: + string += " X " + string += f"{value}" + if self.values[value] is not Constant.true: + string += f" & {self.values[value]}" + return f"{{{string}}}" + + def __hash__(self): + return id(self) + + def __eq__(self, inputs): + return type(self) is type(inputs) and self.var_line is inputs.var_line and self.values == inputs.values + + def set_value(self, value, constraint): + assert value not in self.values + if constraint is not Constant.false: + self.values[value] = constraint + ANF.total_number_of_values += 1 + return self + + def reduce(inputs): + if isinstance(inputs, ANF): + if not inputs.values: + return Constant.false + elif len(inputs.values) == 2**inputs.var_line.sid_line.size: + for constraint in inputs.values.values(): + if constraint is not Constant.true: + return inputs + return Constant.true + return inputs + + def conjunction(self, inputs): + assert isinstance(inputs, ANF) + if self.var_line > inputs.var_line: + return inputs.conjunction(self) + else: + result = ANF(self.var_line) + if self.var_line < inputs.var_line: + for value in self.values: + result.set_value(value, Inputs.conjunction(self.values[value], inputs)) + else: + assert self.var_line is inputs.var_line + for value in self.values: + if value in inputs.values: + result.set_value(value, + Inputs.conjunction(self.values[value], inputs.values[value])) + return ANF.reduce(result) + + def disjunction(self, inputs): + assert isinstance(inputs, ANF) + if self.var_line > inputs.var_line: + return inputs.disjunction(self) + else: + result = ANF(self.var_line) + if self.var_line < inputs.var_line: + for value in range(2**self.var_line.sid_line.size): + if value in self.values: + result.set_value(value, + Inputs.disjunction(self.values[value], inputs)) + else: + result.set_value(value, inputs) + else: + assert self.var_line is inputs.var_line + for value in self.values: + if value in inputs.values: + result.set_value(value, + Inputs.disjunction(self.values[value], inputs.values[value])) + else: + result.set_value(value, self.values[value]) + for value in inputs.values: + if value not in self.values: + result.set_value(value, inputs.values[value]) + return ANF.reduce(result) + + def get_expression(self): + ANF_line = Constant.true + for value in self.values: + constraint_line = Comparison(next_nid(), OP_EQ, Bool.boolean, + self.var_line, + Constd(next_nid(), self.var_line.sid_line, value, + self.var_line.comment, self.var_line.line_no), + self.var_line.comment, self.var_line.line_no) + if self.values[value] is not Constant.true: + constraint_line = Logical(next_nid(), OP_AND, Bool.boolean, + constraint_line, + self.values[value].get_expression(), + constraint_line.comment, constraint_line.line_no) + if ANF_line is Constant.true: + ANF_line = constraint_line + else: + ANF_line = Logical(next_nid(), OP_XOR, Bool.boolean, + constraint_line, + ANF_line, + constraint_line.comment, constraint_line.line_no) + return ANF_line + class Values: total_number_of_values = 0 @@ -1311,6 +1417,10 @@ def __init__(self, nid, sid_line, domain, symbol, comment, line_no, index): Array.number_of_variable_arrays += 1 self.new_mapped_array(index) + def __lt__(self, variable): + # ordering variables for constructing model input + return self.nid < variable.nid + def new_mapped_array(self, index): self.index = index if index is not None: @@ -1341,12 +1451,12 @@ def get_values(self, step): if isinstance(self.sid_line, Bitvector) and self.sid_line.size <= Instance.PROPAGATE: self.cache_values[0] = Values(self.sid_line) if isinstance(self.sid_line, Bool): - self.cache_values[0].set_value(self.sid_line, False, Literal(self, 0, 0)) - self.cache_values[0].set_value(self.sid_line, True, Literal(self, 1, 1)) + self.cache_values[0].set_value(self.sid_line, False, ANF(self).set_value(0, Constant.true)) + self.cache_values[0].set_value(self.sid_line, True, ANF(self).set_value(1, Constant.true)) else: for value in range(2**self.sid_line.size): self.cache_values[0].set_value(self.sid_line, value, - Literal(self, value, value)) + ANF(self).set_value(value, Constant.true)) else: self.cache_values[0] = self return self.cache_values[0]