add BetterMultiPriorityPicker and formal proof
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 4 Aug 2022 05:41:28 +0000 (22:41 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 4 Aug 2022 05:41:28 +0000 (22:41 -0700)
src/nmutil/formal/test_picker.py
src/nmutil/picker.py

index 1abf56ff101101e444367c000b12ce64bdc488db..caaf0071b51b3971f35f26f01560d1c8381c2a28 100644 (file)
@@ -7,7 +7,9 @@ import unittest
 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, Array, Shape, Mux
 from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
-from nmutil.picker import PriorityPicker, MultiPriorityPicker
+from nmutil.picker import (BetterMultiPriorityPicker, PriorityPicker,
+                           MultiPriorityPicker)
+from nmutil.sim_util import write_il
 
 
 class TestPriorityPicker(FHDLTestCase):
@@ -227,61 +229,65 @@ class TestPriorityPicker(FHDLTestCase):
 
 
 class TestMultiPriorityPicker(FHDLTestCase):
-    def tst(self, *, cls=MultiPriorityPicker, wid, levels, indices, multi_in):
-        assert isinstance(wid, int) and wid >= 1
-        assert isinstance(levels, int) and 1 <= levels <= wid
-        assert isinstance(indices, bool)
-        assert isinstance(multi_in, bool)
-        dut = cls(wid=wid, levels=levels, indices=indices, multi_in=multi_in)
-        self.assertEqual(wid, dut.wid)
+    def make_dut(self, width, levels, indices, multi_in):
+        dut = MultiPriorityPicker(wid=width, levels=levels, indices=indices,
+                                  multi_in=multi_in)
+        self.assertEqual(width, dut.wid)
         self.assertEqual(levels, dut.levels)
         self.assertEqual(indices, dut.indices)
         self.assertEqual(multi_in, dut.multi_in)
+        return dut
+
+    def tst(self, *, width, levels, indices, multi_in):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(indices, bool)
+        assert isinstance(multi_in, bool)
+        dut = self.make_dut(width=width, levels=levels, indices=indices,
+                            multi_in=multi_in)
         expected_ports = []
         if multi_in:
-            self.assertIsInstance(dut.i, Array)
+            self.assertIsInstance(dut.i, (Array, list))
             self.assertEqual(len(dut.i), levels)
             for i in dut.i:
                 self.assertIsInstance(i, Signal)
-                self.assertEqual(len(i), wid)
+                self.assertEqual(len(i), width)
                 expected_ports.append(i)
         else:
             self.assertIsInstance(dut.i, Signal)
-            self.assertEqual(len(dut.i), wid)
+            self.assertEqual(len(dut.i), width)
             expected_ports.append(dut.i)
 
-        self.assertIsInstance(dut.o, Array)
+        self.assertIsInstance(dut.o, (Array, list))
         self.assertEqual(len(dut.o), levels)
         for o in dut.o:
             self.assertIsInstance(o, Signal)
-            self.assertEqual(len(o), wid)
+            self.assertEqual(len(o), width)
             expected_ports.append(o)
 
         self.assertEqual(len(dut.en_o), levels)
         expected_ports.append(dut.en_o)
 
         if indices:
-            expected_idx_o_shape = Shape.cast(range(levels))
-            if levels <= 1:
-                expected_idx_o_shape = Shape(0, False)
-            self.assertIsInstance(dut.idx_o, Array)
+            self.assertIsInstance(dut.idx_o, (Array, list))
             self.assertEqual(len(dut.idx_o), levels)
             for idx_o in dut.idx_o:
                 self.assertIsInstance(idx_o, Signal)
-                self.assertEqual(idx_o.shape(), expected_idx_o_shape)
                 expected_ports.append(idx_o)
         else:
             self.assertFalse(hasattr(dut, "idx_o"))
 
         self.assertListEqual(expected_ports, dut.ports())
 
+        write_il(self, dut, ports=dut.ports())
+
         m = Module()
         m.submodules.dut = dut
         if multi_in:
             for i in dut.i:
-                m.d.comb += i.eq(AnyConst(wid))
+                m.d.comb += i.eq(AnyConst(width))
         else:
-            m.d.comb += dut.i.eq(AnyConst(wid))
+            m.d.comb += dut.i.eq(AnyConst(width))
 
         prev_set = 0
         for o, en_o in zip(dut.o, dut.en_o):
@@ -293,14 +299,14 @@ class TestMultiPriorityPicker(FHDLTestCase):
 
             m.d.comb += Assert((o != 0) == en_o)
 
-        prev_set = Const(0, wid)
-        priority_pickers = [PriorityPicker(wid) for _ in range(levels)]
+        prev_set = Const(0, width)
+        priority_pickers = [PriorityPicker(width) for _ in range(levels)]
         for level in range(levels):
             pp = priority_pickers[level]
             setattr(m.submodules, f"pp_{level}", pp)
             inp = dut.i[level] if multi_in else dut.i
             m.d.comb += pp.i.eq(inp & ~prev_set)
-            cur_set = Signal(wid, name=f"cur_set_{level}")
+            cur_set = Signal(width, name=f"cur_set_{level}")
             m.d.comb += cur_set.eq(prev_set | pp.o)
             prev_set = cur_set
             m.d.comb += Assert(pp.o == dut.o[level])
@@ -314,148 +320,165 @@ class TestMultiPriorityPicker(FHDLTestCase):
         self.assertFormal(m)
 
     def test_4_levels_1_idxs_f_mi_f(self):
-        self.tst(wid=4, levels=1, indices=False, multi_in=False)
+        self.tst(width=4, levels=1, indices=False, multi_in=False)
 
     def test_4_levels_1_idxs_f_mi_t(self):
-        self.tst(wid=4, levels=1, indices=False, multi_in=True)
+        self.tst(width=4, levels=1, indices=False, multi_in=True)
 
     def test_4_levels_1_idxs_t_mi_f(self):
-        self.tst(wid=4, levels=1, indices=True, multi_in=False)
+        self.tst(width=4, levels=1, indices=True, multi_in=False)
 
     def test_4_levels_1_idxs_t_mi_t(self):
-        self.tst(wid=4, levels=1, indices=True, multi_in=True)
+        self.tst(width=4, levels=1, indices=True, multi_in=True)
 
     def test_4_levels_2_idxs_f_mi_f(self):
-        self.tst(wid=4, levels=2, indices=False, multi_in=False)
+        self.tst(width=4, levels=2, indices=False, multi_in=False)
 
     def test_4_levels_2_idxs_f_mi_t(self):
-        self.tst(wid=4, levels=2, indices=False, multi_in=True)
+        self.tst(width=4, levels=2, indices=False, multi_in=True)
 
     def test_4_levels_2_idxs_t_mi_f(self):
-        self.tst(wid=4, levels=2, indices=True, multi_in=False)
+        self.tst(width=4, levels=2, indices=True, multi_in=False)
 
     def test_4_levels_2_idxs_t_mi_t(self):
-        self.tst(wid=4, levels=2, indices=True, multi_in=True)
+        self.tst(width=4, levels=2, indices=True, multi_in=True)
 
     def test_4_levels_3_idxs_f_mi_f(self):
-        self.tst(wid=4, levels=3, indices=False, multi_in=False)
+        self.tst(width=4, levels=3, indices=False, multi_in=False)
 
     def test_4_levels_3_idxs_f_mi_t(self):
-        self.tst(wid=4, levels=3, indices=False, multi_in=True)
+        self.tst(width=4, levels=3, indices=False, multi_in=True)
 
     def test_4_levels_3_idxs_t_mi_f(self):
-        self.tst(wid=4, levels=3, indices=True, multi_in=False)
+        self.tst(width=4, levels=3, indices=True, multi_in=False)
 
     def test_4_levels_3_idxs_t_mi_t(self):
-        self.tst(wid=4, levels=3, indices=True, multi_in=True)
+        self.tst(width=4, levels=3, indices=True, multi_in=True)
 
     def test_4_levels_4_idxs_f_mi_f(self):
-        self.tst(wid=4, levels=4, indices=False, multi_in=False)
+        self.tst(width=4, levels=4, indices=False, multi_in=False)
 
     def test_4_levels_4_idxs_f_mi_t(self):
-        self.tst(wid=4, levels=4, indices=False, multi_in=True)
+        self.tst(width=4, levels=4, indices=False, multi_in=True)
 
     def test_4_levels_4_idxs_t_mi_f(self):
-        self.tst(wid=4, levels=4, indices=True, multi_in=False)
+        self.tst(width=4, levels=4, indices=True, multi_in=False)
 
     def test_4_levels_4_idxs_t_mi_t(self):
-        self.tst(wid=4, levels=4, indices=True, multi_in=True)
+        self.tst(width=4, levels=4, indices=True, multi_in=True)
 
     def test_8_levels_1_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=1, indices=False, multi_in=False)
