wishbone.bus: add Arbiter.
authorJean-François Nguyen <jf@lambdaconcept.com>
Wed, 22 Jan 2020 14:07:51 +0000 (15:07 +0100)
committerJean-François Nguyen <jf@lambdaconcept.com>
Wed, 22 Jan 2020 15:18:50 +0000 (16:18 +0100)
nmigen_soc/test/test_wishbone_bus.py
nmigen_soc/wishbone/bus.py

index 8d7064d8dd7e4dcb6db3b1f769219b309bb7c89c..88dbaa2d69a6932be89fca4e5cda4954f3f6037f 100644 (file)
@@ -296,3 +296,265 @@ class DecoderSimulationTestCase(unittest.TestCase):
         with Simulator(m, vcd_file=open("test.vcd", "w")) as sim:
             sim.add_process(sim_test())
             sim.run()
+
+
+class ArbiterTestCase(unittest.TestCase):
+    def setUp(self):
+        self.dut = Arbiter(addr_width=31, data_width=32, granularity=16)
+
+    def test_add_wrong(self):
+        with self.assertRaisesRegex(TypeError,
+                r"Initiator bus must be an instance of wishbone\.Interface, not 'foo'"):
+            self.dut.add("foo")
+
+    def test_add_wrong_addr_width(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Initiator bus has address width 15, which is not the same as arbiter "
+                r"address width 31"):
+            self.dut.add(Interface(addr_width=15, data_width=32, granularity=16))
+
+    def test_add_wrong_granularity(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Initiator bus has granularity 8, which is lesser than "
+                r"the arbiter granularity 16"):
+            self.dut.add(Interface(addr_width=31, data_width=32, granularity=8))
+
+    def test_add_wrong_data_width(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Initiator bus has data width 16, which is not the same as arbiter "
+                r"data width 32"):
+            self.dut.add(Interface(addr_width=31, data_width=16, granularity=16))
+
+    def test_add_wrong_optional_output(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Initiator bus has optional output 'lock', but the arbiter does "
+                r"not have a corresponding input"):
+            self.dut.add(Interface(addr_width=31, data_width=32, granularity=16,
+                                   features={"lock"}))
+
+
+class ArbiterSimulationTestCase(unittest.TestCase):
+    def test_simple(self):
+        dut = Arbiter(addr_width=30, data_width=32, granularity=8,
+                      features={"err", "rty", "stall", "lock", "cti", "bte"})
+        itor_1 = Interface(addr_width=30, data_width=32, granularity=8)
+        dut.add(itor_1)
+        itor_2 = Interface(addr_width=30, data_width=32, granularity=16,
+                      features={"err", "rty", "stall", "lock", "cti", "bte"})
+        dut.add(itor_2)
+
+        def sim_test():
+            yield itor_1.adr.eq(0x7ffffffc >> 2)
+            yield itor_1.cyc.eq(1)
+            yield itor_1.stb.eq(1)
+            yield itor_1.sel.eq(0b1111)
+            yield itor_1.we.eq(1)
+            yield itor_1.dat_w.eq(0x12345678)
+            yield dut.bus.dat_r.eq(0xabcdef01)
+            yield dut.bus.ack.eq(1)
+            yield Delay(1e-7)
+            self.assertEqual((yield dut.bus.adr), 0x7ffffffc >> 2)
+            self.assertEqual((yield dut.bus.cyc), 1)
+            self.assertEqual((yield dut.bus.stb), 1)
+            self.assertEqual((yield dut.bus.sel), 0b1111)
+            self.assertEqual((yield dut.bus.we), 1)
+            self.assertEqual((yield dut.bus.dat_w), 0x12345678)
+            self.assertEqual((yield dut.bus.lock), 1)
+            self.assertEqual((yield dut.bus.cti), CycleType.CLASSIC.value)
+            self.assertEqual((yield dut.bus.bte), BurstTypeExt.LINEAR.value)
+            self.assertEqual((yield itor_1.dat_r), 0xabcdef01)
+            self.assertEqual((yield itor_1.ack), 1)
+
+            yield itor_1.cyc.eq(0)
+            yield itor_2.adr.eq(0xe0000000 >> 2)
+            yield itor_2.cyc.eq(1)
+            yield itor_2.stb.eq(1)
+            yield itor_2.sel.eq(0b10)
+            yield itor_2.we.eq(1)
+            yield itor_2.dat_w.eq(0x43218765)
+            yield itor_2.lock.eq(0)
+            yield itor_2.cti.eq(CycleType.INCR_BURST)
+            yield itor_2.bte.eq(BurstTypeExt.WRAP_4)
+            yield Tick()
+
+            yield dut.bus.err.eq(1)
+            yield dut.bus.rty.eq(1)
+            yield dut.bus.stall.eq(0)
+            yield Delay(1e-7)
+            self.assertEqual((yield dut.bus.adr), 0xe0000000 >> 2)
+            self.assertEqual((yield dut.bus.cyc), 1)
+            self.assertEqual((yield dut.bus.stb), 1)
+            self.assertEqual((yield dut.bus.sel), 0b1100)
+            self.assertEqual((yield dut.bus.we), 1)
+            self.assertEqual((yield dut.bus.dat_w), 0x43218765)
+            self.assertEqual((yield dut.bus.lock), 0)
+            self.assertEqual((yield dut.bus.cti), CycleType.INCR_BURST.value)
+            self.assertEqual((yield dut.bus.bte), BurstTypeExt.WRAP_4.value)
+            self.assertEqual((yield itor_2.dat_r), 0xabcdef01)
+            self.assertEqual((yield itor_2.ack), 1)
+            self.assertEqual((yield itor_2.err), 1)
+            self.assertEqual((yield itor_2.rty), 1)
+            self.assertEqual((yield itor_2.stall), 0)
+
+        with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_clock(1e-6)
+            sim.add_sync_process(sim_test())
+            sim.run()
+
+    def test_lock(self):
+        dut = Arbiter(addr_width=30, data_width=32, features={"lock"})
+        itor_1 = Interface(addr_width=30, data_width=32, features={"lock"})
+        dut.add(itor_1)
+        itor_2 = Interface(addr_width=30, data_width=32, features={"lock"})
+        dut.add(itor_2)
+
+        def sim_test():
+            yield itor_1.cyc.eq(1)
+            yield itor_1.lock.eq(1)
+            yield itor_2.cyc.eq(1)
+            yield dut.bus.ack.eq(1)
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+
+            yield itor_1.lock.eq(0)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 0)
+            self.assertEqual((yield itor_2.ack), 1)
+
+            yield itor_2.cyc.eq(0)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+
+            yield itor_1.stb.eq(1)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+
+            yield itor_1.stb.eq(0)
+            yield itor_2.cyc.eq(1)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 0)
+            self.assertEqual((yield itor_2.ack), 1)
+
+        with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_clock(1e-6)
+            sim.add_sync_process(sim_test())
+            sim.run()
+
+    def test_stall(self):
+        dut = Arbiter(addr_width=30, data_width=32, features={"stall"})
+        itor_1 = Interface(addr_width=30, data_width=32, features={"stall"})
+        dut.add(itor_1)
+        itor_2 = Interface(addr_width=30, data_width=32, features={"stall"})
+        dut.add(itor_2)
+
+        def sim_test():
+            yield itor_1.cyc.eq(1)
+            yield itor_2.cyc.eq(1)
+            yield dut.bus.stall.eq(0)
+            yield Delay(1e-6)
+            self.assertEqual((yield itor_1.stall), 0)
+            self.assertEqual((yield itor_2.stall), 1)
+
+            yield dut.bus.stall.eq(1)
+            yield Delay(1e-6)
+            self.assertEqual((yield itor_1.stall), 1)
+            self.assertEqual((yield itor_2.stall), 1)
+
+        with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_process(sim_test())
+            sim.run()
+
+    def test_stall_compat(self):
+        dut = Arbiter(addr_width=30, data_width=32)
+        itor_1 = Interface(addr_width=30, data_width=32, features={"stall"})
+        dut.add(itor_1)
+        itor_2 = Interface(addr_width=30, data_width=32, features={"stall"})
+        dut.add(itor_2)
+
+        def sim_test():
+            yield itor_1.cyc.eq(1)
+            yield itor_2.cyc.eq(1)
+            yield Delay(1e-6)
+            self.assertEqual((yield itor_1.stall), 1)
+            self.assertEqual((yield itor_2.stall), 1)
+
+            yield dut.bus.ack.eq(1)
+            yield Delay(1e-6)
+            self.assertEqual((yield itor_1.stall), 0)
+            self.assertEqual((yield itor_2.stall), 1)
+
+        with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_process(sim_test())
+            sim.run()
+
+    def test_roundrobin(self):
+        dut = Arbiter(addr_width=30, data_width=32)
+        itor_1 = Interface(addr_width=30, data_width=32)
+        dut.add(itor_1)
+        itor_2 = Interface(addr_width=30, data_width=32)
+        dut.add(itor_2)
+        itor_3 = Interface(addr_width=30, data_width=32)
+        dut.add(itor_3)
+
+        def sim_test():
+            yield itor_1.cyc.eq(1)
+            yield itor_2.cyc.eq(0)
+            yield itor_3.cyc.eq(1)
+            yield dut.bus.ack.eq(1)
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+            self.assertEqual((yield itor_3.ack), 0)
+
+            yield itor_1.cyc.eq(0)
+            yield itor_2.cyc.eq(0)
+            yield itor_3.cyc.eq(1)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 0)
+            self.assertEqual((yield itor_2.ack), 0)
+            self.assertEqual((yield itor_3.ack), 1)
+
+            yield itor_1.cyc.eq(1)
+            yield itor_2.cyc.eq(1)
+            yield itor_3.cyc.eq(0)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 1)
+            self.assertEqual((yield itor_2.ack), 0)
+            self.assertEqual((yield itor_3.ack), 0)
+
+            yield itor_1.cyc.eq(0)
+            yield itor_2.cyc.eq(1)
+            yield itor_3.cyc.eq(1)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 0)
+            self.assertEqual((yield itor_2.ack), 1)
+            self.assertEqual((yield itor_3.ack), 0)
+
+            yield itor_1.cyc.eq(1)
+            yield itor_2.cyc.eq(0)
+            yield itor_3.cyc.eq(1)
+            yield Tick()
+            yield Delay(1e-7)
+            self.assertEqual((yield itor_1.ack), 0)
+            self.assertEqual((yield itor_2.ack), 0)
+            self.assertEqual((yield itor_3.ack), 1)
+
+        with Simulator(dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_clock(1e-6)
+            sim.add_sync_process(sim_test())
+            sim.run()
index 1899ff587b91f2871b28aa3e553764cb53427015..31a4a2b8e0a74ea64ac4ae696ade1ef59676f483 100644 (file)
@@ -6,7 +6,7 @@ from nmigen.utils import log2_int
 from ..memory import MemoryMap
 
 
-__all__ = ["CycleType", "BurstTypeExt", "Interface", "Decoder"]
+__all__ = ["CycleType", "BurstTypeExt", "Interface", "Decoder", "Arbiter"]
 
 
 class CycleType(Enum):
@@ -266,3 +266,120 @@ class Decoder(Elaboratable):
             m.d.comb += self.bus.stall.eq(stall_fanin)
 
         return m
+
+
+class Arbiter(Elaboratable):
+    """Wishbone bus arbiter.
+
+    A round-robin arbiter for initiators accessing a shared Wishbone bus.
+
+    Parameters
+    ----------
+    addr_width : int
+        Address width. See :class:`Interface`.
+    data_width : int
+        Data width. See :class:`Interface`.
+    granularity : int
+        Granularity. See :class:`Interface`
+    features : iter(str)
+        Optional signal set. See :class:`Interface`.
+
+    Attributes
+    ----------
+    bus : :class:`Interface`
+        Shared Wishbone bus.
+    """
+    def __init__(self, *, addr_width, data_width, granularity=None, features=frozenset()):
+        self.bus    = Interface(addr_width=addr_width, data_width=data_width,
+                                granularity=granularity, features=features)
+        self._itors = []
+
+    def add(self, itor_bus):
+        """Add an initiator bus to the arbiter.
+
+        The initiator bus must have the same address width and data width as the arbiter. The
+        granularity of the initiator bus must be greater than or equal to the granularity of
+        the arbiter.
+        """
+        if not isinstance(itor_bus, Interface):
+            raise TypeError("Initiator bus must be an instance of wishbone.Interface, not {!r}"
+                            .format(itor_bus))
+        if itor_bus.addr_width != self.bus.addr_width:
+            raise ValueError("Initiator bus has address width {}, which is not the same as "
+                             "arbiter address width {}"
+                             .format(itor_bus.addr_width, self.bus.addr_width))
+        if itor_bus.granularity < self.bus.granularity:
+            raise ValueError("Initiator bus has granularity {}, which is lesser than the "
+                             "arbiter granularity {}"
+                             .format(itor_bus.granularity, self.bus.granularity))
+        if itor_bus.data_width != self.bus.data_width:
+            raise ValueError("Initiator bus has data width {}, which is not the same as "
+                             "arbiter data width {}"
+                             .format(itor_bus.data_width, self.bus.data_width))
+        for opt_output in {"lock", "cti", "bte"}:
+            if hasattr(itor_bus, opt_output) and not hasattr(self.bus, opt_output):
+                raise ValueError("Initiator bus has optional output {!r}, but the arbiter "
+                                 "does not have a corresponding input"
+                                 .format(opt_output))
+
+        self._itors.append(itor_bus)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        requests = Signal(len(self._itors))
+        grant    = Signal(range(len(self._itors)))
+        m.d.comb += requests.eq(Cat(itor_bus.cyc for itor_bus in self._itors))
+
+        bus_busy = self.bus.cyc
+        if hasattr(self.bus, "lock"):
+            # If LOCK is not asserted, we also wait for STB to be deasserted before granting bus
+            # ownership to the next initiator. If we didn't, the next bus owner could receive
+            # an ACK (or ERR, RTY) from the previous transaction when targeting the same
+            # peripheral.
+            bus_busy &= self.bus.lock | self.bus.stb
+
+        with m.If(~bus_busy):
+            with m.Switch(grant):
+                for i in range(len(requests)):
+                    with m.Case(i):
+                        for pred in reversed(range(i)):
+                            with m.If(requests[pred]):
+                                m.d.sync += grant.eq(pred)
+                        for succ in reversed(range(i + 1, len(requests))):
+                            with m.If(requests[succ]):
+                                m.d.sync += grant.eq(succ)
+
+        with m.Switch(grant):
+            for i, itor_bus in enumerate(self._itors):
+                m.d.comb += itor_bus.dat_r.eq(self.bus.dat_r)
+                if hasattr(itor_bus, "stall"):
+                    itor_bus_stall = Signal(reset=1)
+                    m.d.comb += itor_bus.stall.eq(itor_bus_stall)
+
+                with m.Case(i):
+                    ratio = itor_bus.granularity // self.bus.granularity
+                    m.d.comb += [
+                        self.bus.adr.eq(itor_bus.adr),
+                        self.bus.dat_w.eq(itor_bus.dat_w),
+                        self.bus.sel.eq(Cat(Repl(sel, ratio) for sel in itor_bus.sel)),
+                        self.bus.we.eq(itor_bus.we),
+                        self.bus.stb.eq(itor_bus.stb),
+                    ]
+                    m.d.comb += self.bus.cyc.eq(itor_bus.cyc)
+                    if hasattr(self.bus, "lock"):
+                        m.d.comb += self.bus.lock.eq(getattr(itor_bus, "lock", 1))
+                    if hasattr(self.bus, "cti"):
+                        m.d.comb += self.bus.cti.eq(getattr(itor_bus, "cti", CycleType.CLASSIC))
+                    if hasattr(self.bus, "bte"):
+                        m.d.comb += self.bus.bte.eq(getattr(itor_bus, "bte", BurstTypeExt.LINEAR))
+
+                    m.d.comb += itor_bus.ack.eq(self.bus.ack)
+                    if hasattr(itor_bus, "err"):
+                        m.d.comb += itor_bus.err.eq(getattr(self.bus, "err", 0))
+                    if hasattr(itor_bus, "rty"):
+                        m.d.comb += itor_bus.rty.eq(getattr(self.bus, "rty", 0))
+                    if hasattr(itor_bus, "stall"):
+                        m.d.comb += itor_bus_stall.eq(getattr(self.bus, "stall", ~self.bus.ack))
+
+        return m