-
Notifications
You must be signed in to change notification settings - Fork 615
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split Passes.scala into separate files (#1496)
* Split Passes.scala into separate files * Add imports of implicit things Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
6329307
commit 54ff945
Showing
7 changed files
with
395 additions
and
367 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package firrtl.passes | ||
|
||
import firrtl.Utils.{create_exps, flow, get_field, get_valid_points, times, to_flip, to_flow} | ||
import firrtl.ir._ | ||
import firrtl.options.{PreservesAll, Dependency} | ||
import firrtl.{DuplexFlow, Flow, SinkFlow, SourceFlow, Transform, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex} | ||
import firrtl.Mappers._ | ||
|
||
object ExpandConnects extends Pass with PreservesAll[Transform] { | ||
|
||
override val prerequisites = | ||
Seq( Dependency(PullMuxes), | ||
Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped | ||
|
||
def run(c: Circuit): Circuit = { | ||
def expand_connects(m: Module): Module = { | ||
val flows = collection.mutable.LinkedHashMap[String,Flow]() | ||
def expand_s(s: Statement): Statement = { | ||
def set_flow(e: Expression): Expression = e map set_flow match { | ||
case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name)) | ||
case ex: WSubField => | ||
val f = get_field(ex.expr.tpe, ex.name) | ||
val flowx = times(flow(ex.expr), f.flip) | ||
WSubField(ex.expr, ex.name, ex.tpe, flowx) | ||
case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) | ||
case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr)) | ||
case ex => ex | ||
} | ||
s match { | ||
case sx: DefWire => flows(sx.name) = DuplexFlow; sx | ||
case sx: DefRegister => flows(sx.name) = DuplexFlow; sx | ||
case sx: WDefInstance => flows(sx.name) = SourceFlow; sx | ||
case sx: DefMemory => flows(sx.name) = SourceFlow; sx | ||
case sx: DefNode => flows(sx.name) = SourceFlow; sx | ||
case sx: IsInvalid => | ||
val invalids = create_exps(sx.expr).flatMap { case expx => | ||
flow(set_flow(expx)) match { | ||
case DuplexFlow => Some(IsInvalid(sx.info, expx)) | ||
case SinkFlow => Some(IsInvalid(sx.info, expx)) | ||
case _ => None | ||
} | ||
} | ||
invalids.size match { | ||
case 0 => EmptyStmt | ||
case 1 => invalids.head | ||
case _ => Block(invalids) | ||
} | ||
case sx: Connect => | ||
val locs = create_exps(sx.loc) | ||
val exps = create_exps(sx.expr) | ||
Block(locs.zip(exps).map { case (locx, expx) => | ||
to_flip(flow(locx)) match { | ||
case Default => Connect(sx.info, locx, expx) | ||
case Flip => Connect(sx.info, expx, locx) | ||
} | ||
}) | ||
case sx: PartialConnect => | ||
val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default) | ||
val locs = create_exps(sx.loc) | ||
val exps = create_exps(sx.expr) | ||
val stmts = ls map { case (x, y) => | ||
locs(x).tpe match { | ||
case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) | ||
case _ => | ||
to_flip(flow(locs(x))) match { | ||
case Default => Connect(sx.info, locs(x), exps(y)) | ||
case Flip => Connect(sx.info, exps(y), locs(x)) | ||
} | ||
} | ||
} | ||
Block(stmts) | ||
case sx => sx map expand_s | ||
} | ||
} | ||
|
||
m.ports.foreach { p => flows(p.name) = to_flow(p.direction) } | ||
Module(m.info, m.name, m.ports, expand_s(m.body)) | ||
} | ||
|
||
val modulesx = c.modules.map { | ||
case (m: ExtModule) => m | ||
case (m: Module) => expand_connects(m) | ||
} | ||
Circuit(c.info, modulesx, c.main) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package firrtl.passes | ||
|
||
import firrtl.PrimOps._ | ||
import firrtl.Utils.{BoolType, error, zero} | ||
import firrtl.ir._ | ||
import firrtl.options.{PreservesAll, Dependency} | ||
import firrtl.transforms.ConstantPropagation | ||
import firrtl.{Transform, bitWidth} | ||
import firrtl.Mappers._ | ||
|
||
// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts | ||
// TODO replace UInt with zero-width wire instead | ||
object Legalize extends Pass with PreservesAll[Transform] { | ||
|
||
override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) | ||
|
||
override val optionalPrerequisites = Seq.empty | ||
|
||
override val dependents = Seq.empty | ||
|
||
private def legalizeShiftRight(e: DoPrim): Expression = { | ||
require(e.op == Shr) | ||
e.args.head match { | ||
case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) | ||
case _ => | ||
val amount = e.consts.head.toInt | ||
val width = bitWidth(e.args.head.tpe) | ||
lazy val msb = width - 1 | ||
if (amount >= width) { | ||
e.tpe match { | ||
case UIntType(_) => zero | ||
case SIntType(_) => | ||
val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) | ||
DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) | ||
case t => error(s"Unsupported type $t for Primop Shift Right") | ||
} | ||
} else { | ||
e | ||
} | ||
} | ||
} | ||
private def legalizeBitExtract(expr: DoPrim): Expression = { | ||
expr.args.head match { | ||
case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) | ||
case _ => expr | ||
} | ||
} | ||
private def legalizePad(expr: DoPrim): Expression = expr.args.head match { | ||
case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => | ||
UIntLiteral(value, IntWidth(expr.consts.head)) | ||
case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => | ||
SIntLiteral(value, IntWidth(expr.consts.head)) | ||
case _ => expr | ||
} | ||
private def legalizeConnect(c: Connect): Statement = { | ||
val t = c.loc.tpe | ||
val w = bitWidth(t) | ||
if (w >= bitWidth(c.expr.tpe)) { | ||
c | ||
} else { | ||
val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) | ||
val expr = t match { | ||
case UIntType(_) => bits | ||
case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) | ||
case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) | ||
} | ||
Connect(c.info, c.loc, expr) | ||
} | ||
} | ||
def run (c: Circuit): Circuit = { | ||
def legalizeE(expr: Expression): Expression = expr map legalizeE match { | ||
case prim: DoPrim => prim.op match { | ||
case Shr => legalizeShiftRight(prim) | ||
case Pad => legalizePad(prim) | ||
case Bits | Head | Tail => legalizeBitExtract(prim) | ||
case _ => prim | ||
} | ||
case e => e // respect pre-order traversal | ||
} | ||
def legalizeS (s: Statement): Statement = { | ||
val legalizedStmt = s match { | ||
case c: Connect => legalizeConnect(c) | ||
case _ => s | ||
} | ||
legalizedStmt map legalizeS map legalizeE | ||
} | ||
c copy (modules = c.modules map (_ map legalizeS)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package firrtl.passes | ||
|
||
import firrtl.Utils.error | ||
import firrtl.ir.Circuit | ||
import firrtl.{CircuitForm, CircuitState, FirrtlUserException, Transform, UnknownForm} | ||
|
||
/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] | ||
* Has an [[UnknownForm]], because larger [[Transform]] should specify form | ||
*/ | ||
trait Pass extends Transform { | ||
def inputForm: CircuitForm = UnknownForm | ||
def outputForm: CircuitForm = UnknownForm | ||
def run(c: Circuit): Circuit | ||
def execute(state: CircuitState): CircuitState = { | ||
val result = (state.form, inputForm) match { | ||
case (_, UnknownForm) => run(state.circuit) | ||
case (UnknownForm, _) => run(state.circuit) | ||
case (x, y) if x > y => | ||
error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") | ||
case _ => run(state.circuit) | ||
} | ||
CircuitState(result, outputForm, state.annotations, state.renames) | ||
} | ||
} | ||
|
||
// Error handling | ||
class PassException(message: String) extends FirrtlUserException(message) | ||
class PassExceptions(val exceptions: Seq[PassException]) extends FirrtlUserException("\n" + exceptions.mkString("\n")) | ||
class Errors { | ||
val errors = collection.mutable.ArrayBuffer[PassException]() | ||
def append(pe: PassException) = errors.append(pe) | ||
def trigger() = errors.size match { | ||
case 0 => | ||
case 1 => throw errors.head | ||
case _ => | ||
append(new PassException(s"${errors.length} errors detected!")) | ||
throw new PassExceptions(errors) | ||
} | ||
} |
Oops, something went wrong.