speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 """Bitwise logic operators implemented using a look-up table, like LUTs in
9 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
10
11 https://bugs.libre-soc.org/show_bug.cgi?id=745
12 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
13 """
14
15 from nmigen.hdl.ast import Array, Cat, Repl, Signal
16 from nmigen.hdl.dsl import Module
17 from nmigen.hdl.ir import Elaboratable
18 from nmigen.cli import rtlil
19 from nmutil.plain_data import plain_data
20
21
22 class BitwiseMux(Elaboratable):
23 """Mux, but treating input/output Signals as bit vectors, rather than
24 integers. This means each bit in the output is independently multiplexed
25 based on the corresponding bit in each of the inputs.
26 """
27
28 def __init__(self, width):
29 self.sel = Signal(width)
30 self.t = Signal(width)
31 self.f = Signal(width)
32 self.output = Signal(width)
33
34 def elaborate(self, platform):
35 m = Module()
36 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
37 return m
38
39
40 class BitwiseLut(Elaboratable):
41 """Bitwise logic operators implemented using a look-up table, like LUTs in
42 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
43
44 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
45 """
46
47 def __init__(self, input_count, width):
48 """
49 input_count: int
50 the number of inputs. ternlog-style instructions have 3 inputs.
51 width: int
52 the number of bits in each input/output.
53 """
54 self.input_count = input_count
55 self.width = width
56
57 def inp(i):
58 return Signal(width, name=f"input{i}")
59 self.inputs = tuple(inp(i) for i in range(input_count)) # inputs
60 self.lut = Signal(2 ** input_count) # lookup input
61 self.output = Signal(width) # output
62
63 def elaborate(self, platform):
64 m = Module()
65 comb = m.d.comb
66 lut_array = Array(self.lut) # create dynamic-indexable LUT array
67 out = []
68
69 for bit in range(self.width):
70 # take the bit'th bit of every input, create a LUT index from it
71 index = Signal(self.input_count, name="index%d" % bit)
72 comb += index.eq(Cat(inp[bit] for inp in self.inputs))
73 # store output bit in a list - Cat() it after (simplifies graphviz)
74 outbit = Signal(name="out%d" % bit)
75 comb += outbit.eq(lut_array[index])
76 out.append(outbit)
77
78 # finally Cat() all the output bits together
79 comb += self.output.eq(Cat(*out))
80 return m
81
82 def ports(self):
83 return list(self.inputs) + [self.lut, self.output]
84
85
86 @plain_data()
87 class _TreeMuxNode:
88 """Mux in tree for `TreeBitwiseLut`.
89
90 Attributes:
91 out: Signal
92 container: TreeBitwiseLut
93 parent: _TreeMuxNode | None
94 child0: _TreeMuxNode | None
95 child1: _TreeMuxNode | None
96 depth: int
97 """
98 __slots__ = "out", "container", "parent", "child0", "child1", "depth"
99
100 def __init__(self, out, container, parent, child0, child1, depth):
101 """ Arguments:
102 out: Signal
103 container: TreeBitwiseLut
104 parent: _TreeMuxNode | None
105 child0: _TreeMuxNode | None
106 child1: _TreeMuxNode | None
107 depth: int
108 """
109 self.out = out
110 self.container = container
111 self.parent = parent
112 self.child0 = child0
113 self.child1 = child1
114 self.depth = depth
115
116 @property
117 def child_index(self):
118 """index of this node, when looked up in this node's parent's children.
119 """
120 if self.parent is None:
121 return None
122 return int(self.parent.child1 is self)
123
124 def add_child(self, child_index):
125 node = _TreeMuxNode(
126 out=Signal(self.container.width),
127 container=self.container, parent=self,
128 child0=None, child1=None, depth=1 + self.depth)
129 if child_index:
130 assert self.child1 is None
131 self.child1 = node
132 else:
133 assert self.child0 is None
134 self.child0 = node
135 node.out.name = "node_out_" + node.key_str
136 return node
137
138 @property
139 def key(self):
140 retval = []
141 node = self
142 while node.parent is not None:
143 retval.append(node.child_index)
144 node = node.parent
145 retval.reverse()
146 return retval
147
148 @property
149 def key_str(self):
150 k = ['x'] * self.container.input_count
151 for i, v in enumerate(self.key):
152 k[i] = '1' if v else '0'
153 return '0b' + ''.join(reversed(k))
154
155
156 class TreeBitwiseLut(Elaboratable):
157 """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
158 for API documentation. This version may produce more efficient hardware.
159 """
160
161 def __init__(self, input_count, width):
162 self.input_count = input_count
163 self.width = width
164
165 def inp(i):
166 return Signal(width, name=f"input{i}")
167 self.inputs = tuple(inp(i) for i in range(input_count))
168 self.output = Signal(width)
169 self.lut = Signal(2 ** input_count)
170 self._tree_root = _TreeMuxNode(
171 out=self.output, container=self, parent=None,
172 child0=None, child1=None, depth=0)
173 self._build_tree(self._tree_root)
174
175 def _build_tree(self, node):
176 if node.depth < self.input_count:
177 self._build_tree(node.add_child(0))
178 self._build_tree(node.add_child(1))
179
180 def _elaborate_tree(self, m, node):
181 if node.depth < self.input_count:
182 mux = BitwiseMux(self.width)
183 setattr(m.submodules, "mux_" + node.key_str, mux)
184 m.d.comb += [
185 mux.f.eq(node.child0.out),
186 mux.t.eq(node.child1.out),
187 mux.sel.eq(self.inputs[node.depth]),
188 node.out.eq(mux.output),
189 ]
190 self._elaborate_tree(m, node.child0)
191 self._elaborate_tree(m, node.child1)
192 else:
193 index = int(node.key_str, base=2)
194 m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
195
196 def elaborate(self, platform):
197 m = Module()
198 self._elaborate_tree(m, self._tree_root)
199 return m
200
201 def ports(self):
202 return [*self.inputs, self.lut, self.output]
203
204
205 # useful to see what is going on:
206 # python3 src/nmutil/test/test_lut.py
207 # yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"
208
209 if __name__ == '__main__':
210 dut = BitwiseLut(2, 64)
211 vl = rtlil.convert(dut, ports=dut.ports())
212 with open("test_lut2.il", "w") as f:
213 f.write(vl)