b30c4b7c801158e531a9619beead852a9b14d1c6
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 import unittest
5 from nmigen.hdl.ast import AnyConst, Assert, Signal
6 from nmigen.hdl.dsl import Module
7 from nmutil.formaltest import FHDLTestCase
8 from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
9 from nmigen.sim import Delay
10 from nmutil.sim_util import do_sim, hash_256
11
12
13 class TestBitwiseMux(FHDLTestCase):
14 def test(self):
15 width = 2
16 dut = BitwiseMux(width)
17
18 def case(sel, t, f, expected):
19 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
20 yield dut.sel.eq(sel)
21 yield dut.t.eq(t)
22 yield dut.f.eq(f)
23 yield Delay(1e-6)
24 output = yield dut.output
25 with self.subTest(output=bin(output), expected=bin(expected)):
26 self.assertEqual(expected, output)
27
28 def process():
29 for sel in range(2 ** width):
30 for t in range(2 ** width):
31 for f in range(2**width):
32 expected = 0
33 for i in range(width):
34 if sel & 2 ** i:
35 if t & 2 ** i:
36 expected |= 2 ** i
37 elif f & 2 ** i:
38 expected |= 2 ** i
39 yield from case(sel, t, f, expected)
40 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
41 sim.add_process(process)
42 sim.run()
43
44 def test_formal(self):
45 width = 2
46 dut = BitwiseMux(width)
47 m = Module()
48 m.submodules.dut = dut
49 m.d.comb += dut.sel.eq(AnyConst(width))
50 m.d.comb += dut.f.eq(AnyConst(width))
51 m.d.comb += dut.t.eq(AnyConst(width))
52 for i in range(width):
53 with m.If(dut.sel[i]):
54 m.d.comb += Assert(dut.t[i] == dut.output[i])
55 with m.Else():
56 m.d.comb += Assert(dut.f[i] == dut.output[i])
57 self.assertFormal(m)
58
59
60 class TestBitwiseLut(FHDLTestCase):
61 def tst(self, cls):
62 dut = cls(3, 16)
63 mask = 2 ** dut.width - 1
64 lut_mask = 2 ** dut.lut.width - 1
65 if cls is TreeBitwiseLut:
66 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
67 self.assertEqual(mux_inputs, {
68 (): 'mux_input_0bxxx',
69 (False,): 'mux_input_0bxx0',
70 (False, False): 'mux_input_0bx00',
71 (False, False, False): 'mux_input_0b000',
72 (False, False, True): 'mux_input_0b100',
73 (False, True): 'mux_input_0bx10',
74 (False, True, False): 'mux_input_0b010',
75 (False, True, True): 'mux_input_0b110',
76 (True,): 'mux_input_0bxx1',
77 (True, False): 'mux_input_0bx01',
78 (True, False, False): 'mux_input_0b001',
79 (True, False, True): 'mux_input_0b101',
80 (True, True): 'mux_input_0bx11',
81 (True, True, False): 'mux_input_0b011',
82 (True, True, True): 'mux_input_0b111'
83 })
84
85 def case(in0, in1, in2, lut):
86 expected = 0
87 for i in range(dut.width):
88 lut_index = 0
89 if in0 & 2 ** i:
90 lut_index |= 2 ** 0
91 if in1 & 2 ** i:
92 lut_index |= 2 ** 1
93 if in2 & 2 ** i:
94 lut_index |= 2 ** 2
95 if lut & 2 ** lut_index:
96 expected |= 2 ** i
97 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
98 lut=bin(lut)):
99 yield dut.inputs[0].eq(in0)
100 yield dut.inputs[1].eq(in1)
101 yield dut.inputs[2].eq(in2)
102 yield dut.lut.eq(lut)
103 yield Delay(1e-6)
104 output = yield dut.output
105 with self.subTest(output=bin(output), expected=bin(expected)):
106 self.assertEqual(expected, output)
107
108 def process():
109 for case_index in range(100):
110 with self.subTest(case_index=case_index):
111 in0 = hash_256(f"{case_index} in0") & mask
112 in1 = hash_256(f"{case_index} in1") & mask
113 in2 = hash_256(f"{case_index} in2") & mask
114 lut = hash_256(f"{case_index} lut") & lut_mask
115 yield from case(in0, in1, in2, lut)
116 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
117 sim.add_process(process)
118 sim.run()
119
120 def tst_formal(self, cls):
121 dut = cls(3, 16)
122 m = Module()
123 m.submodules.dut = dut
124 m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
125 m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
126 m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
127 m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
128 for i in range(dut.width):
129 lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
130 for j in range(dut.input_count):
131 m.d.comb += lut_index[j].eq(dut.inputs[j][i])
132 for j in range(dut.lut.width):
133 with m.If(lut_index == j):
134 m.d.comb += Assert(dut.lut[j] == dut.output[i])
135 self.assertFormal(m)
136
137 def test(self):
138 self.tst(BitwiseLut)
139
140 def test_tree(self):
141 self.tst(TreeBitwiseLut)
142
143 def test_formal(self):
144 self.tst_formal(BitwiseLut)
145
146 def test_tree_formal(self):
147 self.tst_formal(TreeBitwiseLut)
148
149
150 if __name__ == "__main__":
151 unittest.main()