implement CLDivRemFSMStage
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
12 CLDivRemOutputData, CLDivRemShape, CLDivRemState,
13 equal_leading_zero_count_reference,
14 EqualLeadingZeroCount)
15 from nmigen.sim import Delay, Tick
16 from nmutil.sim_util import do_sim, hash_256
17 from nmigen_gf.reference.cldivrem import cldivrem
18
19
20 class TestEqualLeadingZeroCount(FHDLTestCase):
21 def tst(self, width, full):
22 dut = EqualLeadingZeroCount(width)
23 self.assertEqual(dut.a.shape(), unsigned(width))
24 self.assertEqual(dut.b.shape(), unsigned(width))
25 self.assertEqual(dut.out.shape(), unsigned(1))
26
27 def case(a, b):
28 assert isinstance(a, int)
29 assert isinstance(b, int)
30 expected = a.bit_length() == b.bit_length()
31 with self.subTest(a=hex(a), b=hex(b),
32 expected=expected):
33 reference = equal_leading_zero_count_reference(a, b, width)
34 with self.subTest(reference=reference):
35 self.assertEqual(expected, reference)
36
37 with self.subTest(a=hex(a), b=hex(b),
38 expected=expected):
39 yield dut.a.eq(a)
40 yield dut.b.eq(b)
41 yield Delay(1e-6)
42 out = yield dut.out
43 with self.subTest(out=out):
44 self.assertEqual(expected, out)
45
46 def process():
47 if full:
48 for a in range(1 << width):
49 for b in range(1 << width):
50 yield from case(a, b)
51 else:
52 for i in range(100):
53 a = hash_256(f"eqlzc input a {i}")
54 a = Const.normalize(a, dut.a.shape())
55 b = hash_256(f"eqlzc input b {i}")
56 b = Const.normalize(b, dut.b.shape())
57 yield from case(a, b)
58
59 with do_sim(self, dut, [dut.a, dut.b, dut.out]) as sim:
60 sim.add_process(process)
61 sim.run()
62
63 def tst_formal(self, width):
64 dut = EqualLeadingZeroCount(width)
65 m = Module()
66 m.submodules.dut = dut
67 m.d.comb += dut.a.eq(AnyConst(width))
68 m.d.comb += dut.b.eq(AnyConst(width))
69 # use a bunch of Value.matches() and boolean logic rather than a
70 # giant Switch()/If() to avoid
71 # https://github.com/YosysHQ/yosys/issues/3268
72 expected_v = False
73 for leading_zeros in range(width + 1):
74 pattern = '0' * leading_zeros + '1' + '-' * width
75 pattern = pattern[0:width]
76 a_has_count = Signal(name=f"a_has_{leading_zeros}")
77 b_has_count = Signal(name=f"b_has_{leading_zeros}")
78 m.d.comb += [
79 a_has_count.eq(dut.a.matches(pattern)),
80 b_has_count.eq(dut.b.matches(pattern)),
81 ]
82 expected_v |= a_has_count & b_has_count
83 expected = Signal()
84 m.d.comb += expected.eq(expected_v)
85 m.d.comb += Assert(dut.out == expected)
86 self.assertFormal(m)
87
88 def test_64(self):
89 self.tst(64, full=False)
90
91 def test_8(self):
92 self.tst(8, full=False)
93
94 def test_3(self):
95 self.tst(3, full=True)
96
97 def test_formal_64(self):
98 self.tst_formal(64)
99
100 def test_formal_8(self):
101 self.tst_formal(8)
102
103 def test_formal_3(self):
104 self.tst_formal(3)
105
106
107 class TestCLDivRemComb(FHDLTestCase):
108 def tst(self, shape, full):
109 assert isinstance(shape, CLDivRemShape)
110 m = Module()
111 n_in = Signal(shape.n_width)
112 d_in = Signal(shape.width)
113 states: "list[CLDivRemState]" = []
114 for i in shape.step_range:
115 states.append(CLDivRemState(shape, name=f"state_{i}"))
116 if i == 0:
117 states[i].set_to_initial(m, n=n_in, d=d_in)
118 else:
119 states[i].set_to_next(m, states[i - 1])
120
121 def case(n, d):
122 assert isinstance(n, int)
123 assert isinstance(d, int)
124 max_width = max(shape.width, shape.n_width)
125 if d != 0:
126 expected_q, expected_r = cldivrem(n, d, width=max_width)
127 else:
128 expected_q = expected_r = 0
129 with self.subTest(n=hex(n), d=hex(d),
130 expected_q=hex(expected_q),
131 expected_r=hex(expected_r)):
132 yield n_in.eq(n)
133 yield d_in.eq(d)
134 yield Delay(1e-6)
135 for i in shape.step_range:
136 with self.subTest(i=i):
137 done = yield states[i].done
138 step = yield states[i].step
139 self.assertEqual(done, i >= shape.done_step)
140 self.assertEqual(step, i)
141 q = yield states[-1].q
142 r = yield states[-1].r
143 with self.subTest(q=hex(q), r=hex(r)):
144 # only check results when inputs are valid
145 if d != 0 and (expected_q >> shape.width) == 0:
146 self.assertEqual(q, expected_q)
147 self.assertEqual(r, expected_r)
148
149 def process():
150 if full:
151 for n in range(1 << shape.n_width):
152 for d in range(1 << shape.width):
153 yield from case(n, d)
154 else:
155 for i in range(100):
156 n = hash_256(f"cldivrem comb n {i}")
157 n = Const.normalize(n, unsigned(shape.n_width))
158 d = hash_256(f"cldivrem comb d {i}")
159 d = Const.normalize(d, unsigned(shape.width))
160 yield from case(n, d)
161 with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
162 sim.add_process(process)
163 sim.run()
164
165 def test_4(self):
166 self.tst(CLDivRemShape(width=4, n_width=4), full=True)
167
168 def test_8_by_4(self):
169 self.tst(CLDivRemShape(width=4, n_width=8), full=True)
170
171
172 class TestCLDivRemFSM(FHDLTestCase):
173 def tst(self, shape, full, steps_per_clock):
174 assert isinstance(shape, CLDivRemShape)
175 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
176 pspec = {}
177 dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
178 i_data: CLDivRemInputData = dut.p.i_data
179 o_data: CLDivRemOutputData = dut.n.o_data
180 self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
181 self.assertEqual(i_data.d.shape(), unsigned(shape.width))
182 self.assertEqual(o_data.q.shape(), unsigned(shape.width))
183 self.assertEqual(o_data.r.shape(), unsigned(shape.width))
184
185 def case(n, d):
186 assert isinstance(n, int)
187 assert isinstance(d, int)
188 max_width = max(shape.width, shape.n_width)
189 if d != 0:
190 expected_q, expected_r = cldivrem(n, d, width=max_width)
191 else:
192 expected_q = expected_r = 0
193 with self.subTest(n=hex(n), d=hex(d),
194 expected_q=hex(expected_q),
195 expected_r=hex(expected_r)):
196 yield dut.p.i_valid.eq(0)
197 yield Tick()
198 yield i_data.n.eq(n)
199 yield i_data.d.eq(d)
200 yield dut.p.i_valid.eq(1)
201 yield Delay(0.1e-6)
202 valid = yield dut.n.o_valid
203 ready = yield dut.p.o_ready
204 with self.subTest():
205 self.assertFalse(valid)
206 self.assertTrue(ready)
207 yield Tick()
208 yield i_data.n.eq(-1)
209 yield i_data.d.eq(-1)
210 yield dut.p.i_valid.eq(0)
211 for i in range(steps_per_clock * 2, shape.done_step,
212 steps_per_clock):
213 yield Delay(0.1e-6)
214 valid = yield dut.n.o_valid
215 ready = yield dut.p.o_ready
216 with self.subTest():
217 self.assertFalse(valid)
218 self.assertFalse(ready)
219 yield Tick()
220 yield Delay(0.1e-6)
221 valid = yield dut.n.o_valid
222 ready = yield dut.p.o_ready
223 with self.subTest():
224 self.assertTrue(valid)
225 self.assertFalse(ready)
226 q = yield o_data.q
227 r = yield o_data.r
228 with self.subTest(q=hex(q), r=hex(r)):
229 # only check results when inputs are valid
230 if d != 0 and (expected_q >> shape.width) == 0:
231 self.assertEqual(q, expected_q)
232 self.assertEqual(r, expected_r)
233 yield dut.n.i_ready.eq(1)
234 yield Tick()
235 yield Delay(0.1e-6)
236 valid = yield dut.n.o_valid
237 ready = yield dut.p.o_ready
238 with self.subTest():
239 self.assertFalse(valid)
240 self.assertTrue(ready)
241 yield dut.n.i_ready.eq(0)
242
243 def process():
244 if full:
245 for n in range(1 << shape.n_width):
246 for d in range(1 << shape.width):
247 yield from case(n, d)
248 else:
249 for i in range(100):
250 n = hash_256(f"cldivrem fsm n {i}")
251 n = Const.normalize(n, unsigned(shape.n_width))
252 d = hash_256(f"cldivrem fsm d {i}")
253 d = Const.normalize(d, unsigned(shape.width))
254 yield from case(n, d)
255
256 with do_sim(self, dut, list(dut.ports())) as sim:
257 sim.add_process(process)
258 sim.add_clock(1e-6)
259 sim.run()
260
261 def test_4_step_1(self):
262 self.tst(CLDivRemShape(width=4, n_width=4),
263 full=True,
264 steps_per_clock=1)
265
266 def test_4_step_2(self):
267 self.tst(CLDivRemShape(width=4, n_width=4),
268 full=True,
269 steps_per_clock=2)
270
271 def test_4_step_3(self):
272 self.tst(CLDivRemShape(width=4, n_width=4),
273 full=True,
274 steps_per_clock=3)
275
276
277 if __name__ == "__main__":
278 unittest.main()