4c90e41206e828ebe9e14143bdb4c9f8fd4e3f18
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import (AnyConst, Assert, Signal, Const, unsigned, Cat)
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference,
12 EqualLeadingZeroCount)
13 from nmigen.sim import Delay
14 from nmutil.sim_util import do_sim, hash_256
15
16
17 class TestEqualLeadingZeroCount(FHDLTestCase):
18 def tst(self, width, full):
19 dut = EqualLeadingZeroCount(width)
20 self.assertEqual(dut.a.shape(), unsigned(width))
21 self.assertEqual(dut.b.shape(), unsigned(width))
22 self.assertEqual(dut.out.shape(), unsigned(1))
23
24 def case(a, b):
25 assert isinstance(a, int)
26 assert isinstance(b, int)
27 expected = a.bit_length() == b.bit_length()
28 with self.subTest(a=hex(a), b=hex(b),
29 expected=expected):
30 reference = equal_leading_zero_count_reference(a, b, width)
31 with self.subTest(reference=reference):
32 self.assertEqual(expected, reference)
33
34 with self.subTest(a=hex(a), b=hex(b),
35 expected=expected):
36 yield dut.a.eq(a)
37 yield dut.b.eq(b)
38 yield Delay(1e-6)
39 out = yield dut.out
40 with self.subTest(out=out):
41 self.assertEqual(expected, out)
42
43 def process():
44 if full:
45 for a in range(1 << width):
46 for b in range(1 << width):
47 yield from case(a, b)
48 else:
49 for i in range(100):
50 a = hash_256(f"eqlzc input a {i}")
51 a = Const.normalize(a, dut.a.shape())
52 b = hash_256(f"eqlzc input b {i}")
53 b = Const.normalize(b, dut.b.shape())
54 yield from case(a, b)
55
56 with do_sim(self, dut, [dut.a, dut.b, dut.out]) as sim:
57 sim.add_process(process)
58 sim.run()
59
60 def tst_formal(self, width):
61 dut = EqualLeadingZeroCount(width)
62 m = Module()
63 m.submodules.dut = dut
64 m.d.comb += dut.a.eq(AnyConst(width))
65 m.d.comb += dut.b.eq(AnyConst(width))
66 expected = Signal()
67 with m.Switch(Cat(dut.a, dut.b)):
68 with m.Case('0' * (2 * width)):
69 # `width` leading zeros
70 m.d.comb += expected.eq(1)
71 for i in range(width):
72 # `i` leading zeros
73 pattern = '0' * i + '1' + '-' * (width - i - 1)
74 with m.Case(pattern * 2):
75 m.d.comb += expected.eq(1)
76 with m.Default():
77 m.d.comb += expected.eq(0)
78 m.d.comb += Assert(dut.out == expected)
79 self.assertFormal(m)
80
81 def test_64(self):
82 self.tst(64, full=False)
83
84 def test_8(self):
85 self.tst(8, full=False)
86
87 def test_3(self):
88 self.tst(3, full=True)
89
90 def test_formal_16(self):
91 # yosys crashes with 32 or 64
92 self.tst_formal(16)
93
94 def test_formal_8(self):
95 self.tst_formal(8)
96
97 def test_formal_3(self):
98 self.tst_formal(3)
99
100 # TODO: add TestCLDivRem
101
102
103 if __name__ == "__main__":
104 unittest.main()