fix so HDL works for 5, 8, 16, 32, and 64-bits.
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import fields, replace
8 import math
9 import unittest
10 from nmutil.formaltest import FHDLTestCase
11 from nmutil.sim_util import do_sim, hash_256
12 from nmigen.sim import Tick, Delay
13 from nmigen.hdl.ast import Signal
14 from nmigen.hdl.dsl import Module
15 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
16 GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
17 GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
18 FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
19
20
21 class TestFixedPoint(FHDLTestCase):
22 def test_str_roundtrip(self):
23 for frac_wid in range(8):
24 for bits in range(-1 << 9, 1 << 9):
25 with self.subTest(bits=hex(bits), frac_wid=frac_wid):
26 value = FixedPoint(bits, frac_wid)
27 round_trip_value = FixedPoint.cast(str(value))
28 self.assertEqual(value, round_trip_value)
29
30 @staticmethod
31 def trap(f):
32 try:
33 return f(), None
34 except (ValueError, ZeroDivisionError) as e:
35 return None, e.__class__.__name__
36
37 def test_sqrt(self):
38 for frac_wid in range(8):
39 for bits in range(1 << 9):
40 for round_dir in RoundDir:
41 radicand = FixedPoint(bits, frac_wid)
42 expected_f = math.sqrt(float(radicand))
43 expected = self.trap(lambda: FixedPoint.with_frac_wid(
44 expected_f, frac_wid, round_dir))
45 with self.subTest(radicand=repr(radicand),
46 round_dir=str(round_dir),
47 expected=repr(expected)):
48 result = self.trap(lambda: radicand.sqrt(round_dir))
49 self.assertEqual(result, expected)
50
51 def test_rsqrt(self):
52 for frac_wid in range(8):
53 for bits in range(1, 1 << 9):
54 for round_dir in RoundDir:
55 radicand = FixedPoint(bits, frac_wid)
56 expected_f = 1 / math.sqrt(float(radicand))
57 expected = self.trap(lambda: FixedPoint.with_frac_wid(
58 expected_f, frac_wid, round_dir))
59 with self.subTest(radicand=repr(radicand),
60 round_dir=str(round_dir),
61 expected=repr(expected)):
62 result = self.trap(lambda: radicand.rsqrt(round_dir))
63 self.assertEqual(result, expected)
64
65
66 class TestGoldschmidtDiv(FHDLTestCase):
67 def test_case1(self):
68 with self.assertRaises(ParamsNotAccurateEnough):
69 GoldschmidtDivParams(io_width=3, extra_precision=2,
70 table_addr_bits=3, table_data_bits=5,
71 iter_count=2)
72
73 def test_case2(self):
74 with self.assertRaises(ParamsNotAccurateEnough):
75 GoldschmidtDivParams(io_width=4, extra_precision=1,
76 table_addr_bits=1, table_data_bits=5,
77 iter_count=1)
78
79 @staticmethod
80 def cases(io_width, cases=None):
81 assert isinstance(io_width, int) and io_width >= 1
82 if cases is not None:
83 for n, d in cases:
84 assert isinstance(d, int) \
85 and 0 < d < (1 << io_width), "invalid case"
86 assert isinstance(n, int) \
87 and 0 <= n < (d << io_width), "invalid case"
88 yield (n, d)
89 elif io_width > 6:
90 assert io_width * 2 <= 256, \
91 "can't generate big enough numbers for test cases"
92 for i in range(10000):
93 d = hash_256(f'd {i}') % (1 << io_width)
94 if d == 0:
95 d = 1
96 n = hash_256(f'n {i}') % (d << io_width)
97 yield (n, d)
98 else:
99 for d in range(1, 1 << io_width):
100 for n in range(d << io_width):
101 yield (n, d)
102
103 def tst(self, io_width, cases=None):
104 assert isinstance(io_width, int)
105 params = GoldschmidtDivParams.get(io_width)
106 with self.subTest(params=str(params)):
107 for n, d in self.cases(io_width, cases):
108 expected_q, expected_r = divmod(n, d)
109 with self.subTest(n=hex(n), d=hex(d),
110 expected_q=hex(expected_q),
111 expected_r=hex(expected_r)):
112 q, r = goldschmidt_div(n, d, params)
113 with self.subTest(q=hex(q), r=hex(r)):
114 self.assertEqual((q, r), (expected_q, expected_r))
115
116 def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
117 sync_rom=False):
118 assert isinstance(io_width, int)
119 params = GoldschmidtDivParams.get(io_width)
120 m = Module()
121 dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes,
122 sync_rom=sync_rom)
123 m.submodules.dut = dut
124 # make sync domain get added
125 m.d.sync += Signal().eq(0)
126
127 def inputs_proc():
128 yield Tick()
129 for n, d in self.cases(io_width, cases):
130 yield dut.n.eq(n)
131 yield dut.d.eq(d)
132 yield Tick()
133
134 def check_interals(n, d):
135 # check internals only if dut is completely combinatorial
136 # so we don't have to figure out how to read values in
137 # previous clock cycles
138 if dut.total_pipeline_registers != 0:
139 return
140 ref_trace = []
141
142 def ref_trace_fn(state):
143 assert isinstance(state, GoldschmidtDivState)
144 ref_trace.append((replace(state)))
145 goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
146 self.assertEqual(len(dut.trace), len(ref_trace))
147 for index, state in enumerate(dut.trace):
148 ref_state = ref_trace[index]
149 last_op = None if index == 0 else params.ops[index - 1]
150 with self.subTest(index=index, state=repr(state),
151 ref_state=repr(ref_state),
152 last_op=str(last_op)):
153 for field in fields(GoldschmidtDivHDLState):
154 sig = getattr(state, field.name)
155 if not isinstance(sig, Signal):
156 continue
157 ref_value = getattr(ref_state, field.name)
158 ref_value_str = repr(ref_value)
159 if isinstance(ref_value, int):
160 ref_value_str = hex(ref_value)
161 value = yield sig
162 with self.subTest(field_name=field.name,
163 sig=repr(sig),
164 sig_shape=repr(sig.shape()),
165 value=hex(value),
166 ref_value=ref_value_str):
167 if isinstance(ref_value, int):
168 self.assertEqual(value, ref_value)
169 else:
170 assert isinstance(ref_value, FixedPoint)
171 self.assertEqual(value, ref_value.bits)
172
173 def check_outputs():
174 yield Tick()
175 for _ in range(dut.total_pipeline_registers):
176 yield Tick()
177 for n, d in self.cases(io_width, cases):
178 yield Delay(0.1e-6)
179 expected_q, expected_r = divmod(n, d)
180 with self.subTest(n=hex(n), d=hex(d),
181 expected_q=hex(expected_q),
182 expected_r=hex(expected_r)):
183 q = yield dut.q
184 r = yield dut.r
185 with self.subTest(q=hex(q), r=hex(r)):
186 self.assertEqual((q, r), (expected_q, expected_r))
187 yield from check_interals(n, d)
188
189 yield Tick()
190
191 with self.subTest(params=str(params)):
192 with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim:
193 sim.add_clock(1e-6)
194 sim.add_process(inputs_proc)
195 sim.add_process(check_outputs)
196 sim.run()
197
198 def test_1_through_4(self):
199 for io_width in range(1, 4 + 1):
200 with self.subTest(io_width=io_width):
201 self.tst(io_width)
202
203 def test_5(self):
204 self.tst(5)
205
206 def test_6(self):
207 self.tst(6)
208
209 def test_8(self):
210 self.tst(8)
211
212 def test_16(self):
213 self.tst(16)
214
215 def test_32(self):
216 self.tst(32)
217
218 def test_64(self):
219 self.tst(64)
220
221 def test_sim_5(self):
222 self.tst_sim(5)
223
224 def test_sim_8(self):
225 self.tst_sim(8)
226
227 def test_sim_16(self):
228 self.tst_sim(16)
229
230 def test_sim_32(self):
231 self.tst_sim(32)
232
233 def test_sim_64(self):
234 self.tst_sim(64)
235
236 def tst_params(self, io_width):
237 assert isinstance(io_width, int)
238 params = GoldschmidtDivParams.get(io_width)
239 print()
240 print(params)
241
242 def test_params_1(self):
243 self.tst_params(1)
244
245 def test_params_2(self):
246 self.tst_params(2)
247
248 def test_params_3(self):
249 self.tst_params(3)
250
251 def test_params_4(self):
252 self.tst_params(4)
253
254 def test_params_5(self):
255 self.tst_params(5)
256
257 def test_params_6(self):
258 self.tst_params(6)
259
260 def test_params_7(self):
261 self.tst_params(7)
262
263 def test_params_8(self):
264 self.tst_params(8)
265
266 def test_params_9(self):
267 self.tst_params(9)
268
269 def test_params_10(self):
270 self.tst_params(10)
271
272 def test_params_11(self):
273 self.tst_params(11)
274
275 def test_params_12(self):
276 self.tst_params(12)
277
278 def test_params_13(self):
279 self.tst_params(13)
280
281 def test_params_14(self):
282 self.tst_params(14)
283
284 def test_params_15(self):
285 self.tst_params(15)
286
287 def test_params_16(self):
288 self.tst_params(16)
289
290 def test_params_17(self):
291 self.tst_params(17)
292
293 def test_params_18(self):
294 self.tst_params(18)
295
296 def test_params_19(self):
297 self.tst_params(19)
298
299 def test_params_20(self):
300 self.tst_params(20)
301
302 def test_params_21(self):
303 self.tst_params(21)
304
305 def test_params_22(self):
306 self.tst_params(22)
307
308 def test_params_23(self):
309 self.tst_params(23)
310
311 def test_params_24(self):
312 self.tst_params(24)
313
314 def test_params_25(self):
315 self.tst_params(25)
316
317 def test_params_26(self):
318 self.tst_params(26)
319
320 def test_params_27(self):
321 self.tst_params(27)
322
323 def test_params_28(self):
324 self.tst_params(28)
325
326 def test_params_29(self):
327 self.tst_params(29)
328
329 def test_params_30(self):
330 self.tst_params(30)
331
332 def test_params_31(self):
333 self.tst_params(31)
334
335 def test_params_32(self):
336 self.tst_params(32)
337
338 def test_params_33(self):
339 self.tst_params(33)
340
341 def test_params_34(self):
342 self.tst_params(34)
343
344 def test_params_35(self):
345 self.tst_params(35)
346
347 def test_params_36(self):
348 self.tst_params(36)
349
350 def test_params_37(self):
351 self.tst_params(37)
352
353 def test_params_38(self):
354 self.tst_params(38)
355
356 def test_params_39(self):
357 self.tst_params(39)
358
359 def test_params_40(self):
360 self.tst_params(40)
361
362 def test_params_41(self):
363 self.tst_params(41)
364
365 def test_params_42(self):
366 self.tst_params(42)
367
368 def test_params_43(self):
369 self.tst_params(43)
370
371 def test_params_44(self):
372 self.tst_params(44)
373
374 def test_params_45(self):
375 self.tst_params(45)
376
377 def test_params_46(self):
378 self.tst_params(46)
379
380 def test_params_47(self):
381 self.tst_params(47)
382
383 def test_params_48(self):
384 self.tst_params(48)
385
386 def test_params_49(self):
387 self.tst_params(49)
388
389 def test_params_50(self):
390 self.tst_params(50)
391
392 def test_params_51(self):
393 self.tst_params(51)
394
395 def test_params_52(self):
396 self.tst_params(52)
397
398 def test_params_53(self):
399 self.tst_params(53)
400
401 def test_params_54(self):
402 self.tst_params(54)
403
404 def test_params_55(self):
405 self.tst_params(55)
406
407 def test_params_56(self):
408 self.tst_params(56)
409
410 def test_params_57(self):
411 self.tst_params(57)
412
413 def test_params_58(self):
414 self.tst_params(58)
415
416 def test_params_59(self):
417 self.tst_params(59)
418
419 def test_params_60(self):
420 self.tst_params(60)
421
422 def test_params_61(self):
423 self.tst_params(61)
424
425 def test_params_62(self):
426 self.tst_params(62)
427
428 def test_params_63(self):
429 self.tst_params(63)
430
431 def test_params_64(self):
432 self.tst_params(64)
433
434
435 class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
436 def tst(self, io_width, frac_wid, extra_precision,
437 table_addr_bits, table_data_bits, iter_count):
438 assert isinstance(io_width, int)
439 assert isinstance(frac_wid, int)
440 assert isinstance(extra_precision, int)
441 assert isinstance(table_addr_bits, int)
442 assert isinstance(table_data_bits, int)
443 assert isinstance(iter_count, int)
444 with self.subTest(io_width=io_width, frac_wid=frac_wid,
445 extra_precision=extra_precision,
446 table_addr_bits=table_addr_bits,
447 table_data_bits=table_data_bits,
448 iter_count=iter_count):
449 for bits in range(1 << io_width):
450 radicand = FixedPoint(bits, frac_wid)
451 expected_sqrt = radicand.sqrt(RoundDir.DOWN)
452 expected_rsqrt = FixedPoint(0, frac_wid)
453 if radicand > 0:
454 expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
455 with self.subTest(radicand=repr(radicand),
456 expected_sqrt=repr(expected_sqrt),
457 expected_rsqrt=repr(expected_rsqrt)):
458 sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
459 radicand=radicand, io_width=io_width,
460 frac_wid=frac_wid,
461 extra_precision=extra_precision,
462 table_addr_bits=table_addr_bits,
463 table_data_bits=table_data_bits,
464 iter_count=iter_count)
465 with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
466 self.assertEqual((sqrt, rsqrt),
467 (expected_sqrt, expected_rsqrt))
468
469 def test1(self):
470 self.tst(io_width=16, frac_wid=8, extra_precision=20,
471 table_addr_bits=4, table_data_bits=28, iter_count=4)
472
473
474 if __name__ == "__main__":
475 unittest.main()