+        self.tst(width=8, levels=1, indices=False, multi_in=False)
 
     def test_8_levels_1_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=1, indices=False, multi_in=True)
+        self.tst(width=8, levels=1, indices=False, multi_in=True)
 
     def test_8_levels_1_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=1, indices=True, multi_in=False)
+        self.tst(width=8, levels=1, indices=True, multi_in=False)
 
     def test_8_levels_1_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=1, indices=True, multi_in=True)
+        self.tst(width=8, levels=1, indices=True, multi_in=True)
 
     def test_8_levels_2_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=2, indices=False, multi_in=False)
+        self.tst(width=8, levels=2, indices=False, multi_in=False)
 
     def test_8_levels_2_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=2, indices=False, multi_in=True)
+        self.tst(width=8, levels=2, indices=False, multi_in=True)
 
     def test_8_levels_2_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=2, indices=True, multi_in=False)
+        self.tst(width=8, levels=2, indices=True, multi_in=False)
 
     def test_8_levels_2_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=2, indices=True, multi_in=True)
+        self.tst(width=8, levels=2, indices=True, multi_in=True)
 
     def test_8_levels_3_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=3, indices=False, multi_in=False)
+        self.tst(width=8, levels=3, indices=False, multi_in=False)
 
     def test_8_levels_3_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=3, indices=False, multi_in=True)
