add copyright notices
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import AnyConst, Assert, Signal
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
12 from nmigen.sim import Delay
13 from nmutil.sim_util import do_sim, hash_256
14
15
16 class TestBitwiseMux(FHDLTestCase):
17 def test(self):
18 width = 2
19 dut = BitwiseMux(width)
20
21 def case(sel, t, f, expected):
22 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
23 yield dut.sel.eq(sel)
24 yield dut.t.eq(t)
25 yield dut.f.eq(f)
26 yield Delay(1e-6)
27 output = yield dut.output
28 with self.subTest(output=bin(output), expected=bin(expected)):
29 self.assertEqual(expected, output)
30
31 def process():
32 for sel in range(2 ** width):
33 for t in range(2 ** width):
34 for f in range(2**width):
35 expected = 0
36 for i in range(width):
37 if sel & 2 ** i:
38 if t & 2 ** i:
39 expected |= 2 ** i
40 elif f & 2 ** i:
41 expected |= 2 ** i
42 yield from case(sel, t, f, expected)
43 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
44 sim.add_process(process)
45 sim.run()
46
47 def test_formal(self):
48 width = 2
49 dut = BitwiseMux(width)
50 m = Module()
51 m.submodules.dut = dut
52 m.d.comb += dut.sel.eq(AnyConst(width))
53 m.d.comb += dut.f.eq(AnyConst(width))
54 m.d.comb += dut.t.eq(AnyConst(width))
55 for i in range(width):
56 with m.If(dut.sel[i]):
57 m.d.comb += Assert(dut.t[i] == dut.output[i])
58 with m.Else():
59 m.d.comb += Assert(dut.f[i] == dut.output[i])
60 self.assertFormal(m)
61
62
63 class TestBitwiseLut(FHDLTestCase):
64 def tst(self, cls):
65 dut = cls(3, 16)
66 mask = 2 ** dut.width - 1
67 lut_mask = 2 ** dut.lut.width - 1
68 if cls is TreeBitwiseLut:
69 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
70 self.assertEqual(mux_inputs, {
71 (): 'mux_input_0bxxx',
72 (False,): 'mux_input_0bxx0',
73 (False, False): 'mux_input_0bx00',
74 (False, False, False): 'mux_input_0b000',
75 (False, False, True): 'mux_input_0b100',
76 (False, True): 'mux_input_0bx10',
77 (False, True, False): 'mux_input_0b010',
78 (False, True, True): 'mux_input_0b110',
79 (True,): 'mux_input_0bxx1',
80 (True, False): 'mux_input_0bx01',
81 (True, False, False): 'mux_input_0b001',
82 (True, False, True): 'mux_input_0b101',
83 (True, True): 'mux_input_0bx11',
84 (True, True, False): 'mux_input_0b011',
85 (True, True, True): 'mux_input_0b111'
86 })
87
88 def case(in0, in1, in2, lut):
89 expected = 0
90 for i in range(dut.width):
91 lut_index = 0
92 if in0 & 2 ** i:
93 lut_index |= 2 ** 0
94 if in1 & 2 ** i:
95 lut_index |= 2 ** 1
96 if in2 & 2 ** i:
97 lut_index |= 2 ** 2
98 if lut & 2 ** lut_index:
99 expected |= 2 ** i
100 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
101 lut=bin(lut)):
102 yield dut.inputs[0].eq(in0)
103 yield dut.inputs[1].eq(in1)
104 yield dut.inputs[2].eq(in2)
105 yield dut.lut.eq(lut)
106 yield Delay(1e-6)
107 output = yield dut.output
108 with self.subTest(output=bin(output), expected=bin(expected)):
109 self.assertEqual(expected, output)
110
111 def process():
112 for case_index in range(100):
113 with self.subTest(case_index=case_index):
114 in0 = hash_256(f"{case_index} in0") & mask
115 in1 = hash_256(f"{case_index} in1") & mask
116 in2 = hash_256(f"{case_index} in2") & mask
117 lut = hash_256(f"{case_index} lut") & lut_mask
118 yield from case(in0, in1, in2, lut)
119 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
120 sim.add_process(process)
121 sim.run()
122
123 def tst_formal(self, cls):
124 dut = cls(3, 16)
125 m = Module()
126 m.submodules.dut = dut
127 m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
128 m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
129 m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
130 m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
131 for i in range(dut.width):
132 lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
133 for j in range(dut.input_count):
134 m.d.comb += lut_index[j].eq(dut.inputs[j][i])
135 for j in range(dut.lut.width):
136 with m.If(lut_index == j):
137 m.d.comb += Assert(dut.lut[j] == dut.output[i])
138 self.assertFormal(m)
139
140 def test(self):
141 self.tst(BitwiseLut)
142
143 def test_tree(self):
144 self.tst(TreeBitwiseLut)
145
146 def test_formal(self):
147 self.tst_formal(BitwiseLut)
148
149 def test_tree_formal(self):
150 self.tst_formal(TreeBitwiseLut)
151
152
153 if __name__ == "__main__":
154 unittest.main()