speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / picker.py
index 844e7b0a4d80089082708e1427ff890d584b7c29..7aab175d8d60c031020b7b208727156123b8c09f 100644 (file)
@@ -1,44 +1,94 @@
-""" Priority Picker: optimised back-to-back PriorityEncoder and Decoder
+# SPDX-License-Identifier: LGPL-3-or-later
+""" Priority Picker: optimized back-to-back PriorityEncoder and Decoder
     and MultiPriorityPicker: cascading mutually-exclusive pickers
 
-    The input is N bits, the output is N bits wide and only one is
-    enabled.
+    This work is funded through NLnet under Grant 2019-02-012
+
+    License: LGPLv3+
+
+
+    PriorityPicker: the input is N bits, the output is N bits wide and
+    only one is enabled.
+
+    MultiPriorityPicker: likewise except that there are M pickers and
+    each output is guaranteed mutually exclusive.  Optionally:
+    an "index" (and enable line) is also outputted.
+
+    MultiPriorityPicker is designed for port-selection, when there are
+    multiple "things" (of width N) contending for access to M "ports".
+    When the M=0 "thing" requests a port, it gets allocated port 0
+    (always).  However if the M=0 "thing" does *not* request a port,
+    this gives the M=1 "thing" the opportunity to gain access to port 0.
+
+    Given that N may potentially be much greater than M (16 bits wide
+    where M may be e.g. only 4) we can't just ok, "ok so M=N therefore
+    M=0 gets access to port 0, M=1 gets access to port 1" etc.
 """
 
-from nmigen import Module, Signal, Cat, Elaboratable, Array, Const
-from nmigen.cli import verilog, rtlil
+from nmigen import Module, Signal, Cat, Elaboratable, Array, Const, Mux
+from nmigen.utils import bits_for
+from nmigen.cli import rtlil
+import math
+from nmutil.prefix_sum import prefix_sum
+
 
 class PriorityPicker(Elaboratable):
     """ implements a priority-picker.  input: N bits, output: N bits
+
+        * msb_mode is for a MSB-priority picker
+        * reverse_i=True is for convenient reversal of the input bits
+        * reverse_o=True is for convenient reversal of the output bits
+        * `msb_mode=True` is redundant with `reverse_i=True, reverse_o=True`
+            but is allowed for backwards compatibility.
     """
-    def __init__(self, wid):
+
+    def __init__(self, wid, msb_mode=False, reverse_i=False, reverse_o=False):
         self.wid = wid
         # inputs
+        self.msb_mode = msb_mode
+        self.reverse_i = reverse_i
+        self.reverse_o = reverse_o
         self.i = Signal(wid, reset_less=True)
         self.o = Signal(wid, reset_less=True)
 
+        self.en_o = Signal(reset_less=True)
+        "true if any output is true"
+
     def elaborate(self, platform):
         m = Module()
 
+        # works by saying, "if all previous bits were zero, we get a chance"
         res = []
-        ni = Signal(self.wid, reset_less = True)
-        m.d.comb += ni.eq(~self.i)
-        for i in range(0, self.wid):
-            t = Signal(reset_less = True)
+        ni = Signal(self.wid, reset_less=True)
+        i = list(self.i)
+        if self.reverse_i:
+            i.reverse()
+        if self.msb_mode:
+            i.reverse()
+        m.d.comb += ni.eq(~Cat(*i))
+        prange = list(range(0, self.wid))
+        if self.msb_mode:
+            prange.reverse()
+        for n in prange:
+            t = Signal(name="t%d" % n, reset_less=True)
             res.append(t)
-            if i == 0:
-                m.d.comb += t.eq(self.i[i])
+            if n == 0:
+                m.d.comb += t.eq(i[n])
             else:
-                m.d.comb += t.eq(~Cat(ni[i], *self.i[:i]).bool())
-
+                m.d.comb += t.eq(~Cat(ni[n], *i[:n]).bool())
+        if self.reverse_o:
+            res.reverse()
         # we like Cat(*xxx).  turn lists into concatenated bits
         m.d.comb += self.o.eq(Cat(*res))
+        # useful "is any output enabled" signal
+        m.d.comb += self.en_o.eq(self.o.bool())  # true if 1 input is true
 
         return m
 
     def __iter__(self):
         yield self.i
         yield self.o
+        yield self.en_o
 
     def ports(self):
         return list(self)
@@ -52,31 +102,67 @@ class MultiPriorityPicker(Elaboratable):
         gets top priority, the second cannot have the same bit that
         the first has set, and so on.  To do this, a "mask" accumulates
         the output from the chain, masking the input to the next chain.
+
+        Also outputted (optional): an index for each picked "thing".
     """
-    def __init__(self, wid, levels):
+
+    def __init__(self, wid, levels, indices=False, multi_in=False):
         self.levels = levels
         self.wid = wid
+        self.indices = indices
+        self.multi_in = multi_in
 
-        self.i = [] # store the array of picker inputs
-        self.o = [] # store the array of picker outputs
+        if multi_in:
+            # multiple inputs, multiple outputs.
+            i_l = []  # array of picker outputs
+            for j in range(self.levels):
+                i = Signal(self.wid, name="i_%d" % j, reset_less=True)
+                i_l.append(i)
+            self.i = Array(i_l)
+        else:
+            # only the one input, but multiple (single) bit outputs
+            self.i = Signal(self.wid, reset_less=True)
 
+        # create array of (single-bit) outputs (unary)
+        o_l = []  # array of picker outputs
         for j in range(self.levels):