+        self.tst(width=8, levels=3, indices=False, multi_in=True)
 
     def test_8_levels_3_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=3, indices=True, multi_in=False)
+        self.tst(width=8, levels=3, indices=True, multi_in=False)
 
     def test_8_levels_3_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=3, indices=True, multi_in=True)
+        self.tst(width=8, levels=3, indices=True, multi_in=True)
 
     def test_8_levels_4_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=4, indices=False, multi_in=False)
+        self.tst(width=8, levels=4, indices=False, multi_in=False)
 
     def test_8_levels_4_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=4, indices=False, multi_in=True)
+        self.tst(width=8, levels=4, indices=False, multi_in=True)
 
     def test_8_levels_4_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=4, indices=True, multi_in=False)
+        self.tst(width=8, levels=4, indices=True, multi_in=False)
 
     def test_8_levels_4_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=4, indices=True, multi_in=True)
+        self.tst(width=8, levels=4, indices=True, multi_in=True)
 
     def test_8_levels_5_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=5, indices=False, multi_in=False)
+        self.tst(width=8, levels=5, indices=False, multi_in=False)
 
     def test_8_levels_5_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=5, indices=False, multi_in=True)
+        self.tst(width=8, levels=5, indices=False, multi_in=True)
 
     def test_8_levels_5_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=5, indices=True, multi_in=False)
+        self.tst(width=8, levels=5, indices=True, multi_in=False)
 
     def test_8_levels_5_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=5, indices=True, multi_in=True)
+        self.tst(width=8, levels=5, indices=True, multi_in=True)
 
     def test_8_levels_6_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=6, indices=False, multi_in=False)
+        self.tst(width=8, levels=6, indices=False, multi_in=False)
 
     def test_8_levels_6_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=6, indices=False, multi_in=True)
+        self.tst(width=8, levels=6, indices=False, multi_in=True)
 
     def test_8_levels_6_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=6, indices=True, multi_in=False)
+        self.tst(width=8, levels=6, indices=True, multi_in=False)
 
     def test_8_levels_6_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=6, indices=True, multi_in=True)
+        self.tst(width=8, levels=6, indices=True, multi_in=True)
 
     def test_8_levels_7_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=7, indices=False, multi_in=False)
+        self.tst(width=8, levels=7, indices=False, multi_in=False)
 
     def test_8_levels_7_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=7, indices=False, multi_in=True)
+        self.tst(width=8, levels=7, indices=False, multi_in=True)
 
     def test_8_levels_7_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=7, indices=True, multi_in=False)
