speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / lut.py
index 5705d3d96e21c050b229e65b9b88d249dbdfe356..755747ab2073dbf1a7620f9ac31e592b2bf63a44 100644 (file)
@@ -1,13 +1,26 @@
 # SPDX-License-Identifier: LGPL-3-or-later
-# See Notices.txt for copyright information
+# Copyright 2021 Jacob Lifshay
+# Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
 
-from nmigen.hdl.ast import Repl, Signal
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+"""Bitwise logic operators implemented using a look-up table, like LUTs in
+FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
+
+https://bugs.libre-soc.org/show_bug.cgi?id=745
+https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
+"""
+
+from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
+from nmigen.cli import rtlil
+from nmutil.plain_data import plain_data
 
 
 class BitwiseMux(Elaboratable):
-    """ Mux, but treating input/output Signals as bit vectors, rather than
+    """Mux, but treating input/output Signals as bit vectors, rather than
     integers. This means each bit in the output is independently multiplexed
     based on the corresponding bit in each of the inputs.
     """
@@ -25,9 +38,127 @@ class BitwiseMux(Elaboratable):
 
 
 class BitwiseLut(Elaboratable):
+    """Bitwise logic operators implemented using a look-up table, like LUTs in
+    FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
+
+    Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
+    """
+
+    def __init__(self, input_count, width):
+        """
+        input_count: int
+            the number of inputs. ternlog-style instructions have 3 inputs.
+        width: int
+            the number of bits in each input/output.
+        """
+        self.input_count = input_count
+        self.width = width
+
+        def inp(i):
+            return Signal(width, name=f"input{i}")
+        self.inputs = tuple(inp(i) for i in range(input_count))  # inputs
+        self.lut = Signal(2 ** input_count)                     # lookup input
+        self.output = Signal(width)                             # output
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        lut_array = Array(self.lut)  # create dynamic-indexable LUT array
+        out = []
+
+        for bit in range(self.width):
+            # take the bit'th bit of every input, create a LUT index from it
+            index = Signal(self.input_count, name="index%d" % bit)
+            comb += index.eq(Cat(inp[bit] for inp in self.inputs))
+            # store output bit in a list - Cat() it after (simplifies graphviz)
+            outbit = Signal(name="out%d" % bit)
+            comb += outbit.eq(lut_array[index])
+            out.append(outbit)
+
+        # finally Cat() all the output bits together
+        comb += self.output.eq(Cat(*out))
+        return m
+
+    def ports(self):
+        return list(self.inputs) + [self.lut, self.output]
+
+
+@plain_data()
+class _TreeMuxNode:
+    """Mux in tree for `TreeBitwiseLut`.
+
+    Attributes:
+    out: Signal
+    container: TreeBitwiseLut
+    parent: _TreeMuxNode | None
+    child0: _TreeMuxNode | None
+    child1: _TreeMuxNode | None
+    depth: int
+    """
+    __slots__ = "out", "container", "parent", "child0", "child1", "depth"
+
+    def __init__(self, out, container, parent, child0, child1, depth):
+        """ Arguments:
+        out: Signal
+        container: TreeBitwiseLut
+        parent: _TreeMuxNode | None
+        child0: _TreeMuxNode | None
+        child1: _TreeMuxNode | None
+        depth: int
+        """
+        self.out = out
+        self.container = container
+        self.parent = parent
+        self.child0 = child0
+        self.child1 = child1
+        self.depth = depth
+
+    @property
+    def child_index(self):
+        """index of this node, when looked up in this node's parent's children.
+        """
+        if self.parent is None:
+            return None
+        return int(self.parent.child1 is self)
+
+    def add_child(self, child_index):
+        node = _TreeMuxNode(
+            out=Signal(self.container.width),
+            container=self.container, parent=self,
+            child0=None, child1=None, depth=1 + self.depth)
+        if child_index:
+            assert self.child1 is None
+            self.child1 = node
+        else:
+            assert self.child0 is None
+            self.child0 = node
+        node.out.name = "node_out_" + node.key_str
+        return node
+
+    @property
+    def key(self):
+        retval = []
+        node = self
+        while node.parent is not None:
+            retval.append(node.child_index)
+            node = node.parent
+        retval.reverse()
+        return retval
+
+    @property
+    def key_str(self):
+        k = ['x'] * self.container.input_count
+        for i, v in enumerate(self.key):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+
+class TreeBitwiseLut(Elaboratable):
+    """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+    for API documentation. This version may produce more efficient hardware.
+    """
+
     def __init__(self, input_count, width):
-        assert isinstance(input_count, int)
-        assert isinstance(width, int)
         self.input_count = input_count
         self.width = width
 
@@ -36,40 +167,47 @@ class BitwiseLut(Elaboratable):
         self.inputs = tuple(inp(i) for i in range(input_count))
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
-        self._mux_inputs = {}
-        self._build_mux_inputs()
+        self._tree_root = _TreeMuxNode(
+            out=self.output, container=self, parent=None,
+            child0=None, child1=None, depth=0)
+        self._build_tree(self._tree_root)
 
-    def _make_key_str(self, *sel_values):
-        k = ['x'] * self.input_count
-        for i, v in enumerate(sel_values):
-            k[i] = '1' if v else '0'
-        return '0b' + ''.join(reversed(k))
+    def _build_tree(self, node):
+        if node.depth < self.input_count:
+            self._build_tree(node.add_child(0))
+            self._build_tree(node.add_child(1))
 
-    def _build_mux_inputs(self, *sel_values):
-        name = f"mux_input_{self._make_key_str(*sel_values)}"
-        self._mux_inputs[sel_values] = Signal(self.width, name=name)
-        if len(sel_values) < self.input_count:
-            self._build_mux_inputs(*sel_values, False)
-            self._build_mux_inputs(*sel_values, True)
+    def _elaborate_tree(self, m, node):
+        if node.depth < self.input_count:
+            mux = BitwiseMux(self.width)
+            setattr(m.submodules, "mux_" + node.key_str, mux)
+            m.d.comb += [
+                mux.f.eq(node.child0.out),
+                mux.t.eq(node.child1.out),
+                mux.sel.eq(self.inputs[node.depth]),
+                node.out.eq(mux.output),
+            ]
+            self._elaborate_tree(m, node.child0)
+            self._elaborate_tree(m, node.child1)
+        else:
+            index = int(node.key_str, base=2)
+            m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.output.eq(self._mux_inputs[()])
-        for sel_values, v in self._mux_inputs.items():
-            if len(sel_values) < self.input_count:
-                mux_name = f"mux_{self._make_key_str(*sel_values)}"
-                mux = BitwiseMux(self.width)
-                setattr(m.submodules, mux_name, mux)
-                m.d.comb += [
-                    mux.f.eq(self._mux_inputs[(*sel_values, False)]),
-                    mux.t.eq(self._mux_inputs[(*sel_values, True)]),
-                    mux.sel.eq(self.inputs[len(sel_values)]),
-                    v.eq(mux.output),
-                ]
-            else:
-                lut_index = 0
-                for i in range(self.input_count):
-                    if sel_values[i]:
-                        lut_index |= 2 ** i
-                m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+        self._elaborate_tree(m, self._tree_root)
         return m
+
+    def ports(self):
+        return [*self.inputs, self.lut, self.output]
+
+
+# useful to see what is going on:
+# python3 src/nmutil/test/test_lut.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"
+
+if __name__ == '__main__':
+    dut = BitwiseLut(2, 64)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_lut2.il", "w") as f:
+        f.write(vl)