rewrite TreeBitwiseLut to actually use a tree rather than a dict, hopefully making...
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 22 Dec 2021 04:02:12 +0000 (20:02 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 22 Dec 2021 04:02:12 +0000 (20:02 -0800)
src/nmutil/lut.py
src/nmutil/test/test_lut.py

index 4eb790d4e11c843aec4e4678735b15fd9f777e05..49c19a876983f9f150251c2c2254b8874fb56aac 100644 (file)
@@ -16,6 +16,7 @@ from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 from nmigen.cli import rtlil
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 from nmigen.cli import rtlil
+from dataclasses import dataclass
 
 
 class BitwiseMux(Elaboratable):
 
 
 class BitwiseMux(Elaboratable):
@@ -82,11 +83,59 @@ class BitwiseLut(Elaboratable):
         return list(self.inputs) + [self.lut, self.output]
 
 
         return list(self.inputs) + [self.lut, self.output]
 
 
+@dataclass
+class _TreeMuxNode:
+    """Mux in tree for `TreeBitwiseLut`."""
+    out: Signal
+    container: "TreeBitwiseLut"
+    parent: "_TreeMuxNode | None"
+    child0: "_TreeMuxNode | None"
+    child1: "_TreeMuxNode | None"
+    depth: int
+
+    @property
+    def child_index(self):
+        """index of this node, when looked up in this node's parent's children.
+        """
+        if self.parent is None:
+            return None
+        return int(self.parent.child1 is self)
+
+    def add_child(self, child_index):
+        node = _TreeMuxNode(
+            out=Signal(self.container.width),
+            container=self.container, parent=self,
+            child0=None, child1=None, depth=1 + self.depth)
+        if child_index:
+            assert self.child1 is None
+            self.child1 = node
+        else:
+            assert self.child0 is None
+            self.child0 = node
+        node.out.name = "node_out_" + node.key_str
+        return node
+
+    @property
+    def key(self):
+        retval = []
+        node = self
+        while node.parent is not None:
+            retval.append(node.child_index)
+            node = node.parent
+        retval.reverse()
+        return retval
+
+    @property
+    def key_str(self):
+        k = ['x'] * self.container.input_count
+        for i, v in enumerate(self.key):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+
 class TreeBitwiseLut(Elaboratable):
 class TreeBitwiseLut(Elaboratable):
-    """Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
-    (good enough reason to say "see bitwiselut", but mention that
-    the API is identical and explain why the second implementation
-    exists, despite it being identical)
+    """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+    for API documentation. This version may produce more efficient hardware.
     """
 
     def __init__(self, input_count, width):
     """
 
     def __init__(self, input_count, width):
@@ -98,50 +147,39 @@ class TreeBitwiseLut(Elaboratable):
         self.inputs = tuple(inp(i) for i in range(input_count))
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
         self.inputs = tuple(inp(i) for i in range(input_count))
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
-        self._mux_inputs = {}
-        self._build_mux_inputs()
-
-    def _make_key_str(self, *sel_values):
-        k = ['x'] * self.input_count
-        for i, v in enumerate(sel_values):
-            k[i] = '1' if v else '0'
-        return '0b' + ''.join(reversed(k))
-
-    def _build_mux_inputs(self, *sel_values):
-        # XXX yyyeah using PHP-style functions-in-text... blech :)
-        # XXX replace with name = mux_input_%s" % self._make_etcetc
-        name = f"mux_input_{self._make_key_str(*sel_values)}"
-        self._mux_inputs[sel_values] = Signal(self.width, name=name)
-        if len(sel_values) < self.input_count:
-            self._build_mux_inputs(*sel_values, False)
-            self._build_mux_inputs(*sel_values, True)
+        self._tree_root = _TreeMuxNode(
+            out=self.output, container=self, parent=None,
+            child0=None, child1=None, depth=0)
+        self._build_tree(self._tree_root)
+
+    def _build_tree(self, node):
+        if node.depth < self.input_count:
+            self._build_tree(node.add_child(0))
+            self._build_tree(node.add_child(1))
+
+    def _elaborate_tree(self, m, node):
+        if node.depth < self.input_count:
+            mux = BitwiseMux(self.width)
+            setattr(m.submodules, "mux_" + node.key_str, mux)
+            m.d.comb += [
+                mux.f.eq(node.child0.out),
+                mux.t.eq(node.child1.out),
+                mux.sel.eq(self.inputs[node.depth]),
+                node.out.eq(mux.output),
+            ]
+            self._elaborate_tree(m, node.child0)
+            self._elaborate_tree(m, node.child1)
+        else:
+            index = int(node.key_str, base=2)
+            m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.output.eq(self._mux_inputs[()])
-        for sel_values, v in self._mux_inputs.items():
-            if len(sel_values) < self.input_count:
-                # XXX yyyeah using PHP-style functions-in-text... blech :)
-                # XXX replace with name = mux_input_%s" % self._make_etcetc
-                mux_name = f"mux_{self._make_key_str(*sel_values)}"
-                mux = BitwiseMux(self.width)
-                setattr(m.submodules, mux_name, mux)
-                m.d.comb += [
-                    mux.f.eq(self._mux_inputs[(*sel_values, False)]),
-                    mux.t.eq(self._mux_inputs[(*sel_values, True)]),
-                    mux.sel.eq(self.inputs[len(sel_values)]),
-                    v.eq(mux.output),
-                ]
-            else:
-                lut_index = 0
-                for i in range(self.input_count):
-                    if sel_values[i]:
-                        lut_index |= 2 ** i
-                m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+        self._elaborate_tree(m, self._tree_root)
         return m
 
     def ports(self):
         return m
 
     def ports(self):
-        return [self.input, self.chunk_sizes, self.output]
+        return [*self.inputs, self.lut, self.output]
 
 
 # useful to see what is going on:
 
 
 # useful to see what is going on:
index cdfaa2db1a0bf2d063688bff14af90672ac2ca6e..e0a98099460ded8912299b05c513dc0f924005d7 100644 (file)
@@ -65,25 +65,6 @@ class TestBitwiseLut(FHDLTestCase):
         dut = cls(3, 16)
         mask = 2 ** dut.width - 1
         lut_mask = 2 ** dut.lut.width - 1
         dut = cls(3, 16)
         mask = 2 ** dut.width - 1
         lut_mask = 2 ** dut.lut.width - 1
-        if cls is TreeBitwiseLut:
-            mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
-            self.assertEqual(mux_inputs, {
-                (): 'mux_input_0bxxx',
-                (False,): 'mux_input_0bxx0',
-                (False, False): 'mux_input_0bx00',
-                (False, False, False): 'mux_input_0b000',
-                (False, False, True): 'mux_input_0b100',
-                (False, True): 'mux_input_0bx10',
-                (False, True, False): 'mux_input_0b010',
-                (False, True, True): 'mux_input_0b110',
-                (True,): 'mux_input_0bxx1',
-                (True, False): 'mux_input_0bx01',
-                (True, False, False): 'mux_input_0b001',
-                (True, False, True): 'mux_input_0b101',
-                (True, True): 'mux_input_0bx11',
-                (True, True, False): 'mux_input_0b011',
-                (True, True, True): 'mux_input_0b111'
-            })
 
         def case(in0, in1, in2, lut):
             expected = 0
 
         def case(in0, in1, in2, lut):
             expected = 0
@@ -109,6 +90,10 @@ class TestBitwiseLut(FHDLTestCase):
                     self.assertEqual(expected, output)
 
         def process():
                     self.assertEqual(expected, output)
 
         def process():
+            for shift in range(dut.lut.width):
+                with self.subTest(shift=shift):
+                    yield from case(in0=0xAAAA, in1=0xCCCC, in2=0xF0F0,
+                                    lut=1 << shift)
             for case_index in range(100):
                 with self.subTest(case_index=case_index):
                     in0 = hash_256(f"{case_index} in0") & mask
             for case_index in range(100):
                 with self.subTest(case_index=case_index):
                     in0 = hash_256(f"{case_index} in0") & mask