+        self.tst(width=8, levels=7, indices=True, multi_in=False)
 
     def test_8_levels_7_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=7, indices=True, multi_in=True)
+        self.tst(width=8, levels=7, indices=True, multi_in=True)
 
     def test_8_levels_8_idxs_f_mi_f(self):
-        self.tst(wid=8, levels=8, indices=False, multi_in=False)
+        self.tst(width=8, levels=8, indices=False, multi_in=False)
 
     def test_8_levels_8_idxs_f_mi_t(self):
-        self.tst(wid=8, levels=8, indices=False, multi_in=True)
+        self.tst(width=8, levels=8, indices=False, multi_in=True)
 
     def test_8_levels_8_idxs_t_mi_f(self):
-        self.tst(wid=8, levels=8, indices=True, multi_in=False)
+        self.tst(width=8, levels=8, indices=True, multi_in=False)
 
     def test_8_levels_8_idxs_t_mi_t(self):
-        self.tst(wid=8, levels=8, indices=True, multi_in=True)
+        self.tst(width=8, levels=8, indices=True, multi_in=True)
+
+    def test_16_levels_16_idxs_f_mi_f(self):
+        self.tst(width=16, levels=16, indices=False, multi_in=False)
+
+
+class TestBetterMultiPriorityPicker(TestMultiPriorityPicker):
+    def make_dut(self, width, levels, indices, multi_in):
+        if multi_in:
+            self.skipTest(
+                "multi_in are not supported by BetterMultiPriorityPicker")
+        if indices:
+            self.skipTest(
+                "indices are not supported by BetterMultiPriorityPicker")
+        dut = BetterMultiPriorityPicker(width=width, levels=levels)
+        self.assertEqual(width, dut.width)
+        self.assertEqual(levels, dut.levels)
+        return dut
 
 
 if __name__ == "__main__":
index ab56741d0d5c5ff56324d8b818e746729f6df5e7..7aab175d8d60c031020b7b208727156123b8c09f 100644 (file)
 """
 
 from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux
+from nmigen.utils import bits_for
 from nmigen.cli import rtlil
 import math
+from nmutil.prefix_sum import prefix_sum
 
 
 class PriorityPicker(Elaboratable):
@@ -211,6 +213,64 @@ class MultiPriorityPicker(Elaboratable):
         return list(self)
 
 
+class BetterMultiPriorityPicker(Elaboratable):
+    """A better replacement for MultiPriorityPicker that has O(log levels)
+        latency, rather than > O(levels) latency.
+    """
+
+    def __init__(self, width, levels, *, work_efficient=False):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(work_efficient, bool)
+        self.width = width
+        self.levels = levels
+        self.work_efficient = work_efficient
+        assert self.__index_sat > self.levels - 1
+        self.i = Signal(width)
+        self.o = [Signal(width, name=f"o_{i}") for i in range(levels)]
+        self.en_o = Signal(levels)
+
+    @property
+    def __index_width(self):
+        return bits_for(self.levels)
+
+    @property
+    def __index_sat(self):
+        return (1 << self.__index_width) - 1
+
+    def elaborate(self, platform):
+        m = Module()
+
+        def sat_add(a, b):
+            sum = Signal(self.__index_width + 1)
+            m.d.comb += sum.eq(a + b)
+            retval = Signal(self.__index_width)
+            m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum))
+            return retval
+        indexes = prefix_sum((self.i[i] for i in range(self.width - 1)),
+                             sat_add, work_efficient=self.work_efficient)
+        indexes.insert(0, 0)
+        for i in range(self.width):
+            sig = Signal(self.__index_width, name=f"index_{i}")
+            m.d.comb += sig.eq(indexes[i])
+            indexes[i] = sig
+        for level in range(self.levels):
+            m.d.comb += self.en_o[level].eq(self.o[level].bool())
+            for i in range(self.width):
+                index_matches = indexes[i] == level
+                m.d.comb += self.o[level][i].eq(index_matches & self.i[i])
+
+        return m
+
+    def __iter__(self):
+        yield self.i
+        yield from self.o
+        yield self.en_o
+
+    def ports(self):
+        return list(self)
+
+
 if __name__ == '__main__':
     dut = PriorityPicker(16)
     vl = rtlil.convert(dut, ports=dut.ports())