49c19a876983f9f150251c2c2254b8874fb56aac
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 """Bitwise logic operators implemented using a look-up table, like LUTs in
9 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
10
11 https://bugs.libre-soc.org/show_bug.cgi?id=745
12 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
13 """
14
15 from nmigen.hdl.ast import Array, Cat, Repl, Signal
16 from nmigen.hdl.dsl import Module
17 from nmigen.hdl.ir import Elaboratable
18 from nmigen.cli import rtlil
19 from dataclasses import dataclass
20
21
22 class BitwiseMux(Elaboratable):
23 """Mux, but treating input/output Signals as bit vectors, rather than
24 integers. This means each bit in the output is independently multiplexed
25 based on the corresponding bit in each of the inputs.
26 """
27
28 def __init__(self, width):
29 self.sel = Signal(width)
30 self.t = Signal(width)
31 self.f = Signal(width)
32 self.output = Signal(width)
33
34 def elaborate(self, platform):
35 m = Module()
36 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
37 return m
38
39
40 class BitwiseLut(Elaboratable):
41 """Bitwise logic operators implemented using a look-up table, like LUTs in
42 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
43
44 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
45 """
46
47 def __init__(self, input_count, width):
48 """
49 input_count: int
50 the number of inputs. ternlog-style instructions have 3 inputs.
51 width: int
52 the number of bits in each input/output.
53 """
54 self.input_count = input_count
55 self.width = width
56
57 def inp(i):
58 return Signal(width, name=f"input{i}")
59 self.inputs = tuple(inp(i) for i in range(input_count)) # inputs
60 self.lut = Signal(2 ** input_count) # lookup input
61 self.output = Signal(width) # output
62
63 def elaborate(self, platform):
64 m = Module()
65 comb = m.d.comb
66 lut_array = Array(self.lut) # create dynamic-indexable LUT array
67 out = []
68
69 for bit in range(self.width):
70 # take the bit'th bit of every input, create a LUT index from it
71 index = Signal(self.input_count, name="index%d" % bit)
72 comb += index.eq(Cat(inp[bit] for inp in self.inputs))
73 # store output bit in a list - Cat() it after (simplifies graphviz)
74 outbit = Signal(name="out%d" % bit)
75 comb += outbit.eq(lut_array[index])
76 out.append(outbit)
77
78 # finally Cat() all the output bits together
79 comb += self.output.eq(Cat(*out))
80 return m
81
82 def ports(self):
83 return list(self.inputs) + [self.lut, self.output]
84
85
86 @dataclass
87 class _TreeMuxNode:
88 """Mux in tree for `TreeBitwiseLut`."""
89 out: Signal
90 container: "TreeBitwiseLut"
91 parent: "_TreeMuxNode | None"
92 child0: "_TreeMuxNode | None"
93 child1: "_TreeMuxNode | None"
94 depth: int
95
96 @property
97 def child_index(self):
98 """index of this node, when looked up in this node's parent's children.
99 """
100 if self.parent is None:
101 return None
102 return int(self.parent.child1 is self)
103
104 def add_child(self, child_index):
105 node = _TreeMuxNode(
106 out=Signal(self.container.width),
107 container=self.container, parent=self,
108 child0=None, child1=None, depth=1 + self.depth)
109 if child_index:
110 assert self.child1 is None
111 self.child1 = node
112 else:
113 assert self.child0 is None
114 self.child0 = node
115 node.out.name = "node_out_" + node.key_str
116 return node
117
118 @property
119 def key(self):
120 retval = []
121 node = self
122 while node.parent is not None:
123 retval.append(node.child_index)
124 node = node.parent
125 retval.reverse()
126 return retval
127
128 @property
129 def key_str(self):
130 k = ['x'] * self.container.input_count
131 for i, v in enumerate(self.key):
132 k[i] = '1' if v else '0'
133 return '0b' + ''.join(reversed(k))
134
135
136 class TreeBitwiseLut(Elaboratable):
137 """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
138 for API documentation. This version may produce more efficient hardware.
139 """
140
141 def __init__(self, input_count, width):
142 self.input_count = input_count
143 self.width = width
144
145 def inp(i):
146 return Signal(width, name=f"input{i}")
147 self.inputs = tuple(inp(i) for i in range(input_count))
148 self.output = Signal(width)
149 self.lut = Signal(2 ** input_count)
150 self._tree_root = _TreeMuxNode(
151 out=self.output, container=self, parent=None,
152 child0=None, child1=None, depth=0)
153 self._build_tree(self._tree_root)
154
155 def _build_tree(self, node):
156 if node.depth < self.input_count:
157 self._build_tree(node.add_child(0))
158 self._build_tree(node.add_child(1))
159
160 def _elaborate_tree(self, m, node):
161 if node.depth < self.input_count:
162 mux = BitwiseMux(self.width)
163 setattr(m.submodules, "mux_" + node.key_str, mux)
164 m.d.comb += [
165 mux.f.eq(node.child0.out),
166 mux.t.eq(node.child1.out),
167 mux.sel.eq(self.inputs[node.depth]),
168 node.out.eq(mux.output),
169 ]
170 self._elaborate_tree(m, node.child0)
171 self._elaborate_tree(m, node.child1)
172 else:
173 index = int(node.key_str, base=2)
174 m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
175
176 def elaborate(self, platform):
177 m = Module()
178 self._elaborate_tree(m, self._tree_root)
179 return m
180
181 def ports(self):
182 return [*self.inputs, self.lut, self.output]
183
184
185 # useful to see what is going on:
186 # yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"