csr.bus: split CSRMultiplexer to CSRInterface+CSRDecoder.
authorwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 10:37:03 +0000 (10:37 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 10:37:03 +0000 (10:37 +0000)
nmigen_soc/csr/bus.py
nmigen_soc/test/test_csr_bus.py

index f0a4f3bdb0023782a9f9dc1ca94207e8dec4d820..eb18a3c7527021542e02c6541e7c83bcdfb7b8a3 100644 (file)
@@ -3,7 +3,7 @@ from nmigen import *
 from nmigen import tracer
 
 
-__all__ = ["CSRElement", "CSRMultiplexer"]
+__all__ = ["CSRElement", "CSRInterface", "CSRDecoder"]
 
 
 class CSRElement(Record):
@@ -40,8 +40,7 @@ class CSRElement(Record):
         if access not in ("r", "w", "rw"):
             raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}"
                              .format(access))
-
-        self.width  = int(width)
+        self.width  = width
         self.access = access
 
         layout = []
@@ -58,18 +57,17 @@ class CSRElement(Record):
         super().__init__(layout, name=name, src_loc_at=1)
 
 
-class CSRMultiplexer(Elaboratable):
+class CSRInterface(Record):
     """CPU-side CSR interface.
 
-    A low-level interface to a set of peripheral CSR registers that implements address-based
-    multiplexing and atomic updates of wide registers.
+    A low-level interface to a set of atomically readable and writable peripheral CSR registers.
 
     Operation
     ---------
 
-    The CSR multiplexer splits each CSR register into chunks according to its data width. Each
-    chunk is assigned an address, and the first chunk of each register always has the provided
-    minimum alignment. This allows accessing CSRs of any size using any datapath width.
+    CSR registers mapped to the CSR bus are split into chunks according to the bus data width.
+    Each chunk is assigned a consecutive address on the bus. This allows accessing CSRs of any
+    size using any datapath width.
 
     When the first chunk of a register is read, the value of a register is captured, and reads
     from subsequent chunks of the same register return the captured values. When any chunk except
@@ -77,12 +75,66 @@ class CSRMultiplexer(Elaboratable):
     chunk writes the captured value to the register. This allows atomically accessing CSRs larger
     than datapath width.
 
-    Reads to padding bytes return zeroes, and writes to padding bytes are ignored.
+    Parameters
+    ----------
+    addr_width : int
+        Address width. At most ``(2 ** addr_width) * data_width`` register bits will be available.
+    data_width : int
+        Data width. Registers are accessed in ``data_width`` sized chunks.
+    name : str
+        Name of the underlying record.
 
-    Writes are registered, and add 1 cycle of latency.
+    Attributes
+    ----------
+    addr : Signal(addr_width)
+        Address for reads and writes.
+    r_data : Signal(data_width)
+        Read data. Valid on the next cycle after ``r_stb`` is asserted.
+    r_stb : Signal()
+        Read strobe. If ``addr`` points to the first chunk of a register, captures register value
+        and causes read side effects to be performed (if any). If ``addr`` points to any chunk
+        of a register, latches the captured value to ``r_data``. Otherwise, latches zero
+        to ``r_data``.
+    w_data : Signal(data_width)
+        Write data. Must be valid when ``w_stb`` is asserted.
+    w_stb : Signal()
+        Write strobe. If ``addr`` points to the last chunk of a register, writes captured value
+        to the register and causes write side effects to be performed (if any). If ``addr`` points
+        to any chunk of a register, latches ``w_data`` to the captured value. Otherwise, does
+        nothing.
+    """
+
+    def __init__(self, *, addr_width, data_width, name=None):
+        if not isinstance(addr_width, int) or addr_width <= 0:
+            raise ValueError("Address width must be a positive integer, not {!r}"
+                             .format(addr_width))
+        if not isinstance(data_width, int) or data_width <= 0:
+            raise ValueError("Data width must be a positive integer, not {!r}"
+                             .format(data_width))
+        self.addr_width = addr_width
+        self.data_width = data_width
 
