1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from functools
import reduce
10 from nmigen
.hdl
.ast
import AnyConst
, Assert
, Signal
, Const
, unsigned
, Mux
11 from nmigen
.hdl
.dsl
import Module
12 from nmutil
.formaltest
import FHDLTestCase
13 from nmigen_gf
.reference
.clmul
import clmul
14 from nmigen_gf
.hdl
.clmul
import CLMulAdd
15 from nmigen
.sim
import Delay
16 from nmutil
.sim_util
import do_sim
, hash_256
19 class TestCLMulAdd(FHDLTestCase
):
20 def tst(self
, factor_width
, terms_width
):
21 dut
= CLMulAdd(factor_width
, terms_width
)
22 self
.assertEqual(dut
.output
.width
,
23 max((factor_width
* 2 - 1, *terms_width
)))
25 def case(factor1
, factor2
, terms
):
26 expected
= reduce(operator
.xor
, terms
, clmul(factor1
, factor2
))
27 with self
.subTest(factor1
=hex(factor1
),
29 terms
=list(map(hex, terms
)),
30 expected
=hex(expected
)):
31 yield dut
.factor1
.eq(factor1
)
32 yield dut
.factor2
.eq(factor2
)
33 for i
, term
in enumerate(terms
):
34 yield dut
.terms
[i
].eq(term
)
36 output
= yield dut
.output
37 with self
.subTest(output
=hex(output
)):
38 self
.assertEqual(expected
, output
)
42 v
= hash_256(f
"clmuladd term {i} factor1")
43 factor1
= Const
.normalize(v
, unsigned(factor_width
))
44 v
= hash_256(f
"clmuladd term {i} factor2")
45 factor2
= Const
.normalize(v
, unsigned(factor_width
))
47 for j
, term_width
in enumerate(terms_width
):
48 v
= hash_256(f
"clmuladd term {i} {j}")
49 terms
.append(Const
.normalize(v
, unsigned(term_width
)))
50 yield from case(factor1
, factor2
, terms
)
51 with
do_sim(self
, dut
, [dut
.factor1
, dut
.factor2
, *dut
.terms
,
53 sim
.add_process(process
)
65 def test_64x64_64(self
):
68 def test_8x8_16_16_16(self
):
69 self
.tst(8, (16, 16, 16))
71 def tst_formal(self
, factor_width
, terms_width
):
72 dut
= CLMulAdd(factor_width
, terms_width
)
74 m
.submodules
.dut
= dut
75 m
.d
.comb
+= dut
.factor1
.eq(AnyConst(factor_width
))
76 m
.d
.comb
+= dut
.factor2
.eq(AnyConst(factor_width
))
78 for shift
in range(factor_width
):
80 Mux(dut
.factor1
[shift
], dut
.factor2
<< shift
, 0))
82 m
.d
.comb
+= i
.eq(AnyConst(i
.shape()))
83 reduce_inputs
.append(i
)
84 for i
in range(len(reduce_inputs
)):
85 sig
= Signal(reduce_inputs
[i
].shape(), name
=f
"reduce_input_{i}")
86 m
.d
.comb
+= sig
.eq(reduce_inputs
[i
])
87 reduce_inputs
[i
] = sig
88 expected
= Signal(reduce(operator
.xor
, reduce_inputs
).shape())
89 m
.d
.comb
+= expected
.eq(reduce(operator
.xor
, reduce_inputs
))
90 m
.d
.comb
+= Assert(dut
.output
== expected
)
93 def test_formal_4x4(self
):
94 self
.tst_formal(4, ())
96 def test_formal_4x4_8(self
):
97 self
.tst_formal(4, (8,))
99 def test_formal_64x64(self
):
100 self
.tst_formal(64, ())
102 def test_formal_64x64_64(self
):
103 self
.tst_formal(64, (64,))
105 def test_formal_8x8_16_16_16(self
):
106 self
.tst_formal(8, (16, 16, 16))
109 if __name__
== "__main__":