add Array-based version of BitwiseLut, renaming old version to TreeBitwiseLut in...
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Array, Cat, Repl, Signal
5 from nmigen.hdl.dsl import Module
6 from nmigen.hdl.ir import Elaboratable
7
8
9 class BitwiseMux(Elaboratable):
10 """ Mux, but treating input/output Signals as bit vectors, rather than
11 integers. This means each bit in the output is independently multiplexed
12 based on the corresponding bit in each of the inputs.
13 """
14
15 def __init__(self, width):
16 self.sel = Signal(width)
17 self.t = Signal(width)
18 self.f = Signal(width)
19 self.output = Signal(width)
20
21 def elaborate(self, platform):
22 m = Module()
23 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
24 return m
25
26
27 class BitwiseLut(Elaboratable):
28 def __init__(self, input_count, width):
29 assert isinstance(input_count, int)
30 assert isinstance(width, int)
31 self.input_count = input_count
32 self.width = width
33
34 def inp(i):
35 return Signal(width, name=f"input{i}")
36 self.inputs = tuple(inp(i) for i in range(input_count))
37 self.output = Signal(width)
38 self.lut = Signal(2 ** input_count)
39
40 def lut_index(i):
41 return Signal(input_count, name=f"lut_index_{i}")
42 self._lut_indexes = [lut_index(i) for i in range(width)]
43
44 def elaborate(self, platform):
45 m = Module()
46 lut = Array(self.lut[i] for i in range(self.lut.width))
47 for i in range(self.width):
48 for j in range(self.input_count):
49 m.d.comb += self._lut_indexes[i][j].eq(self.inputs[j][i])
50 m.d.comb += self.output[i].eq(lut[self._lut_indexes[i]])
51 return m
52
53
54 class TreeBitwiseLut(Elaboratable):
55 """tree-based version of BitwiseLut"""
56
57 def __init__(self, input_count, width):
58 assert isinstance(input_count, int)
59 assert isinstance(width, int)
60 self.input_count = input_count
61 self.width = width
62
63 def inp(i):
64 return Signal(width, name=f"input{i}")
65 self.inputs = tuple(inp(i) for i in range(input_count))
66 self.output = Signal(width)
67 self.lut = Signal(2 ** input_count)
68 self._mux_inputs = {}
69 self._build_mux_inputs()
70
71 def _make_key_str(self, *sel_values):
72 k = ['x'] * self.input_count
73 for i, v in enumerate(sel_values):
74 k[i] = '1' if v else '0'
75 return '0b' + ''.join(reversed(k))
76
77 def _build_mux_inputs(self, *sel_values):
78 name = f"mux_input_{self._make_key_str(*sel_values)}"
79 self._mux_inputs[sel_values] = Signal(self.width, name=name)
80 if len(sel_values) < self.input_count:
81 self._build_mux_inputs(*sel_values, False)
82 self._build_mux_inputs(*sel_values, True)
83
84 def elaborate(self, platform):
85 m = Module()
86 m.d.comb += self.output.eq(self._mux_inputs[()])
87 for sel_values, v in self._mux_inputs.items():
88 if len(sel_values) < self.input_count:
89 mux_name = f"mux_{self._make_key_str(*sel_values)}"
90 mux = BitwiseMux(self.width)
91 setattr(m.submodules, mux_name, mux)
92 m.d.comb += [
93 mux.f.eq(self._mux_inputs[(*sel_values, False)]),
94 mux.t.eq(self._mux_inputs[(*sel_values, True)]),
95 mux.sel.eq(self.inputs[len(sel_values)]),
96 v.eq(mux.output),
97 ]
98 else:
99 lut_index = 0
100 for i in range(self.input_count):
101 if sel_values[i]:
102 lut_index |= 2 ** i
103 m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
104 return m