remove BitwiseXorReduce
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from functools import reduce
8 import operator
9 import unittest
10 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned, Mux
11 from nmigen.hdl.dsl import Module
12 from nmutil.formaltest import FHDLTestCase
13 from nmigen_gf.reference.clmul import clmul
14 from nmigen_gf.hdl.clmul import CLMulAdd
15 from nmigen.sim import Delay
16 from nmutil.sim_util import do_sim, hash_256
17
18
19 class TestCLMulAdd(FHDLTestCase):
20 def tst(self, factor_width, terms_width):
21 dut = CLMulAdd(factor_width, terms_width)
22 self.assertEqual(dut.output.width,
23 max((factor_width * 2 - 1, *terms_width)))
24
25 def case(factor1, factor2, terms):
26 expected = reduce(operator.xor, terms, clmul(factor1, factor2))
27 with self.subTest(factor1=hex(factor1),
28 factor2=bin(factor2),
29 terms=list(map(hex, terms)),
30 expected=hex(expected)):
31 yield dut.factor1.eq(factor1)
32 yield dut.factor2.eq(factor2)
33 for i, term in enumerate(terms):
34 yield dut.terms[i].eq(term)
35 yield Delay(1e-6)
36 output = yield dut.output
37 with self.subTest(output=hex(output)):
38 self.assertEqual(expected, output)
39
40 def process():
41 for i in range(100):
42 v = hash_256(f"clmuladd term {i} factor1")
43 factor1 = Const.normalize(v, unsigned(factor_width))
44 v = hash_256(f"clmuladd term {i} factor2")
45 factor2 = Const.normalize(v, unsigned(factor_width))
46 terms = []
47 for j, term_width in enumerate(terms_width):
48 v = hash_256(f"clmuladd term {i} {j}")
49 terms.append(Const.normalize(v, unsigned(term_width)))
50 yield from case(factor1, factor2, terms)
51 with do_sim(self, dut, [dut.factor1, dut.factor2, *dut.terms,
52 dut.output]) as sim:
53 sim.add_process(process)
54 sim.run()
55
56 def test_4x4(self):
57 self.tst(4, ())
58
59 def test_4x4_8(self):
60 self.tst(4, (8,))
61
62 def test_64x64(self):
63 self.tst(64, ())
64
65 def test_64x64_64(self):
66 self.tst(64, (64,))
67
68 def test_8x8_16_16_16(self):
69 self.tst(8, (16, 16, 16))
70
71 def tst_formal(self, factor_width, terms_width):
72 dut = CLMulAdd(factor_width, terms_width)
73 m = Module()
74 m.submodules.dut = dut
75 m.d.comb += dut.factor1.eq(AnyConst(factor_width))
76 m.d.comb += dut.factor2.eq(AnyConst(factor_width))
77 reduce_inputs = []
78 for shift in range(factor_width):
79 reduce_inputs.append(
80 Mux(dut.factor1[shift], dut.factor2 << shift, 0))
81 for i in dut.terms:
82 m.d.comb += i.eq(AnyConst(i.shape()))
83 reduce_inputs.append(i)
84 for i in range(len(reduce_inputs)):
85 sig = Signal(reduce_inputs[i].shape(), name=f"reduce_input_{i}")
86 m.d.comb += sig.eq(reduce_inputs[i])
87 reduce_inputs[i] = sig
88 expected = Signal(reduce(operator.xor, reduce_inputs).shape())
89 m.d.comb += expected.eq(reduce(operator.xor, reduce_inputs))
90 m.d.comb += Assert(dut.output == expected)
91 self.assertFormal(m)
92
93 def test_formal_4x4(self):
94 self.tst_formal(4, ())
95
96 def test_formal_4x4_8(self):
97 self.tst_formal(4, (8,))
98
99 def test_formal_64x64(self):
100 self.tst_formal(64, ())
101
102 def test_formal_64x64_64(self):
103 self.tst_formal(64, (64,))
104
105 def test_formal_8x8_16_16_16(self):
106 self.tst_formal(8, (16, 16, 16))
107
108
109 if __name__ == "__main__":
110 unittest.main()