-    Wide registers
-    --------------
+        super().__init__([
+            ("addr",    addr_width),
+            ("r_data",  data_width),
+            ("r_stb",   1),
+            ("w_data",  data_width),
+            ("w_stb",   1),
+        ], name=name, src_loc_at=1)
+
+
+class CSRDecoder(Elaboratable):
+    """CSR bus decoder.
+
+    An address-based multiplexer for CSR registers implementing atomic updates.
+
+    Latency
+    -------
+
+    Writes are registered, and are performed 1 cycle after ``w_stb`` is asserted.
+
+    Alignment
+    ---------
 
     Because the CSR bus conserves logic and routing resources, it is common to e.g. access
     a CSR bus with an *n*-bit data path from a CPU with a *k*-bit datapath (*k>n*) in cases
@@ -107,56 +159,29 @@ class CSRMultiplexer(Elaboratable):
     Parameters
     ----------
     addr_width : int
-        Address width. At most ``(2 ** addr_width) * data_width`` register bits will be available.
+        Address width. See :class:`CSRInterface`.
     data_width : int
-        Data width. Registers are accessed in ``data_width`` sized chunks.
+        Data width. See :class:`CSRInterface`.
     alignment : int
         Register alignment. The address assigned to each register will be a multiple of
         ``2 ** alignment``.
 
     Attributes
     ----------
-    addr : Signal(addr_width)
-        Address for reads and writes.
-    r_data : Signal(data_width)
-        Read data. Valid on the next cycle after ``r_stb`` is asserted.
-    r_stb : Signal()
-        Read strobe. If ``addr`` points to the first chunk of a register, captures register value
-        and causes read side effects to be performed (if any). If ``addr`` points to any chunk
-        of a register, latches the captured value to ``r_data``. Otherwise, latches zero
-        to ``r_data``.
-    w_data : Signal(data_width)
-        Write data. Must be valid when ``w_stb`` is asserted.
-    w_stb : Signal()
-        Write strobe. If ``addr`` points to the last chunk of a register, writes captured value
-        to the register and causes write side effects to be performed (if any). If ``addr`` points
-        to any chunk of a register, latches ``w_data`` to the captured value. Otherwise, does
-        nothing.
+    bus : :class:`CSRInterface`
+        CSR bus providing access to registers.
     """
     def __init__(self, *, addr_width, data_width, alignment=0):
-        if not isinstance(addr_width, int) or addr_width <= 0:
-            raise ValueError("Address width must be a positive integer, not {!r}"
-                             .format(addr_width))
-        if not isinstance(data_width, int) or data_width <= 0:
-            raise ValueError("Data width must be a positive integer, not {!r}"
-                             .format(data_width))
+        self.bus = CSRInterface(addr_width=addr_width, data_width=data_width)
+
         if not isinstance(alignment, int) or alignment < 0:
             raise ValueError("Alignment must be a non-negative integer, not {!r}"
                              .format(alignment))
-
-        self.addr_width = int(addr_width)
-        self.data_width = int(data_width)
-        self.alignment  = alignment
+        self.alignment = alignment
 
         self._next_addr = 0
         self._elements  = dict()
 
