speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / lut.py
index 261a1d1bcbfc1a7fab2e9c103667b38e530c5bc0..755747ab2073dbf1a7620f9ac31e592b2bf63a44 100644 (file)
@@ -1,7 +1,9 @@
 # SPDX-License-Identifier: LGPL-3-or-later
-# TODO: Copyright notice (standard style, plenty of examples)
+# Copyright 2021 Jacob Lifshay
 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
-# TODO: credits to NLnet for funding
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
 
 """Bitwise logic operators implemented using a look-up table, like LUTs in
 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
@@ -14,6 +16,7 @@ from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 from nmigen.cli import rtlil
+from nmutil.plain_data import plain_data
 
 
 class BitwiseMux(Elaboratable):
@@ -80,11 +83,79 @@ class BitwiseLut(Elaboratable):
         return list(self.inputs) + [self.lut, self.output]
 
 
+@plain_data()
+class _TreeMuxNode:
+    """Mux in tree for `TreeBitwiseLut`.
+
+    Attributes:
+    out: Signal
+    container: TreeBitwiseLut
+    parent: _TreeMuxNode | None
+    child0: _TreeMuxNode | None
+    child1: _TreeMuxNode | None
+    depth: int
+    """
+    __slots__ = "out", "container", "parent", "child0", "child1", "depth"
+
+    def __init__(self, out, container, parent, child0, child1, depth):
+        """ Arguments:
+        out: Signal
+        container: TreeBitwiseLut
+        parent: _TreeMuxNode | None
+        child0: _TreeMuxNode | None
+        child1: _TreeMuxNode | None
+        depth: int
+        """
+        self.out = out
+        self.container = container
+        self.parent = parent
+        self.child0 = child0
+        self.child1 = child1
+        self.depth = depth
+
+    @property
+    def child_index(self):
+        """index of this node, when looked up in this node's parent's children.
+        """
+        if self.parent is None:
+            return None
+        return int(self.parent.child1 is self)
+
+    def add_child(self, child_index):
+        node = _TreeMuxNode(
+            out=Signal(self.container.width),
+            container=self.container, parent=self,
+            child0=None, child1=None, depth=1 + self.depth)
+        if child_index:
+            assert self.child1 is None
+            self.child1 = node
+        else:
+            assert self.child0 is None
+            self.child0 = node
+        node.out.name = "node_out_" + node.key_str
+        return node
+
+    @property
+    def key(self):
+        retval = []
+        node = self
+        while node.parent is not None:
+            retval.append(node.child_index)
+            node = node.parent
+        retval.reverse()
+        return retval
+
+    @property
+    def key_str(self):
+        k = ['x'] * self.container.input_count
+        for i, v in enumerate(self.key):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+
 class TreeBitwiseLut(Elaboratable):
-    """Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
-    (good enough reason to say "see bitwiselut", but mention that
-    the API is identical and explain why the second implementation
-    exists, despite it being identical)
+    """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+    for API documentation. This version may produce more efficient hardware.
     """
 
     def __init__(self, input_count, width):
@@ -96,55 +167,47 @@ class TreeBitwiseLut(Elaboratable):
         self.inputs = tuple(inp(i) for i in range(input_count))
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
-        self._mux_inputs = {}
-        self._build_mux_inputs()
-
-    def _make_key_str(self, *sel_values):
-        k = ['x'] * self.input_count
-        for i, v in enumerate(sel_values):
-            k[i] = '1' if v else '0'
-        return '0b' + ''.join(reversed(k))
-
-    def _build_mux_inputs(self, *sel_values):
-        # XXX yyyeah using PHP-style functions-in-text... blech :)
-        # XXX replace with name = mux_input_%s" % self._make_etcetc
-        name = f"mux_input_{self._make_key_str(*sel_values)}"
-        self._mux_inputs[sel_values] = Signal(self.width, name=name)
-        if len(sel_values) < self.input_count:
-            self._build_mux_inputs(*sel_values, False)
-            self._build_mux_inputs(*sel_values, True)
+        self._tree_root = _TreeMuxNode(
+            out=self.output, container=self, parent=None,
+            child0=None, child1=None, depth=0)
+        self._build_tree(self._tree_root)
+
+    def _build_tree(self, node):
+        if node.depth < self.input_count:
+            self._build_tree(node.add_child(0))
+            self._build_tree(node.add_child(1))
+
+    def _elaborate_tree(self, m, node):
+        if node.depth < self.input_count:
+            mux = BitwiseMux(self.width)
+            setattr(m.submodules, "mux_" + node.key_str, mux)
+            m.d.comb += [
+                mux.f.eq(node.child0.out),
+                mux.t.eq(node.child1.out),
+                mux.sel.eq(self.inputs[node.depth]),
+                node.out.eq(mux.output),
+            ]
+            self._elaborate_tree(m, node.child0)
+            self._elaborate_tree(m, node.child1)
+        else:
+            index = int(node.key_str, base=2)
+            m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.output.eq(self._mux_inputs[()])
-        for sel_values, v in self._mux_inputs.items():
-            if len(sel_values) < self.input_count:
-                # XXX yyyeah using PHP-style functions-in-text... blech :)
-                # XXX replace with name = mux_input_%s" % self._make_etcetc
-                mux_name = f"mux_{self._make_key_str(*sel_values)}"
-                mux = BitwiseMux(self.width)
-                setattr(m.submodules, mux_name, mux)
-                m.d.comb += [
-                    mux.f.eq(self._mux_inputs[(*sel_values, False)]),
-                    mux.t.eq(self._mux_inputs[(*sel_values, True)]),
-                    mux.sel.eq(self.inputs[len(sel_values)]),
-                    v.eq(mux.output),
-                ]
-            else:
-                lut_index = 0
-                for i in range(self.input_count):
-                    if sel_values[i]:
-                        lut_index |= 2 ** i
-                m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+        self._elaborate_tree(m, self._tree_root)
         return m
 
     def ports(self):
-        return [self.input, self.chunk_sizes, self.output]
+        return [*self.inputs, self.lut, self.output]
+
 
+# useful to see what is going on:
+# python3 src/nmutil/test/test_lut.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"
 
-# useful to see what is going on: use yosys "read_ilang test_lut.il; show top"
 if __name__ == '__main__':
-    dut = BitwiseLut(3, 8)
+    dut = BitwiseLut(2, 64)
     vl = rtlil.convert(dut, ports=dut.ports())
-    with open("test_lut.il", "w") as f:
+    with open("test_lut2.il", "w") as f:
         f.write(vl)