diff --git a/tools/bitme.py b/tools/bitme.py index 575144ba..bba1389a 100755 --- a/tools/bitme.py +++ b/tools/bitme.py @@ -353,19 +353,30 @@ def get_bitwuzla(self, tm): class Values: def __init__(self, sid_line): + self.sid_line = sid_line self.number_of_values = 0 self.values = {} - self.number_of_constraints = 0 - self.constraints = {} + + def OR(arg1, arg2): + assert arg1 != Constant.false or arg2 != Constant.false + if arg1 == Constant.true or arg2 == Constant.true: + return Constant.true + else: + return Logical(next_nid(), OP_OR, Bool.boolean, arg1, arg2, arg1.comment, arg1.line_no) + + def get_value(self): + assert self.number_of_values == 1 + return list(self.values)[0] def set_value(self, constraint, value): assert self.sid_line == value.sid_line + assert constraint != Constant.false if value not in self.values: self.number_of_values += 1 - self.values[value] = constraint - if constraint not in self.constraints: - self.number_of_constraints += 1 - self.constraints[constraint] = value + self.values[value] = constraint + else: + self.values[value] = Values.OR(constraint, self.values[value]) + return self class Expression(Line): def __init__(self, nid, sid_line, domain, comment, line_no): @@ -382,6 +393,10 @@ def get_domain(self): # filter out uninitialized states return [state for state in self.domain if state.init_line is not None] + def get_value(self): + # TODO: remove when done with domain propagation + return self + def get_z3_lambda(self): if self.z3_lambda is None: domain = self.get_domain() @@ -432,7 +447,9 @@ def get_mapped_array_expression_for(self, index): return self def get_values(self, step): - return self + if 0 not in self.cache_values: + self.cache_values[0] = Values(self.sid_line).set_value(Constant.true, self) + return self.cache_values[0] def get_z3(self): if self.z3 is None: @@ -619,7 +636,7 @@ def get_instance(self, step): def set_instance(self, instance, step): self.cache_instance[step] = instance if Instance.PROPAGATE: - self.cache_instance[step] = self.cache_instance[step].get_values(step) + self.cache_instance[step] = self.cache_instance[step].get_values(step).get_value() def get_z3_select(self, step): instance = self.get_instance(step) @@ -805,7 +822,7 @@ def copy(self, arg1_line): def get_values(self, step): if step not in self.cache_values: - arg1_value = self.arg1_line.get_values(step) + arg1_value = self.arg1_line.get_values(step).get_value() if isinstance(arg1_value, Constant): if self.op == 'sext': self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line, arg1_value.signed_value, self.comment, self.line_no) @@ -861,7 +878,7 @@ def copy(self, arg1_line): def get_values(self, step): if step not in self.cache_values: - arg1_value = self.arg1_line.get_values(step) + arg1_value = self.arg1_line.get_values(step).get_value() if isinstance(arg1_value, Constant): self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line, (arg1_value.value & 2**(self.u + 1) - 1) >> self.l, self.comment, self.line_no) @@ -910,30 +927,38 @@ def get_mapped_array_expression_for(self, index): arg1_line = self.arg1_line.get_mapped_array_expression_for(None) return self.copy(arg1_line) + def get_unaries(self, values, op): + results = Values(self.sid_line) + for value in values.values: + constraint = values.values[value] + if op == (lambda x: not x): + if value == Constant.false: + results.set_value(constraint, Constant.true) + else: + assert value == Constant.true + results.set_value(constraint, Constant.false) + else: + results.set_value(constraint, + type(value)(next_nid(), self.sid_line, + op(value.value) % 2**self.sid_line.size, self.comment, self.line_no)) + return results + def get_values(self, step): if step not in self.cache_values: arg1_value = self.arg1_line.get_values(step) - if isinstance(arg1_value, Constant): - value = arg1_value.value + if isinstance(arg1_value, Values): if self.op == 'not': if isinstance(self.sid_line, Bool): - if arg1_value == Constant.false: - self.cache_values[step] = Constant.true - else: - assert arg1_value == Constant.true - self.cache_values[step] = Constant.false - return self.cache_values[step] + assert arg1_value.number_of_values <= 2 + self.cache_values[step] = self.get_unaries(arg1_value, lambda x: not x) else: - value = ~value + self.cache_values[step] = self.get_unaries(arg1_value, lambda x: ~x) elif self.op == 'inc': - value = value + 1 + self.cache_values[step] = self.get_unaries(arg1_value, lambda x: x + 1) elif self.op == 'dec': - value = value - 1 + self.cache_values[step] = self.get_unaries(arg1_value, lambda x: x - 1) elif self.op == 'neg': - value = -value - - self.cache_values[step] = type(arg1_value)(next_nid(), self.sid_line, - value % 2**self.sid_line.size, self.comment, self.line_no) + self.cache_values[step] = self.get_unaries(arg1_value, lambda x: -x) else: self.cache_values[step] = self.copy(arg1_value) return self.cache_values[step] @@ -999,8 +1024,8 @@ def get_mapped_array_expression_for(self, index): def get_values(self, step): if step not in self.cache_values: - arg1_value = self.arg1_line.get_values(step) - arg2_value = self.arg2_line.get_values(step) + arg1_value = self.arg1_line.get_values(step).get_value() + arg2_value = self.arg2_line.get_values(step).get_value() self.cache_values[step] = self.copy(arg1_value, arg2_value) return self.cache_values[step] @@ -1335,9 +1360,9 @@ def __str__(self): def get_values(self, step): if step not in self.cache_values: - arg1_value = self.arg1_line.get_values(step) - arg2_value = self.arg2_line.get_values(step) - arg3_value = self.arg3_line.get_values(step) + arg1_value = self.arg1_line.get_values(step).get_value() + arg2_value = self.arg2_line.get_values(step).get_value() + arg3_value = self.arg3_line.get_values(step).get_value() self.cache_values[step] = self.copy(arg1_value, arg2_value, arg3_value) return self.cache_values[step]