Skip to content

Commit

Permalink
Fix rewrite hanging when it does not apply
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Jan 28, 2025
1 parent 7eb5f49 commit a0a62e4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
12 changes: 12 additions & 0 deletions DataflowRewriter/ExprLow.lean
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ fully specified and therefore symmetric in both expressions.
| .product e_sub₁ e_sub₂ =>
.product (e_sub₁.replace e_sub e_new) (e_sub₂.replace e_sub e_new)

@[drunfold] def force_replace (e e_sub e_new : ExprLow Ident) : (ExprLow Ident × Bool) :=
if e.check_eq e_sub then (e_new, true) else
match e with
| .base inst typ => (e, false)
| .connect x y e_sub' =>
let rep := e_sub'.force_replace e_sub e_new
(.connect x y rep.1, rep.2)
| .product e_sub₁ e_sub₂ =>
let e_sub₁_rep := e_sub₁.force_replace e_sub e_new
let e_sub₂_rep := e_sub₂.force_replace e_sub e_new
(.product e_sub₁_rep.1 e_sub₂_rep.1, e_sub₁_rep.2 || e_sub₂_rep.2)

@[drunfold]
def abstract (e e_sub : ExprLow Ident) (i_inst : PortMapping Ident) (i_typ : Ident) : ExprLow Ident :=
.base i_inst i_typ |> e.replace e_sub
Expand Down
3 changes: 2 additions & 1 deletion DataflowRewriter/Rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ however, currently the low-level expression language does not remember any names
-- `norm` is a function that canonicalises the connections of the input expression given a list of connections as the
-- ordering guide.
let canon := comm_connections g₁.connections
let rewritten := (canon g_lower).replace (canon e_renamed_input_sub) e_renamed_output_sub
let (rewritten, b) := (canon g_lower).force_replace (canon e_sub_input) e_sub_output
EStateM.guard (.error s!"subexpression not found in the graph") b

let norm := rewritten.normalisedNamesMap fresh_prefix
let out ← rewritten.renamePorts norm |>.higherSS |> ofOption (.error "could not lift expression to graph")
Expand Down
3 changes: 2 additions & 1 deletion Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import DataflowRewriter.DynamaticPrinter
import DataflowRewriter.Rewrites.LoopRewrite
import DataflowRewriter.Rewrites.CombineBranch
import DataflowRewriter.Rewrites.CombineMux
import DataflowRewriter.Rewrites.JoinSplitLoopCond

open Batteries (AssocList)

Expand Down Expand Up @@ -62,7 +63,7 @@ OPTIONS
dot that is easier for debugging purposes.
"
def topLevel (e : ExprHigh String) : RewriteResult (ExprHigh String) :=
rewrite_loop e [CombineMux.rewrite, CombineBranch.rewrite]
rewrite_loop e [CombineMux.rewrite, CombineBranch.rewrite, JoinSplitLoopCond.rewrite] (depth := 10000)

def renameAssoc (assoc : AssocList String (AssocList String String)) (r : RewriteInfo) : AssocList String (AssocList String String) :=
assoc.mapKey (λ x => match r.renamed_input_nodes.find? x with
Expand Down

0 comments on commit a0a62e4

Please # to comment.