-            i = Signal(self.wid, name="i_%d" % j, reset_less=True)
             o = Signal(self.wid, name="o_%d" % j, reset_less=True)
-            self.i.append(i)
-            self.o.append(o)
-        self.i = Array(self.i)
-        self.o = Array(self.o)
+            o_l.append(o)
+        self.o = Array(o_l)
+
+        # add an array of "enables"
+        self.en_o = Signal(self.levels, name="en_o", reset_less=True)
+
+        if not self.indices:
+            return
+
+        # add an array of indices
+        lidx = math.ceil(math.log2(self.levels))
+        idx_o = []  # store the array of indices
+        for j in range(self.levels):
+            i = Signal(lidx, name="idxo_%d" % j, reset_less=True)
+            idx_o.append(i)
+        self.idx_o = Array(idx_o)
 
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
 
+        # create Priority Pickers, accumulate their outputs and prevent
+        # the next one in the chain from selecting that output bit.
+        # the input from the current picker will be "masked" and connected
+        # to the *next* picker on the next loop
         prev_pp = None
         p_mask = None
+        pp_l = []
         for j in range(self.levels):
-            o, i = self.o[j], self.i[j]
+            if self.multi_in:
+                i = self.i[j]
+            else:
+                i = self.i
+            o = self.o[j]
             pp = PriorityPicker(self.wid)
+            pp_l.append(pp)
             setattr(m.submodules, "pp%d" % j, pp)
             comb += o.eq(pp.o)
             if prev_pp is None:
@@ -84,23 +170,117 @@ class MultiPriorityPicker(Elaboratable):
                 p_mask = Const(0, self.wid)
             else:
                 mask = Signal(self.wid, name="m_%d" % j, reset_less=True)
-                comb += mask.eq(prev_pp.o | p_mask) # accumulate output bits
+                comb += mask.eq(prev_pp.o | p_mask)  # accumulate output bits
                 comb += pp.i.eq(i & ~mask)          # mask out input
                 p_mask = mask
+            i = pp.i  # for input to next round
             prev_pp = pp
 
+        # accumulate the enables
+        en_l = []
+        for j in range(self.levels):
+            en_l.append(pp_l[j].en_o)
+        # concat accumulated enable bits
+        comb += self.en_o.eq(Cat(*en_l))
+
+        if not self.indices:
+            return m
+
+        # for each picker enabled, pass that out and set a cascading index
+        lidx = math.ceil(math.log2(self.levels))
+        prev_count = 0
+        for j in range(self.levels):
+            en_o = pp_l[j].en_o
+            count1 = Signal(lidx, name="count_%d" % j, reset_less=True)
+            comb += count1.eq(prev_count + Const(1, lidx))
+            comb += self.idx_o[j].eq(prev_count)
+            prev_count = Mux(en_o, count1, prev_count)
+
         return m
 
     def __iter__(self):
-        yield from self.i
+        if self.multi_in:
+            yield from self.i
+        else:
+            yield self.i
         yield from self.o
+        yield self.en_o
+        if not self.indices:
+            return
+        yield from self.idx_o
+
+    def ports(self):
+        return list(self)
+
+
+class BetterMultiPriorityPicker(Elaboratable):
+    """A better replacement for MultiPriorityPicker that has O(log levels)
+        latency, rather than > O(levels) latency.
+    """
+
+    def __init__(self, width, levels, *, work_efficient=False):
+        assert isinstance(width, int) and width >= 1
+        assert isinstance(levels, int) and 1 <= levels <= width
+        assert isinstance(work_efficient, bool)
+        self.width = width
+        self.levels = levels
+        self.work_efficient = work_efficient
+        assert self.__index_sat > self.levels - 1
+        self.i = Signal(width)
+        self.o = [Signal(width, name=f"o_{i}") for i in range(levels)]
+        self.en_o = Signal(levels)
+
+    @property
+    def __index_width(self):
+        return bits_for(self.levels)
+
+    @property
+    def __index_sat(self):
+        return (1 << self.__index_width) - 1
+
+    def elaborate(self, platform):
+        m = Module()
+
+        def sat_add(a, b):
+            sum = Signal(self.__index_width + 1)
+            m.d.comb += sum.eq(a + b)
+            retval = Signal(self.__index_width)
+            m.d.comb += retval.eq(Mux(sum[-1], self.__index_sat, sum))
+            return retval
+        indexes = prefix_sum((self.i[i] for i in range(self.width - 1)),
+                             sat_add, work_efficient=self.work_efficient)
+        indexes.insert(0, 0)
+        for i in range(self.width):
+            sig = Signal(self.__index_width, name=f"index_{i}")
+            m.d.comb += sig.eq(indexes[i])
+            indexes[i] = sig
+        for level in range(self.levels):
+            m.d.comb += self.en_o[level].eq(self.o[level].bool())
+            for i in range(self.width):
+                index_matches = indexes[i] == level
+                m.d.comb += self.o[level][i].eq(index_matches & self.i[i])
+
+        return m
+
+    def __iter__(self):
+        yield self.i
+        yield from self.o
+        yield self.en_o
 
     def ports(self):
         return list(self)
 
 
 if __name__ == '__main__':
-    dut = MultiPriorityPicker(5, 4)
+    dut = PriorityPicker(16)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_picker.il", "w") as f:
+        f.write(vl)
+    dut = MultiPriorityPicker(5, 4, True)
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("test_multi_picker.il", "w") as f:
         f.write(vl)
+    dut = MultiPriorityPicker(5, 4, False, True)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_multi_picker_noidx.il", "w") as f:
+        f.write(vl)