diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 8179c3cca..bd8d601b1 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -822,10 +822,79 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): for port_name, (value, dir) in fragment.named_ports.items(): port_map[f"\\{port_name}"] = value + params = OrderedDict(fragment.parameters) + if fragment.type[0] == "$": - return fragment.type, port_map + return fragment.type, port_map, params else: - return f"\\{fragment.type}", port_map + return f"\\{fragment.type}", port_map, params + + if isinstance(fragment, mem.MemoryInstance): + memory = fragment.memory + init = "".join(format(ast.Const(elem, ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init)) + init = ast.Const(int(init or "0", 2), memory.depth * memory.width) + rd_clk = [] + rd_clk_enable = 0 + rd_clk_polarity = 0 + rd_transparency_mask = 0 + for index, port in enumerate(fragment.read_ports): + if port.domain != "comb": + cd = fragment.domains[port.domain] + rd_clk.append(cd.clk) + if cd.clk_edge == "pos": + rd_clk_polarity |= 1 << index + rd_clk_enable |= 1 << index + if port.transparent: + for write_index, write_port in enumerate(fragment.write_ports): + if port.domain == write_port.domain: + rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index) + else: + rd_clk.append(ast.Const(0, 1)) + wr_clk = [] + wr_clk_enable = 0 + wr_clk_polarity = 0 + for index, port in enumerate(fragment.write_ports): + cd = fragment.domains[port.domain] + wr_clk.append(cd.clk) + wr_clk_enable |= 1 << index + if cd.clk_edge == "pos": + wr_clk_polarity |= 1 << index + params = { + "MEMID": builder._make_name(hierarchy[-1], local=False), + "SIZE": memory.depth, + "OFFSET": 0, + "ABITS": ast.Shape.cast(range(memory.depth)).width, + "WIDTH": memory.width, + "INIT": init, + "RD_PORTS": len(fragment.read_ports), + "RD_CLK_ENABLE": ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))), + "RD_CLK_POLARITY": ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))), + "RD_TRANSPARENCY_MASK": ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))), + "RD_COLLISION_X_MASK": ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))), + "RD_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.read_ports))), + "RD_CE_OVER_SRST": ast.Const(0, max(1, len(fragment.read_ports))), + "RD_ARST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width), + "RD_SRST_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width), + "RD_INIT_VALUE": ast.Const(0, len(fragment.read_ports) * memory.width), + "WR_PORTS": len(fragment.write_ports), + "WR_CLK_ENABLE": ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))), + "WR_CLK_POLARITY": ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))), + "WR_PRIORITY_MASK": ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))), + "WR_WIDE_CONTINUATION": ast.Const(0, max(1, len(fragment.write_ports))), + } + port_map = { + "\\RD_CLK": ast.Cat(rd_clk), + "\\RD_EN": ast.Cat(port.en for port in fragment.read_ports), + "\\RD_ARST": ast.Const(0, len(fragment.read_ports)), + "\\RD_SRST": ast.Const(0, len(fragment.read_ports)), + "\\RD_ADDR": ast.Cat(port.addr for port in fragment.read_ports), + "\\RD_DATA": ast.Cat(port.data for port in fragment.read_ports), + "\\WR_CLK": ast.Cat(wr_clk), + "\\WR_EN": ast.Cat(ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports), + "\\WR_ADDR": ast.Cat(port.addr for port in fragment.write_ports), + "\\WR_DATA": ast.Cat(port.data for port in fragment.write_ports), + } + return "$mem_v2", port_map, params module_name = ".".join(name or "anonymous" for name in hierarchy) module_attrs = OrderedDict() @@ -860,9 +929,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): # Transform all subfragments to their respective cells. Transforming signals connected # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule # name) names. - memories = OrderedDict() for subfragment, sub_name in fragment.subfragments: - if not (subfragment.ports or subfragment.statements or subfragment.subfragments): + if not (subfragment.ports or subfragment.statements or subfragment.subfragments or + isinstance(subfragment, (ir.Instance, mem.MemoryInstance))): # If the fragment is completely empty, skip translating it, otherwise synthesis # tools (including Yosys and Vivado) will treat it as a black box when it is # loaded after conversion to Verilog. @@ -871,25 +940,22 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): if sub_name is None: sub_name = module.anonymous() - sub_params = OrderedDict(getattr(subfragment, "parameters", {})) - - sub_type, sub_port_map = \ + sub_type, sub_port_map, sub_params = \ _convert_fragment(builder, subfragment, name_map, hierarchy=hierarchy + (sub_name,)) - if sub_type == "$mem_v2" and "MEMID" not in sub_params: - sub_params["MEMID"] = builder._make_name(sub_name, local=False) - sub_ports = OrderedDict() for port, value in sub_port_map.items(): - if not isinstance(subfragment, ir.Instance): + if not isinstance(subfragment, (ir.Instance, mem.MemoryInstance)): for signal in value._rhs_signals(): compiler_state.resolve_curr(signal, prefix=sub_name) - if len(value) > 0 or sub_type == "$mem_v2": + if len(value) > 0: sub_ports[port] = rhs_compiler(value) if isinstance(subfragment, ir.Instance): src = _src(subfragment.src_loc) + elif isinstance(subfragment, mem.MemoryInstance): + src = _src(subfragment.memory.src_loc) else: src = "" @@ -1005,7 +1071,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): wire_name = wire_name[1:] name_map[signal] = hierarchy + (wire_name,) - return module.name, port_map + return module.name, port_map, {} def convert_fragment(fragment, name="top", *, emit_src=True): diff --git a/amaranth/hdl/mem.py b/amaranth/hdl/mem.py index 94ffe8fe6..61b8bc0f6 100644 --- a/amaranth/hdl/mem.py +++ b/amaranth/hdl/mem.py @@ -119,55 +119,7 @@ def __getitem__(self, index): return self._array[index] def elaborate(self, platform): - init = "".join(format(Const(elem, unsigned(self.width)).value, f"0{self.width}b") for elem in reversed(self.init)) - init = Const(int(init or "0", 2), self.depth * self.width) - rd_clk = [] - rd_clk_enable = 0 - rd_transparency_mask = 0 - for index, port in enumerate(self._read_ports): - if port.domain != "comb": - rd_clk.append(ClockSignal(port.domain)) - rd_clk_enable |= 1 << index - if port.transparent: - for write_index, write_port in enumerate(self._write_ports): - if port.domain == write_port.domain: - rd_transparency_mask |= 1 << (index * len(self._write_ports) + write_index) - else: - rd_clk.append(Const(0, 1)) - f = Instance("$mem_v2", - *(("a", attr, value) for attr, value in self.attrs.items()), - p_SIZE=self.depth, - p_OFFSET=0, - p_ABITS=Shape.cast(range(self.depth)).width, - p_WIDTH=self.width, - p_INIT=init, - p_RD_PORTS=len(self._read_ports), - p_RD_CLK_ENABLE=Const(rd_clk_enable, len(self._read_ports)) if self._read_ports else Const(0, 1), - p_RD_CLK_POLARITY=Const(-1, unsigned(len(self._read_ports))) if self._read_ports else Const(0, 1), - p_RD_TRANSPARENCY_MASK=Const(rd_transparency_mask, max(1, len(self._read_ports) * len(self._write_ports))), - p_RD_COLLISION_X_MASK=Const(0, max(1, len(self._read_ports) * len(self._write_ports))), - p_RD_WIDE_CONTINUATION=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1), - p_RD_CE_OVER_SRST=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1), - p_RD_ARST_VALUE=Const(0, len(self._read_ports) * self.width), - p_RD_SRST_VALUE=Const(0, len(self._read_ports) * self.width), - p_RD_INIT_VALUE=Const(0, len(self._read_ports) * self.width), - p_WR_PORTS=len(self._write_ports), - p_WR_CLK_ENABLE=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1), - p_WR_CLK_POLARITY=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1), - p_WR_PRIORITY_MASK=Const(0, len(self._write_ports) * len(self._write_ports)) if self._write_ports else Const(0, 1), - p_WR_WIDE_CONTINUATION=Const(0, len(self._write_ports)) if self._write_ports else Const(0, 1), - i_RD_CLK=Cat(rd_clk), - i_RD_EN=Cat(port.en for port in self._read_ports), - i_RD_ARST=Const(0, len(self._read_ports)), - i_RD_SRST=Const(0, len(self._read_ports)), - i_RD_ADDR=Cat(port.addr for port in self._read_ports), - o_RD_DATA=Cat(port.data for port in self._read_ports), - i_WR_CLK=Cat(ClockSignal(port.domain) for port in self._write_ports), - i_WR_EN=Cat(Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in self._write_ports), - i_WR_ADDR=Cat(port.addr for port in self._write_ports), - i_WR_DATA=Cat(port.data for port in self._write_ports), - src_loc=self.src_loc, - ) + f = MemoryInstance(self, self._read_ports, self._write_ports) for port in self._read_ports: port._MustUse__used = True if port.domain == "comb": @@ -211,6 +163,7 @@ def elaborate(self, platform): f.add_driver(signal, port.domain) return f + class ReadPort(Elaboratable): """A memory read port. @@ -354,3 +307,12 @@ def __init__(self, *, data_width, addr_width, domain="sync", name=None, granular name=f"{name}_data", src_loc_at=1) self.en = Signal(data_width // granularity, name=f"{name}_en", src_loc_at=1) + + +class MemoryInstance(Fragment): + def __init__(self, memory, read_ports, write_ports): + super().__init__() + self.memory = memory + self.read_ports = read_ports + self.write_ports = write_ports + self.attrs = memory.attrs diff --git a/amaranth/hdl/xfrm.py b/amaranth/hdl/xfrm.py index c56dbe0ac..02c640380 100644 --- a/amaranth/hdl/xfrm.py +++ b/amaranth/hdl/xfrm.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict from collections.abc import Iterable +from copy import copy from .._utils import flatten, _ignore_deprecated from .. import tracer @@ -8,6 +9,7 @@ from .ast import _StatementList from .cd import * from .ir import * +from .mem import MemoryInstance __all__ = ["ValueVisitor", "ValueTransformer", @@ -261,8 +263,30 @@ def map_drivers(self, fragment, new_fragment): for domain, signal in fragment.iter_drivers(): new_fragment.add_driver(signal, domain) + def map_memory_ports(self, fragment, new_fragment): + new_fragment.read_ports = [ + copy(port) + for port in fragment.read_ports + ] + new_fragment.write_ports = [ + copy(port) + for port in fragment.write_ports + ] + if hasattr(self, "on_value"): + for port in new_fragment.read_ports: + port.en = self.on_value(port.en) + port.addr = self.on_value(port.addr) + port.data = self.on_value(port.data) + for port in new_fragment.write_ports: + port.en = self.on_value(port.en) + port.addr = self.on_value(port.addr) + port.data = self.on_value(port.data) + def on_fragment(self, fragment): - if isinstance(fragment, Instance): + if isinstance(fragment, MemoryInstance): + new_fragment = MemoryInstance(fragment.memory, [], []) + self.map_memory_ports(fragment, new_fragment) + elif isinstance(fragment, Instance): new_fragment = Instance(fragment.type, src_loc=fragment.src_loc) new_fragment.parameters = OrderedDict(fragment.parameters) self.map_named_ports(fragment, new_fragment) @@ -381,6 +405,19 @@ def on_statements(self, stmts): self.on_statement(stmt) def on_fragment(self, fragment): + if isinstance(fragment, MemoryInstance): + for port in fragment.read_ports: + self.on_value(port.addr) + self.on_value(port.data) + self.on_value(port.en) + if port.domain != "comb": + self._add_used_domain(port.domain) + for port in fragment.write_ports: + self.on_value(port.addr) + self.on_value(port.data) + self.on_value(port.en) + self._add_used_domain(port.domain) + if isinstance(fragment, Instance): for name, (value, dir) in fragment.named_ports.items(): self.on_value(value) @@ -444,6 +481,15 @@ def map_drivers(self, fragment, new_fragment): for signal in signals: new_fragment.add_driver(self.on_value(signal), domain) + def map_memory_ports(self, fragment, new_fragment): + super().map_memory_ports(fragment, new_fragment) + for port in new_fragment.read_ports: + if port.domain in self.domain_map: + port.domain = self.domain_map[port.domain] + for port in new_fragment.write_ports: + if port.domain in self.domain_map: + port.domain = self.domain_map[port.domain] + class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): def __init__(self, domains=None): @@ -630,14 +676,11 @@ def _insert_control(self, fragment, domain, signals): def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) - if isinstance(new_fragment, Instance) and new_fragment.type == "$mem_v2": - for kind in ["RD", "WR"]: - clk_parts = new_fragment.named_ports[kind + "_CLK"][0].parts - en_parts = new_fragment.named_ports[kind + "_EN"][0].parts - new_en = [] - for clk, en in zip(clk_parts, en_parts): - if isinstance(clk, ClockSignal) and clk.domain in self.controls: - en = Mux(self.controls[clk.domain], en, Const(0, len(en))) - new_en.append(en) - new_fragment.named_ports[kind + "_EN"] = Cat(new_en), "i" + if isinstance(new_fragment, MemoryInstance): + for port in new_fragment.read_ports: + if port.domain in self.controls: + port.en = port.en & self.controls[port.domain] + for port in new_fragment.write_ports: + if port.domain in self.controls: + port.en = Mux(self.controls[port.domain], port.en, Const(0, len(port.en))) return new_fragment diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index e3ff03b51..8341d2c61 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -4,9 +4,11 @@ from amaranth.hdl.ast import * from amaranth.hdl.cd import * +from amaranth.hdl.dsl import * from amaranth.hdl.ir import * from amaranth.hdl.xfrm import * from amaranth.hdl.mem import * +from amaranth.hdl.mem import MemoryInstance from .utils import * from amaranth._utils import _ignore_deprecated @@ -113,6 +115,22 @@ def test_rename_cd_subfragment(self): "pix": cd_pix, }) + def test_rename_mem_ports(self): + m = Module() + mem = Memory(depth=4, width=16) + m.submodules.mem = mem + mem.read_port(domain="a") + mem.read_port(domain="b") + mem.write_port(domain="c") + + f = Fragment.get(m, None) + f = DomainRenamer({"a": "d", "c": "e"})(f) + mem = f.subfragments[0][0] + self.assertIsInstance(mem, MemoryInstance) + self.assertEqual(mem.read_ports[0].domain, "d") + self.assertEqual(mem.read_ports[1].domain, "b") + self.assertEqual(mem.write_ports[0].domain, "e") + def test_rename_wrong_to_comb(self): with self.assertRaisesRegex(ValueError, r"^Domain 'sync' may not be renamed to 'comb'$"): @@ -501,31 +519,20 @@ def test_enable_read_port(self): mem = Memory(width=8, depth=4) mem.read_port(transparent=False) f = EnableInserter(self.c1)(mem).elaborate(platform=None) - self.assertRepr(f.named_ports["RD_EN"][0], """ - (cat (m (sig c1) (sig mem_r_en) (const 1'd0))) + self.assertRepr(f.read_ports[0].en, """ + (& (sig mem_r_en) (sig c1)) """) def test_enable_write_port(self): mem = Memory(width=8, depth=4) - mem.write_port() + mem.write_port(granularity=2) f = EnableInserter(self.c1)(mem).elaborate(platform=None) - self.assertRepr(f.named_ports["WR_EN"][0], """ - (cat (m + self.assertRepr(f.write_ports[0].en, """ + (m (sig c1) - (cat - (cat - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - (slice (sig mem_w_en) 0:1) - ) - ) - (const 8'd0) - )) + (sig mem_w_en) + (const 4'd0) + ) """)