add docs
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 """
3 Bitwise logic operators implemented using a look-up table, like LUTs in
4 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
5
6 https://bugs.libre-soc.org/show_bug.cgi?id=745
7 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
8 """
9
10 from nmigen.hdl.ast import Array, Cat, Repl, Signal
11 from nmigen.hdl.dsl import Module
12 from nmigen.hdl.ir import Elaboratable
13
14
15 class BitwiseMux(Elaboratable):
16 """ Mux, but treating input/output Signals as bit vectors, rather than
17 integers. This means each bit in the output is independently multiplexed
18 based on the corresponding bit in each of the inputs.
19 """
20
21 def __init__(self, width):
22 self.sel = Signal(width)
23 self.t = Signal(width)
24 self.f = Signal(width)
25 self.output = Signal(width)
26
27 def elaborate(self, platform):
28 m = Module()
29 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
30 return m
31
32
33 class BitwiseLut(Elaboratable):
34 """ Bitwise logic operators implemented using a look-up table, like LUTs in
35 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
36
37 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
38 """
39
40 def __init__(self, input_count, width):
41 """
42 input_count: int
43 the number of inputs. ternlog-style instructions have 3 inputs.
44 width: int
45 the number of bits in each input/output.
46 """
47 self.input_count = input_count
48 self.width = width
49
50 def inp(i):
51 return Signal(width, name=f"input{i}")
52 self.inputs = tuple(inp(i) for i in range(input_count))
53 """ the inputs """
54 self.output = Signal(width)
55 """ the output """
56 self.lut = Signal(2 ** input_count)
57 """ the look-up table. Is `2 ** input_count` bits wide."""
58
59 def elaborate(self, platform):
60 m = Module()
61 lut_array = Array(self.lut)
62 for bit in range(self.width):
63 index = Cat(inp[bit] for inp in self.inputs)
64 m.d.comb += self.output[bit].eq(lut_array[index])
65 return m
66
67
68 class TreeBitwiseLut(Elaboratable):
69 """ Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
70 """
71
72 def __init__(self, input_count, width):
73 self.input_count = input_count
74 self.width = width
75
76 def inp(i):
77 return Signal(width, name=f"input{i}")
78 self.inputs = tuple(inp(i) for i in range(input_count))
79 self.output = Signal(width)
80 self.lut = Signal(2 ** input_count)
81 self._mux_inputs = {}
82 self._build_mux_inputs()
83
84 def _make_key_str(self, *sel_values):
85 k = ['x'] * self.input_count
86 for i, v in enumerate(sel_values):
87 k[i] = '1' if v else '0'
88 return '0b' + ''.join(reversed(k))
89
90 def _build_mux_inputs(self, *sel_values):
91 name = f"mux_input_{self._make_key_str(*sel_values)}"
92 self._mux_inputs[sel_values] = Signal(self.width, name=name)
93 if len(sel_values) < self.input_count:
94 self._build_mux_inputs(*sel_values, False)
95 self._build_mux_inputs(*sel_values, True)
96
97 def elaborate(self, platform):
98 m = Module()
99 m.d.comb += self.output.eq(self._mux_inputs[()])
100 for sel_values, v in self._mux_inputs.items():
101 if len(sel_values) < self.input_count:
102 mux_name = f"mux_{self._make_key_str(*sel_values)}"
103 mux = BitwiseMux(self.width)
104 setattr(m.submodules, mux_name, mux)
105 m.d.comb += [
106 mux.f.eq(self._mux_inputs[(*sel_values, False)]),
107 mux.t.eq(self._mux_inputs[(*sel_values, True)]),
108 mux.sel.eq(self.inputs[len(sel_values)]),
109 v.eq(mux.output),
110 ]
111 else:
112 lut_index = 0
113 for i in range(self.input_count):
114 if sel_values[i]:
115 lut_index |= 2 ** i
116 m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
117 return m