Skip to content

Commit

Permalink
hdl.mem: Switch to first-class IR representation for memories.
Browse files Browse the repository at this point in the history
Fixes #611.
  • Loading branch information
wanda-phi authored and whitequark committed Jan 17, 2024
1 parent 2fecd1c commit ae36b59
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 92 deletions.
92 changes: 79 additions & 13 deletions amaranth/back/rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,79 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
for port_name, (value, dir) in fragment.named_ports.items():
port_map[f"\\{port_name}"] = value

params = OrderedDict(fragment.parameters)

if fragment.type[0] == "$":
return fragment.type, port_map
return fragment.type, port_map, params
else:
return f"\\{fragment.type}", port_map
return f"\\{fragment.type}", port_map, params

if isinstance(fragment, mem.MemoryInstance):
memory = fragment.memory
init = "".join(format(ast.Const(elem, ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init))
init = ast.Const(int(init or "0", 2), memory.depth * memory.width)
rd_clk = []
rd_clk_enable = 0
rd_clk_polarity = 0
rd_transparency_mask = 0
for index, port in enumerate(fragment.read_ports):
if port.domain != "comb":
cd = fragment.domains[port.domain]
rd_clk.append(cd.clk)
if cd.clk_edge == "pos":
rd_clk_polarity |= 1 << index
rd_clk_enable |= 1 << index
if port.transparent:
for write_index, write_port in enumerate(fragment.write_ports):
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index)
else:
rd_clk.append(ast.Const(0, 1))
wr_clk = []
wr_clk_enable = 0
wr_clk_polarity = 0
for index, port in enumerate(fragment.write_ports):
cd = fragment.domains[port.domain]
wr_clk.append(cd.clk)
wr_clk_enable |= 1 << index
if cd.clk_edge == "pos":
wr_clk_polarity |= 1 << index
params = {
"MEMID": builder._make_name(hierarchy[-1], local=False),
"SIZE": memory.depth,
"OFFSET": 0,
"ABITS": ast.Shape.cast(range(memory.depth)).width,
"WIDTH": memory.width,
"INIT": init,
"RD_PORTS": len(fragment.read_ports),
"RD_CLK_ENABLE": ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))),
"RD_CLK_POLARITY": ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))),
"RD_TRANSPARENCY_MASK": ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
"RD_COLLISION_X_MASK": ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
"RD_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.read_ports))),
"RD_CE_OVER_SRST": ast.Const(0, max(1, len(fragment.read_ports))),
"RD_ARST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"RD_SRST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"RD_INIT_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width),
"WR_PORTS": len(fragment.write_ports),
"WR_CLK_ENABLE": ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))),
"WR_CLK_POLARITY": ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))),
"WR_PRIORITY_MASK": ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))),
"WR_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.write_ports))),
}
port_map = {
"\\RD_CLK": ast.Cat(rd_clk),
"\\RD_EN": ast.Cat(port.en for port in fragment.read_ports),
"\\RD_ARST": ast.Const(0, len(fragment.read_ports)),
"\\RD_SRST": ast.Const(0, len(fragment.read_ports)),
"\\RD_ADDR": ast.Cat(port.addr for port in fragment.read_ports),
"\\RD_DATA": ast.Cat(port.data for port in fragment.read_ports),
"\\WR_CLK": ast.Cat(wr_clk),
"\\WR_EN": ast.Cat(ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports),
"\\WR_ADDR": ast.Cat(port.addr for port in fragment.write_ports),
"\\WR_DATA": ast.Cat(port.data for port in fragment.write_ports),
}
return "$mem_v2", port_map, params

module_name = ".".join(name or "anonymous" for name in hierarchy)
module_attrs = OrderedDict()
Expand Down Expand Up @@ -860,9 +929,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Transform all subfragments to their respective cells. Transforming signals connected
# to their ports into wires eagerly makes sure they get sensible (prefixed with submodule
# name) names.
memories = OrderedDict()
for subfragment, sub_name in fragment.subfragments:
if not (subfragment.ports or subfragment.statements or subfragment.subfragments):
if not (subfragment.ports or subfragment.statements or subfragment.subfragments or
isinstance(subfragment, (ir.Instance, mem.MemoryInstance))):
# If the fragment is completely empty, skip translating it, otherwise synthesis
# tools (including Yosys and Vivado) will treat it as a black box when it is
# loaded after conversion to Verilog.
Expand All @@ -871,25 +940,22 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
if sub_name is None:
sub_name = module.anonymous()

