add cldivrem_shifting as a more efficient algorithm
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import unittest
8 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned
9 from nmigen.hdl.dsl import Module
10 from nmutil.formaltest import FHDLTestCase
11 from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
12 CLDivRemOutputData, CLDivRemShape,
13 cldivrem_shifting, CLDivRemState,
14 equal_leading_zero_count_reference,
15 EqualLeadingZeroCount)
16 from nmigen.sim import Delay, Tick
17 from nmutil.sim_util import do_sim, hash_256
18 from nmigen_gf.reference.cldivrem import cldivrem
19
20
21 class TestEqualLeadingZeroCount(FHDLTestCase):
22 def tst(self, width, full):
23 dut = EqualLeadingZeroCount(width)
24 self.assertEqual(dut.a.shape(), unsigned(width))
25 self.assertEqual(dut.b.shape(), unsigned(width))
26 self.assertEqual(dut.out.shape(), unsigned(1))
27
28 def case(a, b):
29 assert isinstance(a, int)
30 assert isinstance(b, int)
31 expected = a.bit_length() == b.bit_length()
32 with self.subTest(a=hex(a), b=hex(b),
33 expected=expected):
34 reference = equal_leading_zero_count_reference(a, b, width)
35 with self.subTest(reference=reference):
36 self.assertEqual(expected, reference)
37
38 with self.subTest(a=hex(a), b=hex(b),
39 expected=expected):
40 yield dut.a.eq(a)
41 yield dut.b.eq(b)
42 yield Delay(1e-6)
43 out = yield dut.out
44 with self.subTest(out=out):
45 self.assertEqual(expected, out)
46
47 def process():
48 if full:
49 for a in range(1 << width):
50 for b in range(1 << width):
51 yield from case(a, b)
52 else:
53 for i in range(100):
54 a = hash_256(f"eqlzc input a {i}")
55 a = Const.normalize(a, dut.a.shape())
56 b = hash_256(f"eqlzc input b {i}")
57 b = Const.normalize(b, dut.b.shape())
58 yield from case(a, b)
59
60 with do_sim(self, dut, [dut.a, dut.b, dut.out]) as sim:
61 sim.add_process(process)
62 sim.run()
63
64 def tst_formal(self, width):
65 dut = EqualLeadingZeroCount(width)
66 m = Module()
67 m.submodules.dut = dut
68 m.d.comb += dut.a.eq(AnyConst(width))
69 m.d.comb += dut.b.eq(AnyConst(width))
70 # use a bunch of Value.matches() and boolean logic rather than a
71 # giant Switch()/If() to avoid
72 # https://github.com/YosysHQ/yosys/issues/3268
73 expected_v = False
74 for leading_zeros in range(width + 1):
75 pattern = '0' * leading_zeros + '1' + '-' * width
76 pattern = pattern[0:width]
77 a_has_count = Signal(name=f"a_has_{leading_zeros}")
78 b_has_count = Signal(name=f"b_has_{leading_zeros}")
79 m.d.comb += [
80 a_has_count.eq(dut.a.matches(pattern)),
81 b_has_count.eq(dut.b.matches(pattern)),
82 ]
83 expected_v |= a_has_count & b_has_count
84 expected = Signal()
85 m.d.comb += expected.eq(expected_v)
86 m.d.comb += Assert(dut.out == expected)
87 self.assertFormal(m)
88
89 def test_64(self):
90 self.tst(64, full=False)
91
92 def test_8(self):
93 self.tst(8, full=False)
94
95 def test_3(self):
96 self.tst(3, full=True)
97
98 def test_formal_64(self):
99 self.tst_formal(64)
100
101 def test_formal_8(self):
102 self.tst_formal(8)
103
104 def test_formal_3(self):
105 self.tst_formal(3)
106
107
108 class TestCLDivRemShifting(FHDLTestCase):
109 def tst(self, width, full):
110 def case(n, d):
111 assert isinstance(n, int)
112 assert isinstance(d, int)
113 if d != 0:
114 expected_q, expected_r = cldivrem(n, d, width=width)
115 q, r = cldivrem_shifting(n, d, width=width)
116 else:
117 expected_q = expected_r = 0
118 q = r = 0
119 with self.subTest(n=hex(n), d=hex(d),
120 expected_q=hex(expected_q),
121 expected_r=hex(expected_r),
122 q=hex(q), r=hex(r)):
123 self.assertEqual(expected_q, q)
124 self.assertEqual(expected_r, r)
125 if full:
126 for n in range(1 << width):
127 for d in range(1 << width):
128 case(n, d)
129 else:
130 for i in range(100):
131 n = hash_256(f"cldivrem comb n {i}")
132 n = Const.normalize(n, unsigned(width))
133 d = hash_256(f"cldivrem comb d {i}")
134 d = Const.normalize(d, unsigned(width))
135 case(n, d)
136
137 def test_6(self):
138 self.tst(6, full=True)
139
140 def test_64(self):
141 self.tst(64, full=False)
142
143
144 class TestCLDivRemComb(FHDLTestCase):
145 def tst(self, shape, full):
146 assert isinstance(shape, CLDivRemShape)
147 m = Module()
148 n_in = Signal(shape.n_width)
149 d_in = Signal(shape.width)
150 states: "list[CLDivRemState]" = []
151 for i in shape.step_range:
152 states.append(CLDivRemState(shape, name=f"state_{i}"))
153 if i == 0:
154 states[i].set_to_initial(m, n=n_in, d=d_in)
155 else:
156 states[i].set_to_next(m, states[i - 1])
157
158 def case(n, d):
159 assert isinstance(n, int)
160 assert isinstance(d, int)
161 max_width = max(shape.width, shape.n_width)
162 if d != 0:
163 expected_q, expected_r = cldivrem(n, d, width=max_width)
164 else:
165 expected_q = expected_r = 0
166 with self.subTest(n=hex(n), d=hex(d),
167 expected_q=hex(expected_q),
168 expected_r=hex(expected_r)):
169 yield n_in.eq(n)
170 yield d_in.eq(d)
171 yield Delay(1e-6)
172 for i in shape.step_range:
173 with self.subTest(i=i):
174 done = yield states[i].done
175 step = yield states[i].step
176 self.assertEqual(done, i >= shape.done_step)
177 self.assertEqual(step, i)
178 q = yield states[-1].q
179 r = yield states[-1].r
180 with self.subTest(q=hex(q), r=hex(r)):
181 # only check results when inputs are valid
182 if d != 0 and (expected_q >> shape.width) == 0:
183 self.assertEqual(q, expected_q)
184 self.assertEqual(r, expected_r)
185
186 def process():
187 if full:
188 for n in range(1 << shape.n_width):
189 for d in range(1 << shape.width):
190 yield from case(n, d)
191 else:
192 for i in range(100):
193 n = hash_256(f"cldivrem comb n {i}")
194 n = Const.normalize(n, unsigned(shape.n_width))
195 d = hash_256(f"cldivrem comb d {i}")
196 d = Const.normalize(d, unsigned(shape.width))
197 yield from case(n, d)
198 with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
199 sim.add_process(process)
200 sim.run()
201
202 def test_4(self):
203 self.tst(CLDivRemShape(width=4, n_width=4), full=True)
204
205 def test_8_by_4(self):
206 self.tst(CLDivRemShape(width=4, n_width=8), full=True)
207
208
209 class TestCLDivRemFSM(FHDLTestCase):
210 def tst(self, shape, full, steps_per_clock):
211 assert isinstance(shape, CLDivRemShape)
212 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
213 pspec = {}
214 dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
215 i_data: CLDivRemInputData = dut.p.i_data
216 o_data: CLDivRemOutputData = dut.n.o_data
217 self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
218 self.assertEqual(i_data.d.shape(), unsigned(shape.width))
219 self.assertEqual(o_data.q.shape(), unsigned(shape.width))
220 self.assertEqual(o_data.r.shape(), unsigned(shape.width))
221
222 def case(n, d):
223 assert isinstance(n, int)
224 assert isinstance(d, int)
225 max_width = max(shape.width, shape.n_width)
226 if d != 0:
227 expected_q, expected_r = cldivrem(n, d, width=max_width)
228 else:
229 expected_q = expected_r = 0
230 with self.subTest(n=hex(n), d=hex(d),
231 expected_q=hex(expected_q),
232 expected_r=hex(expected_r)):
233 yield dut.p.i_valid.eq(0)
234 yield Tick()
235 yield i_data.n.eq(n)
236 yield i_data.d.eq(d)
237 yield dut.p.i_valid.eq(1)
238 yield Delay(0.1e-6)
239 valid = yield dut.n.o_valid
240 ready = yield dut.p.o_ready
241 with self.subTest():
242 self.assertFalse(valid)
243 self.assertTrue(ready)
244 yield Tick()
245 yield i_data.n.eq(-1)
246 yield i_data.d.eq(-1)
247 yield dut.p.i_valid.eq(0)
248 for i in range(steps_per_clock * 2, shape.done_step,
249 steps_per_clock):
250 yield Delay(0.1e-6)
251 valid = yield dut.n.o_valid
252 ready = yield dut.p.o_ready
253 with self.subTest():
254 self.assertFalse(valid)
255 self.assertFalse(ready)
256 yield Tick()
257 yield Delay(0.1e-6)
258 valid = yield dut.n.o_valid
259 ready = yield dut.p.o_ready
260 with self.subTest():
261 self.assertTrue(valid)
262 self.assertFalse(ready)
263 q = yield o_data.q
264 r = yield o_data.r
265 with self.subTest(q=hex(q), r=hex(r)):
266 # only check results when inputs are valid
267 if d != 0 and (expected_q >> shape.width) == 0:
268 self.assertEqual(q, expected_q)
269 self.assertEqual(r, expected_r)
270 yield dut.n.i_ready.eq(1)
271 yield Tick()
272 yield Delay(0.1e-6)
273 valid = yield dut.n.o_valid
274 ready = yield dut.p.o_ready
275 with self.subTest():
276 self.assertFalse(valid)
277 self.assertTrue(ready)
278 yield dut.n.i_ready.eq(0)
279
280 def process():
281 if full:
282 for n in range(1 << shape.n_width):
283 for d in range(1 << shape.width):
284 yield from case(n, d)
285 else:
286 for i in range(100):
287 n = hash_256(f"cldivrem fsm n {i}")
288 n = Const.normalize(n, unsigned(shape.n_width))
289 d = hash_256(f"cldivrem fsm d {i}")
290 d = Const.normalize(d, unsigned(shape.width))
291 yield from case(n, d)
292
293 with do_sim(self, dut, list(dut.ports())) as sim:
294 sim.add_process(process)
295 sim.add_clock(1e-6)
296 sim.run()
297
298 def test_4_step_1(self):
299 self.tst(CLDivRemShape(width=4, n_width=4),
300 full=True,
301 steps_per_clock=1)
302
303 def test_4_step_2(self):
304 self.tst(CLDivRemShape(width=4, n_width=4),
305 full=True,
306 steps_per_clock=2)
307
308 def test_4_step_3(self):
309 self.tst(CLDivRemShape(width=4, n_width=4),
310 full=True,
311 steps_per_clock=3)
312
313
314 if __name__ == "__main__":
315 unittest.main()