1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from functools
import reduce
8 from operator
import xor
10 from nmigen
.hdl
.ast
import (AnyConst
, Assert
, Signal
, Const
, unsigned
, signed
,
12 from nmigen
.hdl
.dsl
import Module
13 from nmutil
.formaltest
import FHDLTestCase
14 from nmigen_gf
.reference
.clmul
import clmul
15 from nmutil
.clmul
import BitwiseXorReduce
, CLMulAdd
16 from nmigen
.sim
import Delay
17 from nmutil
.sim_util
import do_sim
, hash_256
20 class TestBitwiseXorReduce(FHDLTestCase
):
21 def tst(self
, input_shapes
):
22 dut
= BitwiseXorReduce(Signal(w
, name
=f
"input_{i}")
23 for i
, w
in enumerate(input_shapes
))
24 self
.assertEqual(reduce(xor
, dut
.input_values
).shape(),
28 expected
= reduce(xor
, inputs
)
29 with self
.subTest(inputs
=list(map(hex, inputs
)),
30 expected
=hex(expected
)):
31 for i
, inp
in enumerate(inputs
):
32 yield dut
.input_values
[i
].eq(inp
)
34 output
= yield dut
.output
35 with self
.subTest(output
=hex(output
)):
36 self
.assertEqual(expected
, output
)
41 for inp
in dut
.input_values
:
42 v
= hash_256(f
"bxorr input {i} {inp.name}")
43 inputs
.append(Const
.normalize(v
, inp
.shape()))
44 yield from case(inputs
)
46 with
do_sim(self
, dut
, [*dut
.input_values
, dut
.output
]) as sim
:
47 sim
.add_process(process
)
50 def tst_formal(self
, input_shapes
):
51 dut
= BitwiseXorReduce(Signal(w
, name
=f
"input_{i}")
52 for i
, w
in enumerate(input_shapes
))
54 m
.submodules
.dut
= dut
55 for i
in dut
.input_values
:
56 m
.d
.comb
+= i
.eq(AnyConst(i
.shape()))
57 m
.d
.comb
+= Assert(dut
.output
== reduce(xor
, dut
.input_values
))
60 def test_65_of_u64(self
):
63 def test_formal_65_of_u64(self
):
64 self
.tst_formal([64] * 65)
66 def test_5_of_u6(self
):
69 def test_formal_5_of_u6(self
):
70 self
.tst_formal([6] * 5)
72 def test_u5_i6_u3_i10(self
):
73 self
.tst([unsigned(5), signed(6), unsigned(3), signed(10)])
75 def test_formal_u5_i6_u3_i10(self
):
76 self
.tst_formal([unsigned(5), signed(6), unsigned(3), signed(10)])
79 class TestCLMulAdd(FHDLTestCase
):
80 def tst(self
, factor_width
, terms_width
):
81 dut
= CLMulAdd(factor_width
, terms_width
)
82 self
.assertEqual(dut
.output
.width
,
83 max((factor_width
* 2 - 1, *terms_width
)))
85 def case(factor1
, factor2
, terms
):
86 expected
= reduce(xor
, terms
, clmul(factor1
, factor2
))
87 with self
.subTest(factor1
=hex(factor1
),
89 terms
=list(map(hex, terms
)),
90 expected
=hex(expected
)):
91 yield dut
.factor1
.eq(factor1
)
92 yield dut
.factor2
.eq(factor2
)
93 for i
, term
in enumerate(terms
):
94 yield dut
.terms
[i
].eq(term
)
96 output
= yield dut
.output
97 with self
.subTest(output
=hex(output
)):
98 self
.assertEqual(expected
, output
)
102 v
= hash_256(f
"clmuladd term {i} factor1")
103 factor1
= Const
.normalize(v
, unsigned(factor_width
))
104 v
= hash_256(f
"clmuladd term {i} factor2")
105 factor2
= Const
.normalize(v
, unsigned(factor_width
))
107 for j
, term_width
in enumerate(terms_width
):
108 v
= hash_256(f
"clmuladd term {i} {j}")
109 terms
.append(Const
.normalize(v
, unsigned(term_width
)))
110 yield from case(factor1
, factor2
, terms
)
111 with
do_sim(self
, dut
, [dut
.factor1
, dut
.factor2
, *dut
.terms
,
113 sim
.add_process(process
)
119 def test_4x4_8(self
):
122 def test_64x64(self
):
125 def test_64x64_64(self
):
128 def test_8x8_16_16_16(self
):
129 self
.tst(8, (16, 16, 16))
131 def tst_formal(self
, factor_width
, terms_width
):
132 dut
= CLMulAdd(factor_width
, terms_width
)
134 m
.submodules
.dut
= dut
135 m
.d
.comb
+= dut
.factor1
.eq(AnyConst(factor_width
))
136 m
.d
.comb
+= dut
.factor2
.eq(AnyConst(factor_width
))
138 for shift
in range(factor_width
):
139 reduce_inputs
.append(
140 Mux(dut
.factor1
[shift
], dut
.factor2
<< shift
, 0))
142 m
.d
.comb
+= i
.eq(AnyConst(i
.shape()))
143 reduce_inputs
.append(i
)
144 for i
in range(len(reduce_inputs
)):
145 sig
= Signal(reduce_inputs
[i
].shape(), name
=f
"reduce_input_{i}")
146 m
.d
.comb
+= sig
.eq(reduce_inputs
[i
])
147 reduce_inputs
[i
] = sig
148 expected
= Signal(reduce(xor
, reduce_inputs
).shape())
149 m
.d
.comb
+= expected
.eq(reduce(xor
, reduce_inputs
))
150 m
.d
.comb
+= Assert(dut
.output
== expected
)
153 def test_formal_4x4(self
):
154 self
.tst_formal(4, ())
156 def test_formal_4x4_8(self
):
157 self
.tst_formal(4, (8,))
159 def test_formal_64x64(self
):
160 self
.tst_formal(64, ())
162 def test_formal_64x64_64(self
):
163 self
.tst_formal(64, (64,))
165 def test_formal_8x8_16_16_16(self
):
166 self
.tst_formal(8, (16, 16, 16))
169 if __name__
== "__main__":