rewrite TreeBitwiseLut to actually use a tree rather than a dict, hopefully making...
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 22 Dec 2021 04:02:12 +0000 (20:02 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 22 Dec 2021 04:02:12 +0000 (20:02 -0800)
src/nmutil/lut.py
src/nmutil/test/test_lut.py

index 4eb790d4e11c843aec4e4678735b15fd9f777e05..49c19a876983f9f150251c2c2254b8874fb56aac 100644 (file)
@@ -16,6 +16,7 @@ from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 from nmigen.cli import rtlil
+from dataclasses import dataclass
 
 
 class BitwiseMux(Elaboratable):
@@ -82,11 +83,59 @@ class BitwiseLut(Elaboratable):
         return list(self.inputs) + [self.lut, self.output]
 
 
+@dataclass
+class _TreeMuxNode:
+    """Mux in tree for `TreeBitwiseLut`."""
+    out: Signal
+    container: "TreeBitwiseLut"
+    parent: "_TreeMuxNode | None"
+    child0: "_TreeMuxNode | None"
+    child1: "_TreeMuxNode | None"
+    depth: int
+
+    @property
+    def child_index(self):
+        """index of this node, when looked up in this node's parent's children.
+        """
+        if self.parent is None:
+            return None
+        return int(self.parent.child1 is self)
+
+    def add_child(self, child_index):
+        node = _TreeMuxNode(
+            out=Signal(self.container.width),
+            container=self.container, parent=self,
+            child0=None, child1=None, depth=1 + self.depth)
+        if child_index:
+            assert self.child1 is None
+            self.child1 = node
+        else:
+            assert self.child0 is None
+            self.child0 = node
+        node.out.name = "node_out_" + node.key_str
+        return node
+
+    @property
+    def key(self):
+        retval = []
+        node = self
+        while node.parent is not None:
+            retval.append(node.child_index)
+            node = node.parent
+        retval.reverse()
+        return retval
+
+    @property
+    def key_str(self):
+        k = ['x'] * self.container.input_count
+        for i, v in enumerate(self.key):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+
 class TreeBitwiseLut(Elaboratable):
-    """Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
-    (good enough reason to say "see bitwiselut", but mention that
-    the API is identical and explain why the second implementation
-    exists, despite it being identical)
+    """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+    for API documentation. This version may produce more efficient hardware.
     """
 
     def __init__(self, input_count, width):
@@ -98,50 +147,39 @@ class TreeBitwiseLut(Elaboratable):
         self.inputs = tuple(inp(i) for i in range(input_count))
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
-        self._mux_inputs = {}
-        self._build_mux_inputs()
-
-    def _make_key_str(self, *sel_values):
-        k = ['x'] * self.input_count
-        for i, v in enumerate(sel_values):
-            k[i] = '1' if v else '0'
-        return '0b' + ''.join(reversed(k))
-
-    def _build_mux_inputs(self, *sel_values):
-        # XXX yyyeah using PHP-style functions-in-text... blech :)
-        # XXX replace with name = mux_input_%s" % self._make_etcetc
-        name = f"mux_input_{self._make_key_str(*sel_values)}"
-        self._mux_inputs[sel_values] = Signal(self.width, name=name)
-        if len(sel_values) < self.input_count:
-            self._build_mux_inputs(*sel_values, False)
-            self._build_mux_inputs(*sel_values, True)
+        self._tree_root = _TreeMuxNode(
+            out=self.output, container=self, parent=None,
+            child0=None, child1=None, depth=0)
+        self._build_tree(self._tree_root)
+
+    def _build_tree(self, node):
+        if node.depth < self.input_count:
+            self._build_tree(node.add_child(0))
+            self._build_tree(node.add_child(1))
+
+    def _elaborate_tree(self, m, node):
+        if node.depth < self.input_count:
+            mux = BitwiseMux(self.width)
+            setattr(m.submodules, "mux_" + node.key_str, mux)
+            m.d.comb += [
+                mux.f.eq(node.child0.out),
+                mux.t.eq(node.child1.out),
+                mux.sel.eq(self.inputs[node.depth]),
+                node.out.eq(mux.output),
+            ]
+            self._elaborate_tree(m, node.child0)
+            self._elaborate_tree(m, node.child1)
+        else:
+            index = int(node.key_str, base=2)
+            m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.output.eq(self._mux_inputs[()])
-        for sel_values, v in self._mux_inputs.items():
-            if len(sel_values) < self.input_count:
-                # XXX yyyeah using PHP-style functions-in-text... blech :)
-                # XXX replace with name = mux_input_%s" % self._make_etcetc
-                mux_name = f"mux_{self._make_key_str(*sel_values)}"
-                mux = BitwiseMux(self.width)
-                setattr(m.submodules, mux_name, mux)
-                m.d.comb += [
-                    mux.f.eq(self._mux_inputs[(*sel_values, False)]),
-                    mux.t.eq(self._mux_inputs[(*sel_values, True)]),
-                    mux.sel.eq(self.inputs[len(sel_values)]),
-                    v.eq(mux.output),
-                ]
-            else:
-                lut_index = 0
-                for i in range(self.input_count):
-                    if sel_values[i]:
-                        lut_index |= 2 ** i
-                m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+        self._elaborate_tree(m, self._tree_root)
         return m
 
     def ports(self):
-        return [self.input, self.chunk_sizes, self.output]
+        return [*self.inputs, self.lut, self.output]
 
 
 # useful to see what is going on:
index cdfaa2db1a0bf2d063688bff14af90672ac2ca6e..e0a98099460ded8912299b05c513dc0f924005d7 100644 (file)
@@ -65,25 +65,6 @@ class TestBitwiseLut(FHDLTestCase):
         dut = cls(3, 16)
         mask = 2 ** dut.width - 1
         lut_mask = 2 ** dut.lut.width - 1
-        if cls is TreeBitwiseLut:
-            mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
-            self.assertEqual(mux_inputs, {
-                (): 'mux_input_0bxxx',
-                (False,): 'mux_input_0bxx0',
-                (False, False): 'mux_input_0bx00',
-                (False, False, False): 'mux_input_0b000',
-                (False, False, True): 'mux_input_0b100',
-                (False, True): 'mux_input_0bx10',
-                (False, True, False): 'mux_input_0b010',
-                (False, True, True): 'mux_input_0b110',
-                (True,): 'mux_input_0bxx1',
-                (True, False): 'mux_input_0bx01',
-                (True, False, False): 'mux_input_0b001',
-                (True, False, True): 'mux_input_0b101',
-                (True, True): 'mux_input_0bx11',
-                (True, True, False): 'mux_input_0b011',
-                (True, True, True): 'mux_input_0b111'
-            })
 
         def case(in0, in1, in2, lut):
             expected = 0
@@ -109,6 +90,10 @@ class TestBitwiseLut(FHDLTestCase):
                     self.assertEqual(expected, output)
 
         def process():
+            for shift in range(dut.lut.width):
+                with self.subTest(shift=shift):
+                    yield from case(in0=0xAAAA, in1=0xCCCC, in2=0xF0F0,
+                                    lut=1 << shift)
             for case_index in range(100):
                 with self.subTest(case_index=case_index):
                     in0 = hash_256(f"{case_index} in0") & mask