sub_params = OrderedDict(getattr(subfragment, "parameters", {}))

sub_type, sub_port_map = \
sub_type, sub_port_map, sub_params = \
_convert_fragment(builder, subfragment, name_map,
hierarchy=hierarchy + (sub_name,))

if sub_type == "$mem_v2" and "MEMID" not in sub_params:
sub_params["MEMID"] = builder._make_name(sub_name, local=False)

sub_ports = OrderedDict()
for port, value in sub_port_map.items():
if not isinstance(subfragment, ir.Instance):
if not isinstance(subfragment, (ir.Instance, mem.MemoryInstance)):
for signal in value._rhs_signals():
compiler_state.resolve_curr(signal, prefix=sub_name)
if len(value) > 0 or sub_type == "$mem_v2":
if len(value) > 0:

This comment has been minimized.

Copy link
@kivikakk

kivikakk Jan 18, 2024

Contributor

I think this check is still required, unless a change to Yosys is incoming. The generated RTLIL for a ROM now looks like:

  cell $mem_v2 \rom_rd
    parameter \WR_WIDE_CONTINUATION 1'0
    parameter \WR_PRIORITY_MASK 1'0
    parameter \WR_CLK_POLARITY 1'0
[…]
    parameter \OFFSET 0
    parameter \SIZE 71
    parameter \MEMID "\\rom_rd"
    connect \RD_DATA $memory_r_data
    connect \RD_ADDR $memory_r_addr
    connect \RD_SRST 1'0
    connect \RD_ARST 1'0
    connect \RD_EN $memory_r_en
    connect \RD_CLK \clk
  end

Note the lack of empty connections for the WR_ ports. This fails to synthesise with simply:

ERROR: Found error in internal cell \top.\rom_rd ($mem_v2) at kernel/rtlil.cc:1075:

This is because all the ports are expected in the $mem_v2 check:

https://github.com/YosysHQ/yosys/blob/e1f4c5c9cbbeafb5e43db6d58cff99377b6e5803/kernel/rtlil.cc#L1621-L1630