-        self.addr   = Signal(addr_width)
-        self.r_data = Signal(data_width)
-        self.r_stb  = Signal()
-        self.w_data = Signal(data_width)
-        self.w_stb  = Signal()
-
     def add(self, element):
         """Add a register.
 
@@ -176,7 +201,7 @@ class CSRMultiplexer(Elaboratable):
                             .format(element))
 
         addr = self.align_to(self.alignment)
-        self._next_addr += (element.width + self.data_width - 1) // self.data_width
+        self._next_addr += (element.width + self.bus.data_width - 1) // self.bus.data_width
         size = self.align_to(self.alignment) - addr
         self._elements[addr] = element, size
         return addr, size
@@ -222,33 +247,33 @@ class CSRMultiplexer(Elaboratable):
             # arithmetic comparisons, since some toolchains (e.g. Yosys) are too eager to infer
             # carry chains for comparisons, even with a constant. (Register sizes don't have
             # to be powers of 2.)
-            with m.Switch(self.addr):
+            with m.Switch(self.bus.addr):
                 for chunk_offset in range(elem_size):
-                    chunk_slice = slice(chunk_offset * self.data_width,
-                                        (chunk_offset + 1) * self.data_width)
+                    chunk_slice = slice(chunk_offset * self.bus.data_width,
+                                        (chunk_offset + 1) * self.bus.data_width)
                     with m.Case(elem_addr + chunk_offset):
                         if "r" in elem.access:
-                            chunk_r_stb = Signal(self.data_width,
+                            chunk_r_stb = Signal(self.bus.data_width,
                                 name="{}__r_stb_{}".format(elem.name, chunk_offset))
                             r_data_fanin |= Mux(chunk_r_stb, shadow[chunk_slice], 0)
                             if chunk_offset == 0:
-                                m.d.comb += elem.r_stb.eq(self.r_stb)
-                                with m.If(self.r_stb):
+                                m.d.comb += elem.r_stb.eq(self.bus.r_stb)
+                                with m.If(self.bus.r_stb):
                                     m.d.sync += shadow.eq(elem.r_data)
                             # Delay by 1 cycle, allowing reads to be pipelined.
-                            m.d.sync += chunk_r_stb.eq(self.r_stb)
+                            m.d.sync += chunk_r_stb.eq(self.bus.r_stb)
 
                         if "w" in elem.access:
                             if chunk_offset == elem_size - 1:
                                 # Delay by 1 cycle, avoiding combinatorial paths through
                                 # the CSR bus and into CSR registers.
-                                m.d.sync += elem.w_stb.eq(self.w_stb)
-                            with m.If(self.w_stb):
-                                m.d.sync += shadow[chunk_slice].eq(self.w_data)
+                                m.d.sync += elem.w_stb.eq(self.bus.w_stb)
+                            with m.If(self.bus.w_stb):
+                                m.d.sync += shadow[chunk_slice].eq(self.bus.w_data)
 
                 with m.Default():
                     m.d.sync += shadow.eq(0)
 
-        m.d.comb += self.r_data.eq(r_data_fanin)
+        m.d.comb += self.bus.r_data.eq(r_data_fanin)
 
         return m
index ac427b9806f5c78c9a654cc4d0228ace0d158a2c..f54e9350b6741d5d6626389be2225925cbbbdf4d 100644 (file)
@@ -7,7 +7,7 @@ from ..csr.bus import *
 
 
 class CSRElementTestCase(unittest.TestCase):
-    def test_1_ro(self):
+    def test_layout_1_ro(self):
         elem = CSRElement(1, "r")
         self.assertEqual(elem.width, 1)
         self.assertEqual(elem.access, "r")
@@ -16,7 +16,7 @@ class CSRElementTestCase(unittest.TestCase):
             ("r_stb", 1),
         ]))
 
-    def test_8_rw(self):
+    def test_layout_8_rw(self):
         elem = CSRElement(8, access="rw")
         self.assertEqual(elem.width, 8)
         self.assertEqual(elem.access, "rw")
@@ -27,7 +27,7 @@ class CSRElementTestCase(unittest.TestCase):
             ("w_stb", 1),
         ]))
 
-    def test_10_wo(self):
+    def test_layout_10_wo(self):
         elem = CSRElement(10, "w")
         self.assertEqual(elem.width, 10)
         self.assertEqual(elem.access, "w")
@@ -36,7 +36,7 @@ class CSRElementTestCase(unittest.TestCase):
             ("w_stb", 1),
         ]))
 
-    def test_0_rw(self): # degenerate but legal case
+    def test_layout_0_rw(self): # degenerate but legal case
         elem = CSRElement(0, access="rw")
         self.assertEqual(elem.width, 0)
         self.assertEqual(elem.access, "rw")
@@ -58,28 +58,40 @@ class CSRElementTestCase(unittest.TestCase):
             CSRElement(1, "wo")
 
 
-class CSRMultiplexerTestCase(unittest.TestCase):
-    def setUp(self):
-        self.dut = CSRMultiplexer(addr_width=16, data_width=8)
+class CSRInterfaceTestCase(unittest.TestCase):
+    def test_layout(self):
+        iface = CSRInterface(addr_width=12, data_width=8)
+        self.assertEqual(iface.addr_width, 12)
+        self.assertEqual(iface.data_width, 8)
+        self.assertEqual(iface.layout, Layout.cast([
+            ("addr",    12),
+            ("r_data",  8),
+            ("r_stb",   1),
+            ("w_data",  8),
+            ("w_stb",   1),
+        ]))
 
     def test_addr_width_wrong(self):
         with self.assertRaisesRegex(ValueError,
                 r"Address width must be a positive integer, not -1"):
-            CSRMultiplexer(addr_width=-1, data_width=8)
+            CSRInterface(addr_width=-1, data_width=8)
 
     def test_data_width_wrong(self):
         with self.assertRaisesRegex(ValueError,
                 r"Data width must be a positive integer, not -1"):
-            CSRMultiplexer(addr_width=16, data_width=-1)
+            CSRInterface(addr_width=16, data_width=-1)
+
+
+class CSRDecoderTestCase(unittest.TestCase):
+    def setUp(self):
+        self.dut = CSRDecoder(addr_width=16, data_width=8)
 
     def test_alignment_wrong(self):
         with self.assertRaisesRegex(ValueError,
                 r"Alignment must be a non-negative integer, not -1"):
-            CSRMultiplexer(addr_width=16, data_width=8, alignment=-1)
+            CSRDecoder(addr_width=16, data_width=8, alignment=-1)
 
     def test_attrs(self):
-        self.assertEqual(self.dut.addr_width, 16)
-        self.assertEqual(self.dut.data_width, 8)
         self.assertEqual(self.dut.alignment, 0)
 
     def test_add_4b(self):
@@ -117,6 +129,8 @@ class CSRMultiplexerTestCase(unittest.TestCase):
                          (4, 1))
 
     def test_sim(self):
+        bus = self.dut.bus
+
         elem_4_r = CSRElement(4, "r")
         self.dut.add(elem_4_r)
         elem_8_w = CSRElement(8, "w")
@@ -128,55 +142,55 @@ class CSRMultiplexerTestCase(unittest.TestCase):
             yield elem_4_r.r_data.eq(0xa)
             yield elem_16_rw.r_data.eq(0x5aa5)
 
-            yield self.dut.addr.eq(0)
-            yield self.dut.r_stb.eq(1)
+            yield bus.addr.eq(0)
+            yield bus.r_stb.eq(1)
             yield
-            yield self.dut.r_stb.eq(0)
+            yield bus.r_stb.eq(0)
             self.assertEqual((yield elem_4_r.r_stb), 1)
             self.assertEqual((yield elem_16_rw.r_stb), 0)
             yield
-            self.assertEqual((yield self.dut.r_data), 0xa)
+            self.assertEqual((yield bus.r_data), 0xa)
 
-            yield self.dut.addr.eq(2)
-            yield self.dut.r_stb.eq(1)
+            yield bus.addr.eq(2)
+            yield bus.r_stb.eq(1)
             yield
-            yield self.dut.r_stb.eq(0)
+            yield bus.r_stb.eq(0)
             self.assertEqual((yield elem_4_r.r_stb), 0)
             self.assertEqual((yield elem_16_rw.r_stb), 1)
             yield
-            yield self.dut.addr.eq(3) # pipeline a read
-            self.assertEqual((yield self.dut.r_data), 0xa5)
+            yield bus.addr.eq(3) # pipeline a read
+            self.assertEqual((yield bus.r_data), 0xa5)
 
-            yield self.dut.r_stb.eq(1)
+            yield bus.r_stb.eq(1)
             yield
-            yield self.dut.r_stb.eq(0)
+            yield bus.r_stb.eq(0)
             self.assertEqual((yield elem_4_r.r_stb), 0)
             self.assertEqual((yield elem_16_rw.r_stb), 0)
             yield
-            self.assertEqual((yield self.dut.r_data), 0x5a)
+            self.assertEqual((yield bus.r_data), 0x5a)
 
-            yield self.dut.addr.eq(1)
-            yield self.dut.w_data.eq(0x3d)
-            yield self.dut.w_stb.eq(1)
+            yield bus.addr.eq(1)
+            yield bus.w_data.eq(0x3d)
+            yield bus.w_stb.eq(1)
             yield
-            yield self.dut.w_stb.eq(0)
+            yield bus.w_stb.eq(0)
             yield
             self.assertEqual((yield elem_8_w.w_stb), 1)
             self.assertEqual((yield elem_8_w.w_data), 0x3d)
             self.assertEqual((yield elem_16_rw.w_stb), 0)
 
-            yield self.dut.addr.eq(2)
-            yield self.dut.w_data.eq(0x55)
-            yield self.dut.w_stb.eq(1)
+            yield bus.addr.eq(2)
+            yield bus.w_data.eq(0x55)
+            yield bus.w_stb.eq(1)
             yield
             self.assertEqual((yield elem_8_w.w_stb), 0)
             self.assertEqual((yield elem_16_rw.w_stb), 0)
-            yield self.dut.addr.eq(3) # pipeline a write
-            yield self.dut.w_data.eq(0xaa)
+            yield bus.addr.eq(3) # pipeline a write
+            yield bus.w_data.eq(0xaa)
             yield
             self.assertEqual((yield elem_8_w.w_stb), 0)
             self.assertEqual((yield elem_16_rw.w_stb), 0)
-            yield self.dut.w_stb.eq(0)
+            yield bus.w_stb.eq(0)
             yield
             self.assertEqual((yield elem_8_w.w_stb), 0)
             self.assertEqual((yield elem_16_rw.w_stb), 1)
@@ -188,9 +202,9 @@ class CSRMultiplexerTestCase(unittest.TestCase):
             sim.run()
 
 
-class CSRAlignedMultiplexerTestCase(unittest.TestCase):
+class CSRDecoderAlignedTestCase(unittest.TestCase):
     def setUp(self):
-        self.dut = CSRMultiplexer(addr_width=16, data_width=8, alignment=2)
+        self.dut = CSRDecoder(addr_width=16, data_width=8, alignment=2)
 
     def test_attrs(self):
         self.assertEqual(self.dut.alignment, 2)
@@ -216,28 +230,30 @@ class CSRAlignedMultiplexerTestCase(unittest.TestCase):
                          (4, 4))
 
     def test_sim(self):
+        bus = self.dut.bus
+
         elem_20_rw = CSRElement(20, "rw")
         self.dut.add(elem_20_rw)
 
         def sim_test():
-            yield self.dut.w_stb.eq(1)
-            yield self.dut.addr.eq(0)
-            yield self.dut.w_data.eq(0x55)
+            yield bus.w_stb.eq(1)
+            yield bus.addr.eq(0)
+            yield bus.w_data.eq(0x55)
             yield
             self.assertEqual((yield elem_20_rw.w_stb), 0)
-            yield self.dut.addr.eq(1)
-            yield self.dut.w_data.eq(0xaa)
+            yield bus.addr.eq(1)
+            yield bus.w_data.eq(0xaa)
             yield
             self.assertEqual((yield elem_20_rw.w_stb), 0)
-            yield self.dut.addr.eq(2)
-            yield self.dut.w_data.eq(0x33)
+            yield bus.addr.eq(2)
+            yield bus.w_data.eq(0x33)
             yield
             self.assertEqual((yield elem_20_rw.w_stb), 0)
-            yield self.dut.addr.eq(3)
-            yield self.dut.w_data.eq(0xdd)
+            yield bus.addr.eq(3)
+            yield bus.w_data.eq(0xdd)
             yield
             self.assertEqual((yield elem_20_rw.w_stb), 0)
-            yield self.dut.w_stb.eq(0)
+            yield bus.w_stb.eq(0)
             yield
             self.assertEqual((yield elem_20_rw.w_stb), 1)
             self.assertEqual((yield elem_20_rw.w_data), 0x3aa55)