remove unnecessary <no space here> messages
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # TODO: Copyright notice (standard style, plenty of examples)
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 # TODO: credits to NLnet for funding
5
6 """Bitwise logic operators implemented using a look-up table, like LUTs in
7 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=745
10 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
11 """
12
13 from nmigen.hdl.ast import Array, Cat, Repl, Signal
14 from nmigen.hdl.dsl import Module
15 from nmigen.hdl.ir import Elaboratable
16 from nmigen.cli import rtlil
17
18
19 class BitwiseMux(Elaboratable):
20 """Mux, but treating input/output Signals as bit vectors, rather than
21 integers. This means each bit in the output is independently multiplexed
22 based on the corresponding bit in each of the inputs.
23 """
24
25 def __init__(self, width):
26 self.sel = Signal(width)
27 self.t = Signal(width)
28 self.f = Signal(width)
29 self.output = Signal(width)
30
31 def elaborate(self, platform):
32 m = Module()
33 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
34 return m
35
36
37 class BitwiseLut(Elaboratable):
38 """Bitwise logic operators implemented using a look-up table, like LUTs in
39 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
40
41 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
42 """
43
44 def __init__(self, input_count, width):
45 """
46 input_count: int
47 the number of inputs. ternlog-style instructions have 3 inputs.
48 width: int
49 the number of bits in each input/output.
50 """
51 self.input_count = input_count
52 self.width = width
53
54 def inp(i):
55 return Signal(width, name=f"input{i}")
56 self.inputs = tuple(inp(i) for i in range(input_count)) # inputs
57 self.lut = Signal(2 ** input_count) # lookup input
58 self.output = Signal(width) # output
59
60 def elaborate(self, platform):
61 m = Module()
62 comb = m.d.comb
63 lut_array = Array(self.lut) # create dynamic-indexable LUT array
64 out = []
65
66 for bit in range(self.width):
67 # take the bit'th bit of every input, create a LUT index from it
68 index = Signal(self.input_count, name="index%d" % bit)
69 comb += index.eq(Cat(inp[bit] for inp in self.inputs))
70 # store output bit in a list - Cat() it after (simplifies graphviz)
71 outbit = Signal(name="out%d" % bit)
72 comb += outbit.eq(lut_array[index])
73 out.append(outbit)
74
75 # finally Cat() all the output bits together
76 comb += self.output.eq(Cat(*out))
77 return m
78
79 def ports(self):
80 return list(self.inputs) + [self.lut, self.output]
81
82
83 class TreeBitwiseLut(Elaboratable):
84 """Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
85 (good enough reason to say "see bitwiselut", but mention that
86 the API is identical and explain why the second implementation
87 exists, despite it being identical)
88 """
89
90 def __init__(self, input_count, width):
91 self.input_count = input_count
92 self.width = width
93
94 def inp(i):
95 return Signal(width, name=f"input{i}")
96 self.inputs = tuple(inp(i) for i in range(input_count))
97 self.output = Signal(width)
98 self.lut = Signal(2 ** input_count)
99 self._mux_inputs = {}
100 self._build_mux_inputs()
101
102 def _make_key_str(self, *sel_values):
103 k = ['x'] * self.input_count
104 for i, v in enumerate(sel_values):
105 k[i] = '1' if v else '0'
106 return '0b' + ''.join(reversed(k))
107
108 def _build_mux_inputs(self, *sel_values):
109 # XXX yyyeah using PHP-style functions-in-text... blech :)
110 # XXX replace with name = mux_input_%s" % self._make_etcetc
111 name = f"mux_input_{self._make_key_str(*sel_values)}"
112 self._mux_inputs[sel_values] = Signal(self.width, name=name)
113 if len(sel_values) < self.input_count:
114 self._build_mux_inputs(*sel_values, False)
115 self._build_mux_inputs(*sel_values, True)
116
117 def elaborate(self, platform):
118 m = Module()
119 m.d.comb += self.output.eq(self._mux_inputs[()])
120 for sel_values, v in self._mux_inputs.items():
121 if len(sel_values) < self.input_count:
122 # XXX yyyeah using PHP-style functions-in-text... blech :)
123 # XXX replace with name = mux_input_%s" % self._make_etcetc
124 mux_name = f"mux_{self._make_key_str(*sel_values)}"
125 mux = BitwiseMux(self.width)
126 setattr(m.submodules, mux_name, mux)
127 m.d.comb += [
128 mux.f.eq(self._mux_inputs[(*sel_values, False)]),
129 mux.t.eq(self._mux_inputs[(*sel_values, True)]),
130 mux.sel.eq(self.inputs[len(sel_values)]),
131 v.eq(mux.output),
132 ]
133 else:
134 lut_index = 0
135 for i in range(self.input_count):
136 if sel_values[i]:
137 lut_index |= 2 ** i
138 m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
139 return m
140
141 def ports(self):
142 return [self.input, self.chunk_sizes, self.output]
143
144
145 # useful to see what is going on: use yosys "read_ilang test_lut.il; show top"
146 if __name__ == '__main__':
147 dut = BitwiseLut(3, 8)
148 vl = rtlil.convert(dut, ports=dut.ports())
149 with open("test_lut.il", "w") as f:
150 f.write(vl)