speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / picker.py
index ab56741d0d5c5ff56324d8b818e746729f6df5e7..7aab175d8d60c031020b7b208727156123b8c09f 100644 (file)
 """
 
 from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux
+from nmigen.utils import bits_for
 from nmigen.cli import rtlil
 import math
+from nmutil.prefix_sum import prefix_sum
 
 
 class PriorityPicker(Elaboratable):
@@ -211,6 +213,64 @@ class MultiPriorityPicker(Elaboratable):
         return list(self)
 
 
+class BetterMultiPriorityPicker(Elaboratable):
+    """A better replacement for MultiPriorityPicker that has O(log levels)
+        latency, rather than > O(levels) latency.
+    """
+
+    def __init__(self, width, levels, *, work_efficient=False):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(work_efficient, bool)
+        self.width = width
+        self.levels = levels
+        self.work_efficient = work_efficient
+        assert self.__index_sat > self.levels - 1
+        self.i = Signal(width)
+        self.o = [Signal(width, name=f"o_{i}") for i in range(levels)]
+        self.en_o = Signal(levels)
+
+    @property
+    def __index_width(self):
+        return bits_for(self.levels)
+
+    @property
+    def __index_sat(self):
+        return (1 << self.__index_width) - 1
+
+    def elaborate(self, platform):
+        m = Module()
+
+        def sat_add(a, b):
+            sum = Signal(self.__index_width + 1)
+            m.d.comb += sum.eq(a + b)
+            retval = Signal(self.__index_width)
+            m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum))
+            return retval
+        indexes = prefix_sum((self.i[i] for i in range(self.width - 1)),
+                             sat_add, work_efficient=self.work_efficient)
+        indexes.insert(0, 0)
+        for i in range(self.width):
+            sig = Signal(self.__index_width, name=f"index_{i}")
+            m.d.comb += sig.eq(indexes[i])
+            indexes[i] = sig
+        for level in range(self.levels):
+            m.d.comb += self.en_o[level].eq(self.o[level].bool())
+            for i in range(self.width):
+                index_matches = indexes[i] == level
+                m.d.comb += self.o[level][i].eq(index_matches & self.i[i])
+
+        return m
+
+    def __iter__(self):
+        yield self.i
+        yield from self.o
+        yield self.en_o
+
+    def ports(self):
+        return list(self)
+
+
 if __name__ == '__main__':
     dut = PriorityPicker(16)
     vl = rtlil.convert(dut, ports=dut.ports())