fix tests
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from functools import reduce
8 from operator import xor
9 import unittest
10 from nmigen.hdl.ast import (AnyConst, Assert, Signal, Const, unsigned, signed,
11 Mux)
12 from nmigen.hdl.dsl import Module
13 from nmutil.formaltest import FHDLTestCase
14 from nmigen_gf.reference.clmul import clmul
15 from nmutil.clmul import BitwiseXorReduce, CLMulAdd
16 from nmigen.sim import Delay
17 from nmutil.sim_util import do_sim, hash_256
18
19
20 class TestBitwiseXorReduce(FHDLTestCase):
21 def tst(self, input_shapes):
22 dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
23 for i, w in enumerate(input_shapes))
24 self.assertEqual(reduce(xor, dut.input_values).shape(),
25 dut.output.shape())
26
27 def case(inputs):
28 expected = reduce(xor, inputs)
29 with self.subTest(inputs=list(map(hex, inputs)),
30 expected=hex(expected)):
31 for i, inp in enumerate(inputs):
32 yield dut.input_values[i].eq(inp)
33 yield Delay(1e-6)
34 output = yield dut.output
35 with self.subTest(output=hex(output)):
36 self.assertEqual(expected, output)
37
38 def process():
39 for i in range(100):
40 inputs = []
41 for inp in dut.input_values:
42 v = hash_256(f"bxorr input {i} {inp.name}")
43 inputs.append(Const.normalize(v, inp.shape()))
44 yield from case(inputs)
45
46 with do_sim(self, dut, [*dut.input_values, dut.output]) as sim:
47 sim.add_process(process)
48 sim.run()
49
50 def tst_formal(self, input_shapes):
51 dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
52 for i, w in enumerate(input_shapes))
53 m = Module()
54 m.submodules.dut = dut
55 for i in dut.input_values:
56 m.d.comb += i.eq(AnyConst(i.shape()))
57 m.d.comb += Assert(dut.output == reduce(xor, dut.input_values))
58 self.assertFormal(m)
59
60 def test_65_of_u64(self):
61 self.tst([64] * 65)
62
63 def test_formal_65_of_u64(self):
64 self.tst_formal([64] * 65)
65
66 def test_5_of_u6(self):
67 self.tst([6] * 5)
68
69 def test_formal_5_of_u6(self):
70 self.tst_formal([6] * 5)
71
72 def test_u5_i6_u3_i10(self):
73 self.tst([unsigned(5), signed(6), unsigned(3), signed(10)])
74
75 def test_formal_u5_i6_u3_i10(self):
76 self.tst_formal([unsigned(5), signed(6), unsigned(3), signed(10)])
77
78
79 class TestCLMulAdd(FHDLTestCase):
80 def tst(self, factor_width, terms_width):
81 dut = CLMulAdd(factor_width, terms_width)
82 self.assertEqual(dut.output.width,
83 max((factor_width * 2 - 1, *terms_width)))
84
85 def case(factor1, factor2, terms):
86 expected = reduce(xor, terms, clmul(factor1, factor2))
87 with self.subTest(factor1=hex(factor1),
88 factor2=bin(factor2),
89 terms=list(map(hex, terms)),
90 expected=hex(expected)):
91 yield dut.factor1.eq(factor1)
92 yield dut.factor2.eq(factor2)
93 for i, term in enumerate(terms):
94 yield dut.terms[i].eq(term)
95 yield Delay(1e-6)
96 output = yield dut.output
97 with self.subTest(output=hex(output)):
98 self.assertEqual(expected, output)
99
100 def process():
101 for i in range(100):
102 v = hash_256(f"clmuladd term {i} factor1")
103 factor1 = Const.normalize(v, unsigned(factor_width))
104 v = hash_256(f"clmuladd term {i} factor2")
105 factor2 = Const.normalize(v, unsigned(factor_width))
106 terms = []
107 for j, term_width in enumerate(terms_width):
108 v = hash_256(f"clmuladd term {i} {j}")
109 terms.append(Const.normalize(v, unsigned(term_width)))
110 yield from case(factor1, factor2, terms)
111 with do_sim(self, dut, [dut.factor1, dut.factor2, *dut.terms,
112 dut.output]) as sim:
113 sim.add_process(process)
114 sim.run()
115
116 def test_4x4(self):
117 self.tst(4, ())
118
119 def test_4x4_8(self):
120 self.tst(4, (8,))
121
122 def test_64x64(self):
123 self.tst(64, ())
124
125 def test_64x64_64(self):
126 self.tst(64, (64,))
127
128 def test_8x8_16_16_16(self):
129 self.tst(8, (16, 16, 16))
130
131 def tst_formal(self, factor_width, terms_width):
132 dut = CLMulAdd(factor_width, terms_width)
133 m = Module()
134 m.submodules.dut = dut
135 m.d.comb += dut.factor1.eq(AnyConst(factor_width))
136 m.d.comb += dut.factor2.eq(AnyConst(factor_width))
137 reduce_inputs = []
138 for shift in range(factor_width):
139 reduce_inputs.append(
140 Mux(dut.factor1[shift], dut.factor2 << shift, 0))
141 for i in dut.terms:
142 m.d.comb += i.eq(AnyConst(i.shape()))
143 reduce_inputs.append(i)
144 for i in range(len(reduce_inputs)):
145 sig = Signal(reduce_inputs[i].shape(), name=f"reduce_input_{i}")
146 m.d.comb += sig.eq(reduce_inputs[i])
147 reduce_inputs[i] = sig
148 expected = Signal(reduce(xor, reduce_inputs).shape())
149 m.d.comb += expected.eq(reduce(xor, reduce_inputs))
150 m.d.comb += Assert(dut.output == expected)
151 self.assertFormal(m)
152
153 def test_formal_4x4(self):
154 self.tst_formal(4, ())
155
156 def test_formal_4x4_8(self):
157 self.tst_formal(4, (8,))
158
159 def test_formal_64x64(self):
160 self.tst_formal(64, ())
161
162 def test_formal_64x64_64(self):
163 self.tst_formal(64, (64,))
164
165 def test_formal_8x8_16_16_16(self):
166 self.tst_formal(8, (16, 16, 16))
167
168
169 if __name__ == "__main__":
170 unittest.main()