simplify lut.py
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Array, Cat, Repl, Signal
5 from nmigen.hdl.dsl import Module
6 from nmigen.hdl.ir import Elaboratable
7
8
9 class BitwiseMux(Elaboratable):
10 """ Mux, but treating input/output Signals as bit vectors, rather than
11 integers. This means each bit in the output is independently multiplexed
12 based on the corresponding bit in each of the inputs.
13 """
14
15 def __init__(self, width):
16 self.sel = Signal(width)
17 self.t = Signal(width)
18 self.f = Signal(width)
19 self.output = Signal(width)
20
21 def elaborate(self, platform):
22 m = Module()
23 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
24 return m
25
26
27 class BitwiseLut(Elaboratable):
28 def __init__(self, input_count, width):
29 self.input_count = input_count
30 self.width = width
31
32 def inp(i):
33 return Signal(width, name=f"input{i}")
34 self.inputs = tuple(inp(i) for i in range(input_count))
35 self.output = Signal(width)
36 self.lut = Signal(2 ** input_count)
37
38 def elaborate(self, platform):
39 m = Module()
40 lut_array = Array(self.lut)
41 for bit in range(self.width):
42 index = Cat(inp[bit] for inp in self.inputs)
43 m.d.comb += self.output[bit].eq(lut_array[index])
44 return m
45
46
47 class TreeBitwiseLut(Elaboratable):
48 """tree-based version of BitwiseLut"""
49
50 def __init__(self, input_count, width):
51 self.input_count = input_count
52 self.width = width
53
54 def inp(i):
55 return Signal(width, name=f"input{i}")
56 self.inputs = tuple(inp(i) for i in range(input_count))
57 self.output = Signal(width)
58 self.lut = Signal(2 ** input_count)
59 self._mux_inputs = {}
60 self._build_mux_inputs()
61
62 def _make_key_str(self, *sel_values):
63 k = ['x'] * self.input_count
64 for i, v in enumerate(sel_values):
65 k[i] = '1' if v else '0'
66 return '0b' + ''.join(reversed(k))
67
68 def _build_mux_inputs(self, *sel_values):
69 name = f"mux_input_{self._make_key_str(*sel_values)}"
70 self._mux_inputs[sel_values] = Signal(self.width, name=name)
71 if len(sel_values) < self.input_count:
72 self._build_mux_inputs(*sel_values, False)
73 self._build_mux_inputs(*sel_values, True)
74
75 def elaborate(self, platform):
76 m = Module()
77 m.d.comb += self.output.eq(self._mux_inputs[()])
78 for sel_values, v in self._mux_inputs.items():
79 if len(sel_values) < self.input_count:
80 mux_name = f"mux_{self._make_key_str(*sel_values)}"
81 mux = BitwiseMux(self.width)
82 setattr(m.submodules, mux_name, mux)
83 m.d.comb += [
84 mux.f.eq(self._mux_inputs[(*sel_values, False)]),
85 mux.t.eq(self._mux_inputs[(*sel_values, True)]),
86 mux.sel.eq(self.inputs[len(sel_values)]),
87 v.eq(mux.output),
88 ]
89 else:
90 lut_index = 0
91 for i in range(self.input_count):
92 if sel_values[i]:
93 lut_index |= 2 ** i
94 m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
95 return m