-
Notifications
You must be signed in to change notification settings - Fork 177
Common Pass Idioms
Please make sure you've read the following pages before continuing:
Suppose we want to write a pass that splits nested DoPrim expressions, thus transforming this:
circuit Top:
module Top :
input x: UInt<3>
input y: UInt<3>
input z: UInt<3>
output o: UInt<3>
o <= add(x, add(y, z))
into this:
circuit Top:
module Top :
input x: UInt<3>
input y: UInt<3>
input z: UInt<3>
output o: UInt<3>
node GEN_1 = add(y, z)
o <= add(x, GEN_1)
We first need to traverse the AST to every Statement and Expression. Then, when we see a DoPrim, we need to add a new DefNode to the module's body and insert a reference to that DefNode in place of the DoPrim. The code below implements this (and preserves the Info token). Note that Namespace
is a utility function located in (Namespace.scala)[https://github.com/ucb-bar/firrtl/blob/master/src/main/scala/firrtl/Namespace.scala].
object Splitter extends Pass {
def name = "Splitter!"
/** Run splitM on every module **/
def run(c: Circuit): Circuit = c.copy(modules = c.modules map(splitM(_)))
/** Run splitS on the body of every module **/
def splitM(m: DefModule): DefModule = m map splitS(Namespace(m))
/** Run splitE on all children Expressions.
* If stmts contain extra statements, return a Block containing them and
* the new statement; otherwise, return the new statement. */
def splitS(namespace: Namespace)(s: Statement): Statement = {
val block = mutable.ArrayBuffer[Statement]()
s match {
case s: HasInfo =>
val newStmt = s map splitE(block, namespace, s.info)
block.length match {
case 0 => newStmt
case _ => Block(block.toSeq :+ newStmt)
}
case s => s map splitS(namespace)
}
/** Run splitE on all children expressions.
* If e is a DoPrim, add a new DefNode to block and return reference to
* the DefNode; otherwise return e.*/
def splitE(block: mutable.ArrayBuffer[Statement], namespace: Namespace,
info: Info)(e: Expression): Expression = e map splitE(block, namespace, info) match {
case e: DoPrim =>
val newName = namespace.newTemp
block += DefNode(info, newName, e)
Ref(newName, e.tpe)
case _ => e
}
}
Suppose we want to write a pass that inlined all DefNodes whose value is a literal, thus transforming this:
circuit Top:
module Top :
input x: UInt<3>
output o: UInt<4>
node y = UInt(1)
o <= add(x, y)
into this:
circuit Top:
module Top :
input x: UInt<3>
output y: UInt<4>
o <= add(x, UInt(1))
We first need to traverse the AST to every Statement and Expression. Then, when we see a DefNode pointing to a Literal, we need to store it into a hashmap and return an EmptyStmt (thus deleting that DefNode). Then, whenever we see a reference to the deleted DefNode, we must insert the corresponding Literal.
object Inliner extends Pass {
def name = "Inliner!"
/** Run inlineM on every module **/
def run(c: Circuit): Circuit = c.copy(modules = c.modules map(inlineM(_)))
/** Run inlineS on the body of every module **/
def inlineM(m: DefModule): DefModule = m map inlineS(mutable.HashMap[String, Expression]())
/** Run inlineE on all children Expressions, and then run inlineS on children statements.
* If statement is a DefNode containing a literal, update values and
* return EmptyStmt; otherwise return statement. */
def inlineS(values: mutable.HashMap[String, Expression])(s: Statement): Statement =
s map inlineE(values) map inlineS(values) match {
case d: DefNode => d.value match {
case l: Literal =>
values(d.name) = l
EmptyStmt
case _ => d
}
case o => o
}
/** If e is a reference whose name is contained in values,
* return values(e.name); otherwise run inlineE on all
* children expressions.*/
def inlineE(values: mutable.HashMap[String, Expression])(e: Expression): Expression = e match {
case e: Ref if values.contains(e.name) => values(e.name)
case _ => e map inlineE(values)
}
}
Would this be useful? Let @azidar know by submitting an issue!
Would this be useful? Let @azidar know by submitting an issue!