84c515ec52f45d881383e2b241f04bce9c02c9fc
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 """Bitwise logic operators implemented using a look-up table, like LUTs in
9 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
10
11 https://bugs.libre-soc.org/show_bug.cgi?id=745
12 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
13 """
14
15 from nmigen.hdl.ast import Array, Cat, Repl, Signal
16 from nmigen.hdl.dsl import Module
17 from nmigen.hdl.ir import Elaboratable
18 from nmigen.cli import rtlil
19 from nmutil.plain_data import plain_data
20
21 class BitwiseMux(Elaboratable):
22 """Mux, but treating input/output Signals as bit vectors, rather than
23 integers. This means each bit in the output is independently multiplexed
24 based on the corresponding bit in each of the inputs.
25 """
26
27 def __init__(self, width):
28 self.sel = Signal(width)
29 self.t = Signal(width)
30 self.f = Signal(width)
31 self.output = Signal(width)
32
33 def elaborate(self, platform):
34 m = Module()
35 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
36 return m
37
38
39 class BitwiseLut(Elaboratable):
40 """Bitwise logic operators implemented using a look-up table, like LUTs in
41 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
42
43 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
44 """
45
46 def __init__(self, input_count, width):
47 """
48 input_count: int
49 the number of inputs. ternlog-style instructions have 3 inputs.
50 width: int
51 the number of bits in each input/output.
52 """
53 self.input_count = input_count
54 self.width = width
55
56 def inp(i):
57 return Signal(width, name=f"input{i}")
58 self.inputs = tuple(inp(i) for i in range(input_count)) # inputs
59 self.lut = Signal(2 ** input_count) # lookup input
60 self.output = Signal(width) # output
61
62 def elaborate(self, platform):
63 m = Module()
64 comb = m.d.comb
65 lut_array = Array(self.lut) # create dynamic-indexable LUT array
66 out = []
67
68 for bit in range(self.width):
69 # take the bit'th bit of every input, create a LUT index from it
70 index = Signal(self.input_count, name="index%d" % bit)
71 comb += index.eq(Cat(inp[bit] for inp in self.inputs))
72 # store output bit in a list - Cat() it after (simplifies graphviz)
73 outbit = Signal(name="out%d" % bit)
74 comb += outbit.eq(lut_array[index])
75 out.append(outbit)
76
77 # finally Cat() all the output bits together
78 comb += self.output.eq(Cat(*out))
79 return m
80
81 def ports(self):
82 return list(self.inputs) + [self.lut, self.output]
83
84
85 @plain_data()
86 class _TreeMuxNode:
87 """Mux in tree for `TreeBitwiseLut`.
88
89 Attributes:
90 out: Signal
91 container: TreeBitwiseLut
92 parent: _TreeMuxNode | None
93 child0: _TreeMuxNode | None
94 child1: _TreeMuxNode | None
95 depth: int
96 """
97 __slots__ = "out", "container", "parent", "child0", "child1", "depth"
98
99 def __init__(self, out, container, parent, child0, child1, depth):
100 """ Arguments:
101 out: Signal
102 container: TreeBitwiseLut
103 parent: _TreeMuxNode | None
104 child0: _TreeMuxNode | None
105 child1: _TreeMuxNode | None
106 depth: int
107 """
108 self.out = out
109 self.container = container
110 self.parent = parent
111 self.child0 = child0
112 self.child1 = child1
113 self.depth = depth
114
115 @property
116 def child_index(self):
117 """index of this node, when looked up in this node's parent's children.
118 """
119 if self.parent is None:
120 return None
121 return int(self.parent.child1 is self)
122
123 def add_child(self, child_index):
124 node = _TreeMuxNode(
125 out=Signal(self.container.width),
126 container=self.container, parent=self,
127 child0=None, child1=None, depth=1 + self.depth)
128 if child_index:
129 assert self.child1 is None
130 self.child1 = node
131 else:
132 assert self.child0 is None
133 self.child0 = node
134 node.out.name = "node_out_" + node.key_str
135 return node
136
137 @property
138 def key(self):
139 retval = []
140 node = self
141 while node.parent is not None:
142 retval.append(node.child_index)
143 node = node.parent
144 retval.reverse()
145 return retval
146
147 @property
148 def key_str(self):
149 k = ['x'] * self.container.input_count
150 for i, v in enumerate(self.key):
151 k[i] = '1' if v else '0'
152 return '0b' + ''.join(reversed(k))
153
154
155 class TreeBitwiseLut(Elaboratable):
156 """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
157 for API documentation. This version may produce more efficient hardware.
158 """
159
160 def __init__(self, input_count, width):
161 self.input_count = input_count
162 self.width = width
163
164 def inp(i):
165 return Signal(width, name=f"input{i}")
166 self.inputs = tuple(inp(i) for i in range(input_count))
167 self.output = Signal(width)
168 self.lut = Signal(2 ** input_count)
169 self._tree_root = _TreeMuxNode(
170 out=self.output, container=self, parent=None,
171 child0=None, child1=None, depth=0)
172 self._build_tree(self._tree_root)
173
174 def _build_tree(self, node):
175 if node.depth < self.input_count:
176 self._build_tree(node.add_child(0))
177 self._build_tree(node.add_child(1))
178
179 def _elaborate_tree(self, m, node):
180 if node.depth < self.input_count:
181 mux = BitwiseMux(self.width)
182 setattr(m.submodules, "mux_" + node.key_str, mux)
183 m.d.comb += [
184 mux.f.eq(node.child0.out),
185 mux.t.eq(node.child1.out),
186 mux.sel.eq(self.inputs[node.depth]),
187 node.out.eq(mux.output),
188 ]
189 self._elaborate_tree(m, node.child0)
190 self._elaborate_tree(m, node.child1)
191 else:
192 index = int(node.key_str, base=2)
193 m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
194
195 def elaborate(self, platform):
196 m = Module()
197 self._elaborate_tree(m, self._tree_root)
198 return m
199
200 def ports(self):
201 return [*self.inputs, self.lut, self.output]
202
203
204 # useful to see what is going on:
205 # python3 src/nmutil/test/test_lut.py
206 # yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"
207
208 if __name__ == '__main__':
209 dut = BitwiseLut(2, 64)
210 vl = rtlil.convert(dut, ports=dut.ports())
211 with open("test_lut2.il", "w") as f:
212 f.write(vl)