Skip to content

Commit

Permalink
Also using dictionaries for clauses, caching conjunctions of clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Feb 10, 2025
1 parent a014236 commit 681a80e
Showing 1 changed file with 61 additions and 35 deletions.
96 changes: 61 additions & 35 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,7 @@ def create_literal(var_line, value):
return Literal.literals[var_line][value]

def is_conflicting(self, literal):
assert self.var_line is not literal.var_line or self.value != literal.value
return self.var_line is literal.var_line
return self.var_line is literal.var_line and self.value != literal.value

def get_expression(self):
return Comparison(next_nid(), OP_EQ, Bool.boolean,
Expand All @@ -406,12 +405,14 @@ def get_expression(self):
class Clause:
clauses = {}

conjunctions = {}

def __init__(self, literal):
self.literals = [literal]
self.literals = {literal.var_line:literal}

def __str__(self):
string = ""
for literal in self.literals:
for literal in self.literals.values():
if string:
string += ", "
string += f"{literal}"
Expand All @@ -423,36 +424,62 @@ def __hash__(self):
def __eq__(self, clause):
return self.literals == clause.literals

def cache_clause(self):
self.literals.sort(key=id)
for clause in Clause.clauses:
if self == clause:
return clause
Clause.clauses[self] = None
return self
def cache_clause(clause):
if isinstance(clause, Clause):
for cached_clause in Clause.clauses:
if clause == cached_clause:
return cached_clause
Clause.clauses[clause] = None
return clause

def create_clause(var_line, value):
return Clause(Literal.create_literal(var_line, value)).cache_clause()
return Clause.cache_clause(Clause(Literal.create_literal(var_line, value)))

def is_conjunction_cached(self, clause):
return (self in Clause.conjunctions and clause in Clause.conjunctions[self]) or (clause in Clause.conjunctions and self in Clause.conjunctions[clause])

def get_cached_conjunction(self, clause):
assert self.is_conjunction_cached(clause)
if self in Clause.conjunctions and clause in Clause.conjunctions[self]:
return Clause.conjunctions[self][clause]
elif clause in Clause.conjunctions and self in Clause.conjunctions[clause]:
return Clause.conjunctions[clause][self]

def cache_conjunction(clause, clause1, clause2):
clause = Clause.cache_clause(clause)
if clause1 not in Clause.conjunctions:
Clause.conjunctions[clause1] = {clause2:clause}
else:
Clause.conjunctions[clause1] |= {clause2:clause}
if clause2 not in Clause.conjunctions:
Clause.conjunctions[clause2] = {clause1:clause}
else:
Clause.conjunctions[clause2] |= {clause1:clause}
return clause

def and_clause(self, clause):
and_clause = Constant.false
for literal in self.literals:
if and_clause is Constant.false:
and_clause = Clause(literal)
else:
and_clause.literals.append(literal)
for literal in clause.literals:
if literal not in and_clause.literals:
for clause_literal in and_clause.literals:
if clause_literal.is_conflicting(literal):
return Constant.false
else:
and_clause.literals.append(literal)
return and_clause.cache_clause()
def conjunction(self, clause):
if self is clause:
return self
elif self.is_conjunction_cached(clause):
return self.get_cached_conjunction(clause)
else:
conjunction = Constant.false
for literal in self.literals.values():
if conjunction is Constant.false:
conjunction = Clause(literal)
else:
conjunction.literals[literal.var_line] = literal
for literal in clause.literals.values():
if literal.var_line in conjunction.literals:
if conjunction.literals[literal.var_line].is_conflicting(literal):
return Clause.cache_conjunction(Constant.false, self, clause)
else:
conjunction.literals[literal.var_line] = literal
return Clause.cache_conjunction(conjunction, self, clause)

def get_expression(self):
clause_line = Constant.false
for literal in self.literals:
for literal in self.literals.values():
literal_line = literal.get_expression()
if clause_line is Constant.false:
clause_line = literal_line
Expand All @@ -470,7 +497,7 @@ class DNF:
disjunctions = {}

def __init__(self, clause):
self.clauses = [clause]
self.clauses = {clause:None}

def __str__(self):
string = ""
Expand All @@ -486,7 +513,6 @@ def __eq__(self, dnf):

def cache_dnf(dnf):
if isinstance(dnf, DNF):
dnf.clauses.sort(key=id)
for cached_dnf in DNF.dnfs:
if dnf == cached_dnf:
return cached_dnf
Expand Down Expand Up @@ -536,12 +562,12 @@ def conjunction(dnf1, dnf2):
dnf = Constant.false
for clause1 in dnf1.clauses:
for clause2 in dnf2.clauses:
clause = clause1.and_clause(clause2)
clause = clause1.conjunction(clause2)
if clause is not Constant.false:
if dnf is Constant.false:
dnf = DNF(clause)
elif clause not in dnf.clauses:
dnf.clauses.append(clause)
else:
dnf.clauses[clause] = None
return DNF.cache_conjunction(dnf, dnf1, dnf2)

def is_disjunction_cached(dnf1, dnf2):
Expand Down Expand Up @@ -586,9 +612,9 @@ def disjunction(dnf1, dnf2):
if dnf is Constant.false:
dnf = DNF(clause)
else:
dnf.clauses.append(clause)
dnf.clauses[clause] = None
for clause in dnf2.clauses:
dnf.clauses.append(clause)
dnf.clauses[clause] = None
return DNF.cache_disjunction(dnf, dnf1, dnf2)

def get_expression(self):
Expand Down

0 comments on commit 681a80e

Please # to comment.