(the port() check always asserts the target's existence, even if the width is zero.)

This comment has been minimized.

Copy link
@whitequark

whitequark Jan 18, 2024

Member

Thanks for reporting. @wanda-phi, could you take a look please?

This comment has been minimized.

Copy link
@wanda-phi

wanda-phi Jan 18, 2024

Author Member

Ah yes, my bad, the code was supposed to look ... somewhat different when I started writing it.

sub_ports[port] = rhs_compiler(value)

if isinstance(subfragment, ir.Instance):
src = _src(subfragment.src_loc)
elif isinstance(subfragment, mem.MemoryInstance):
src = _src(subfragment.memory.src_loc)
else:
src = ""

Expand Down Expand Up @@ -1005,7 +1071,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
wire_name = wire_name[1:]
name_map[signal] = hierarchy + (wire_name,)

return module.name, port_map
return module.name, port_map, {}


def convert_fragment(fragment, name="top", *, emit_src=True):
Expand Down
60 changes: 11 additions & 49 deletions amaranth/hdl/mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,55 +119,7 @@ def __getitem__(self, index):
return self._array[index]

def elaborate(self, platform):
init = "".join(format(Const(elem, unsigned(self.width)).value, f"0{self.width}b") for elem in reversed(self.init))
init = Const(int(init or "0", 2), self.depth * self.width)
rd_clk = []
rd_clk_enable = 0
rd_transparency_mask = 0
for index, port in enumerate(self._read_ports):
if port.domain != "comb":
rd_clk.append(ClockSignal(port.domain))
rd_clk_enable |= 1 << index
if port.transparent:
for write_index, write_port in enumerate(self._write_ports):
if port.domain == write_port.domain:
rd_transparency_mask |= 1 << (index * len(self._write_ports) + write_index)
else:
rd_clk.append(Const(0, 1))
f = Instance("$mem_v2",
*(("a", attr, value) for attr, value in self.attrs.items()),
p_SIZE=self.depth,
p_OFFSET=0,
p_ABITS=Shape.cast(range(self.depth)).width,
p_WIDTH=self.width,
p_INIT=init,
p_RD_PORTS=len(self._read_ports),
p_RD_CLK_ENABLE=Const(rd_clk_enable, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CLK_POLARITY=Const(-1, unsigned(len(self._read_ports))) if self._read_ports else Const(0, 1),
p_RD_TRANSPARENCY_MASK=Const(rd_transparency_mask, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_COLLISION_X_MASK=Const(0, max(1, len(self._read_ports) * len(self._write_ports))),
p_RD_WIDE_CONTINUATION=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_CE_OVER_SRST=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
p_RD_ARST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_SRST_VALUE=Const(0, len(self._read_ports) * self.width),
p_RD_INIT_VALUE=Const(0, len(self._read_ports) * self.width),
p_WR_PORTS=len(self._write_ports),
p_WR_CLK_ENABLE=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_CLK_POLARITY=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
p_WR_PRIORITY_MASK=Const(0, len(self._write_ports) * len(self._write_ports)) if self._write_ports else Const(0, 1),
p_WR_WIDE_CONTINUATION=Const(0, len(self._write_ports)) if self._write_ports else Const(0, 1),
i_RD_CLK=Cat(rd_clk),
i_RD_EN=Cat(port.en for port in self._read_ports),
i_RD_ARST=Const(0, len(self._read_ports)),
i_RD_SRST=Const(0, len(self._read_ports)),
i_RD_ADDR=Cat(port.addr for port in self._read_ports),
o_RD_DATA=Cat(port.data for port in self._read_ports),
i_WR_CLK=Cat(ClockSignal(port.domain) for port in self._write_ports),
i_WR_EN=Cat(Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in self._write_ports),
i_WR_ADDR=Cat(port.addr for port in self._write_ports),
i_WR_DATA=Cat(port.data for port in self._write_ports),
src_loc=self.src_loc,
)
f = MemoryInstance(self, self._read_ports, self._write_ports)
for port in self._read_ports:
port._MustUse__used = True
if port.domain == "comb":
Expand Down Expand Up @@ -211,6 +163,7 @@ def elaborate(self, platform):
f.add_driver(signal, port.domain)
return f


class ReadPort(Elaboratable):
"""A memory read port.
Expand Down Expand Up @@ -354,3 +307,12 @@ def __init__(self, *, data_width, addr_width, domain="sync", name=None, granular
name=f"{name}_data", src_loc_at=1)
self.en = Signal(data_width // granularity,
name=f"{name}_en", src_loc_at=1)


class MemoryInstance(Fragment):
def __init__(self, memory, read_ports, write_ports):
super().__init__()
self.memory = memory
self.read_ports = read_ports
self.write_ports = write_ports
self.attrs = memory.attrs
65 changes: 54 additions & 11 deletions amaranth/hdl/xfrm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable
from copy import copy

from .._utils import flatten, _ignore_deprecated
from .. import tracer
from .ast import *
from .ast import _StatementList
from .cd import *
from .ir import *
from .mem import MemoryInstance


__all__ = ["ValueVisitor", "ValueTransformer",
Expand Down Expand Up @@ -261,8 +263,30 @@ def map_drivers(self, fragment, new_fragment):
for domain, signal in fragment.iter_drivers():
new_fragment.add_driver(signal, domain)

def map_memory_ports(self, fragment, new_fragment):
new_fragment.read_ports = [
copy(port)
for port in fragment.read_ports
]
new_fragment.write_ports = [
copy(port)
for port in fragment.write_ports
]
if hasattr(self, "on_value"):
for port in new_fragment.read_ports:
port.en = self.on_value(port.en)
port.addr = self.on_value(port.addr)
port.data = self.on_value(port.data)
for port in new_fragment.write_ports:
port.en = self.on_value(port.en)
port.addr = self.on_value(port.addr)
port.data = self.on_value(port.data)

def on_fragment(self, fragment):
if isinstance(fragment, Instance):
if isinstance(fragment, MemoryInstance):
new_fragment = MemoryInstance(fragment.memory, [], [])
self.map_memory_ports(fragment, new_fragment)
elif isinstance(fragment, Instance):
new_fragment = Instance(fragment.type, src_loc=fragment.src_loc)
new_fragment.parameters = OrderedDict(fragment.parameters)
self.map_named_ports(fragment, new_fragment)
Expand Down Expand Up @@ -381,6 +405,19 @@ def on_statements(self, stmts):
self.on_statement(stmt)

def on_fragment(self, fragment):
if isinstance(fragment, MemoryInstance):
for port in fragment.read_ports:
self.on_value(port.addr)
self.on_value(port.data)
self.on_value(port.en)
if port.domain != "comb":
self._add_used_domain(port.domain)
for port in fragment.write_ports:
self.on_value(port.addr)
self.on_value(port.data)
self.on_value(port.en)
self._add_used_domain(port.domain)

if isinstance(fragment, Instance):
for name, (value, dir) in fragment.named_ports.items():
self.on_value(value)
Expand Down Expand Up @@ -444,6 +481,15 @@ def map_drivers(self, fragment, new_fragment):
for signal in signals:
new_fragment.add_driver(self.on_value(signal), domain)

def map_memory_ports(self, fragment, new_fragment):
super().map_memory_ports(fragment, new_fragment)
for port in new_fragment.read_ports:
if port.domain in self.domain_map:
port.domain = self.domain_map[port.domain]
for port in new_fragment.write_ports:
if port.domain in self.domain_map:
port.domain = self.domain_map[port.domain]


class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
def __init__(self, domains=None):
Expand Down Expand Up @@ -630,14 +676,11 @@ def _insert_control(self, fragment, domain, signals):

def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment)
if isinstance(new_fragment, Instance) and new_fragment.type == "$mem_v2":
for kind in ["RD", "WR"]:
clk_parts = new_fragment.named_ports[kind + "_CLK"][0].parts
en_parts = new_fragment.named_ports[kind + "_EN"][0].parts
new_en = []
for clk, en in zip(clk_parts, en_parts):
if isinstance(clk, ClockSignal) and clk.domain in self.controls:
en = Mux(self.controls[clk.domain], en, Const(0, len(en)))
new_en.append(en)
new_fragment.named_ports[kind + "_EN"] = Cat(new_en), "i"
if isinstance(new_fragment, MemoryInstance):
for port in new_fragment.read_ports:
if port.domain in self.controls:
port.en = port.en & self.controls[port.domain]
for port in new_fragment.write_ports:
if port.domain in self.controls:
port.en = Mux(self.controls[port.domain], port.en, Const(0, len(port.en)))
return new_fragment
45 changes: 26 additions & 19 deletions tests/test_hdl_xfrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from amaranth.hdl.ast import *
from amaranth.hdl.cd import *
from amaranth.hdl.dsl import *
from amaranth.hdl.ir import *
from amaranth.hdl.xfrm import *
from amaranth.hdl.mem import *
from amaranth.hdl.mem import MemoryInstance

from .utils import *
from amaranth._utils import _ignore_deprecated
Expand Down Expand Up @@ -113,6 +115,22 @@ def test_rename_cd_subfragment(self):
"pix": cd_pix,
})

def test_rename_mem_ports(self):
m = Module()
mem = Memory(depth=4, width=16)
m.submodules.mem = mem
mem.read_port(domain="a")
mem.read_port(domain="b")
mem.write_port(domain="c")

f = Fragment.get(m, None)
f = DomainRenamer({"a": "d", "c": "e"})(f)
mem = f.subfragments[0][0]
self.assertIsInstance(mem, MemoryInstance)
self.assertEqual(mem.read_ports[0].domain, "d")
self.assertEqual(mem.read_ports[1].domain, "b")
self.assertEqual(mem.write_ports[0].domain, "e")

def test_rename_wrong_to_comb(self):
with self.assertRaisesRegex(ValueError,
r"^Domain 'sync' may not be renamed to 'comb'$"):
Expand Down Expand Up @@ -501,31 +519,20 @@ def test_enable_read_port(self):
mem = Memory(width=8, depth=4)
mem.read_port(transparent=False)
f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["RD_EN"][0], """
(cat (m (sig c1) (sig mem_r_en) (const 1'd0)))
self.assertRepr(f.read_ports[0].en, """
(& (sig mem_r_en) (sig c1))
""")

def test_enable_write_port(self):
mem = Memory(width=8, depth=4)
mem.write_port()
mem.write_port(granularity=2)
f = EnableInserter(self.c1)(mem).elaborate(platform=None)
self.assertRepr(f.named_ports["WR_EN"][0], """
(cat (m
self.assertRepr(f.write_ports[0].en, """
(m
(sig c1)
(cat
(cat
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
(slice (sig mem_w_en) 0:1)
)
)
(const 8'd0)
))
(sig mem_w_en)
(const 4'd0)
)
""")


Expand Down

0 comments on commit ae36b59

Please # to comment.