bf999bd850237387120691f16bcda7eb308e21df
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import fields, replace
8 import math
9 import unittest
10 from nmutil.formaltest import FHDLTestCase
11 from nmutil.sim_util import do_sim, hash_256
12 from nmigen.sim import Tick, Delay
13 from nmigen.hdl.ast import Signal
14 from nmigen.hdl.dsl import Module
15 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
16 GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
17 GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
18 FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
19
20
21 class TestFixedPoint(FHDLTestCase):
22 def test_str_roundtrip(self):
23 for frac_wid in range(8):
24 for bits in range(-1 << 9, 1 << 9):
25 with self.subTest(bits=hex(bits), frac_wid=frac_wid):
26 value = FixedPoint(bits, frac_wid)
27 round_trip_value = FixedPoint.cast(str(value))
28 self.assertEqual(value, round_trip_value)
29
30 @staticmethod
31 def trap(f):
32 try:
33 return f(), None
34 except (ValueError, ZeroDivisionError) as e:
35 return None, e.__class__.__name__
36
37 def test_sqrt(self):
38 for frac_wid in range(8):
39 for bits in range(1 << 9):
40 for round_dir in RoundDir:
41 radicand = FixedPoint(bits, frac_wid)
42 expected_f = math.sqrt(float(radicand))
43 expected = self.trap(lambda: FixedPoint.with_frac_wid(
44 expected_f, frac_wid, round_dir))
45 with self.subTest(radicand=repr(radicand),
46 round_dir=str(round_dir),
47 expected=repr(expected)):
48 result = self.trap(lambda: radicand.sqrt(round_dir))
49 self.assertEqual(result, expected)
50
51 def test_rsqrt(self):
52 for frac_wid in range(8):
53 for bits in range(1, 1 << 9):
54 for round_dir in RoundDir:
55 radicand = FixedPoint(bits, frac_wid)
56 expected_f = 1 / math.sqrt(float(radicand))
57 expected = self.trap(lambda: FixedPoint.with_frac_wid(
58 expected_f, frac_wid, round_dir))
59 with self.subTest(radicand=repr(radicand),
60 round_dir=str(round_dir),
61 expected=repr(expected)):
62 result = self.trap(lambda: radicand.rsqrt(round_dir))
63 self.assertEqual(result, expected)
64
65
66 class TestGoldschmidtDiv(FHDLTestCase):
67 def test_case1(self):
68 with self.assertRaises(ParamsNotAccurateEnough):
69 GoldschmidtDivParams(io_width=3, extra_precision=2,
70 table_addr_bits=3, table_data_bits=5,
71 iter_count=2)
72
73 def test_case2(self):
74 with self.assertRaises(ParamsNotAccurateEnough):
75 GoldschmidtDivParams(io_width=4, extra_precision=1,
76 table_addr_bits=1, table_data_bits=5,
77 iter_count=1)
78
79 @staticmethod
80 def cases(io_width, cases=None):
81 assert isinstance(io_width, int) and io_width >= 1
82 if cases is not None:
83 for n, d in cases:
84 assert isinstance(d, int) \
85 and 0 < d < (1 << io_width), "invalid case"
86 assert isinstance(n, int) \
87 and 0 <= n < (d << io_width), "invalid case"
88 yield (n, d)
89 elif io_width > 6:
90 assert io_width * 2 <= 256, \
91 "can't generate big enough numbers for test cases"
92 for i in range(10000):
93 d = hash_256(f'd {i}') % (1 << io_width)
94 if d == 0:
95 d = 1
96 n = hash_256(f'n {i}') % (d << io_width)
97 yield (n, d)
98 else:
99 for d in range(1, 1 << io_width):
100 for n in range(d << io_width):
101 yield (n, d)
102
103 def tst(self, io_width, cases=None):
104 assert isinstance(io_width, int)
105 params = GoldschmidtDivParams.get(io_width)
106 with self.subTest(params=str(params)):
107 for n, d in self.cases(io_width, cases):
108 expected_q, expected_r = divmod(n, d)
109 with self.subTest(n=hex(n), d=hex(d),
110 expected_q=hex(expected_q),
111 expected_r=hex(expected_r)):
112 trace = []
113
114 def trace_fn(state):
115 assert isinstance(state, GoldschmidtDivState)
116 trace.append((replace(state)))
117 q, r = goldschmidt_div(n, d, params, trace=trace_fn)
118 with self.subTest(q=hex(q), r=hex(r), trace=repr(trace)):
119 self.assertEqual((q, r), (expected_q, expected_r))
120
121 def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
122 sync_rom=False):
123 assert isinstance(io_width, int)
124 params = GoldschmidtDivParams.get(io_width)
125 m = Module()
126 dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes,
127 sync_rom=sync_rom)
128 m.submodules.dut = dut
129 # make sync domain get added
130 m.d.sync += Signal().eq(0)
131
132 def inputs_proc():
133 yield Tick()
134 for n, d in self.cases(io_width, cases):
135 yield dut.n.eq(n)
136 yield dut.d.eq(d)
137 yield Tick()
138
139 def check_interals(n, d):
140 # check internals only if dut is completely combinatorial
141 # so we don't have to figure out how to read values in
142 # previous clock cycles
143 if dut.total_pipeline_registers != 0:
144 return
145 ref_trace = []
146
147 def ref_trace_fn(state):
148 assert isinstance(state, GoldschmidtDivState)
149 ref_trace.append((replace(state)))
150 goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
151 self.assertEqual(len(dut.trace), len(ref_trace))
152 for index, state in enumerate(dut.trace):
153 ref_state = ref_trace[index]
154 last_op = None if index == 0 else params.ops[index - 1]
155 with self.subTest(index=index, state=repr(state),
156 ref_state=repr(ref_state),
157 last_op=str(last_op)):
158 for field in fields(GoldschmidtDivHDLState):
159 sig = getattr(state, field.name)
160 if not isinstance(sig, Signal):
161 continue
162 ref_value = getattr(ref_state, field.name)
163 ref_value_str = repr(ref_value)
164 if isinstance(ref_value, int):
165 ref_value_str = hex(ref_value)
166 value = yield sig
167 with self.subTest(field_name=field.name,
168 sig=repr(sig),
169 sig_shape=repr(sig.shape()),
170 value=hex(value),
171 ref_value=ref_value_str):
172 if isinstance(ref_value, int):
173 self.assertEqual(value, ref_value)
174 else:
175 assert isinstance(ref_value, FixedPoint)
176 self.assertEqual(value, ref_value.bits)
177
178 def check_outputs():
179 yield Tick()
180 for _ in range(dut.total_pipeline_registers):
181 yield Tick()
182 for n, d in self.cases(io_width, cases):
183 yield Delay(0.1e-6)
184 expected_q, expected_r = divmod(n, d)
185 with self.subTest(n=hex(n), d=hex(d),
186 expected_q=hex(expected_q),
187 expected_r=hex(expected_r)):
188 q = yield dut.q
189 r = yield dut.r
190 with self.subTest(q=hex(q), r=hex(r)):
191 self.assertEqual((q, r), (expected_q, expected_r))
192 yield from check_interals(n, d)
193
194 yield Tick()
195
196 with self.subTest(params=str(params)):
197 with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim:
198 sim.add_clock(1e-6)
199 sim.add_process(inputs_proc)
200 sim.add_process(check_outputs)
201 sim.run()
202
203 def test_1_through_4(self):
204 for io_width in range(1, 4 + 1):
205 with self.subTest(io_width=io_width):
206 self.tst(io_width)
207
208 def test_5(self):
209 self.tst(5)
210
211 def test_6(self):
212 self.tst(6)
213
214 def test_8(self):
215 self.tst(8)
216
217 def test_16(self):
218 self.tst(16)
219
220 def test_32(self):
221 self.tst(32)
222
223 def test_64(self):
224 self.tst(64)
225
226 def test_sim_5(self):
227 self.tst_sim(5)
228
229 def test_sim_8(self):
230 self.tst_sim(8)
231
232 def test_sim_16(self):
233 self.tst_sim(16)
234
235 def test_sim_32(self):
236 self.tst_sim(32)
237
238 def test_sim_64(self):
239 self.tst_sim(64)
240
241 def tst_params(self, io_width):
242 assert isinstance(io_width, int)
243 params = GoldschmidtDivParams.get(io_width)
244 print()
245 print(params)
246
247 def test_params_1(self):
248 self.tst_params(1)
249
250 def test_params_2(self):
251 self.tst_params(2)
252
253 def test_params_3(self):
254 self.tst_params(3)
255
256 def test_params_4(self):
257 self.tst_params(4)
258
259 def test_params_5(self):
260 self.tst_params(5)
261
262 def test_params_6(self):
263 self.tst_params(6)
264
265 def test_params_7(self):
266 self.tst_params(7)
267
268 def test_params_8(self):
269 self.tst_params(8)
270
271 def test_params_9(self):
272 self.tst_params(9)
273
274 def test_params_10(self):
275 self.tst_params(10)
276
277 def test_params_11(self):
278 self.tst_params(11)
279
280 def test_params_12(self):
281 self.tst_params(12)
282
283 def test_params_13(self):
284 self.tst_params(13)
285
286 def test_params_14(self):
287 self.tst_params(14)
288
289 def test_params_15(self):
290 self.tst_params(15)
291
292 def test_params_16(self):
293 self.tst_params(16)
294
295 def test_params_17(self):
296 self.tst_params(17)
297
298 def test_params_18(self):
299 self.tst_params(18)
300
301 def test_params_19(self):
302 self.tst_params(19)
303
304 def test_params_20(self):
305 self.tst_params(20)
306
307 def test_params_21(self):
308 self.tst_params(21)
309
310 def test_params_22(self):
311 self.tst_params(22)
312
313 def test_params_23(self):
314 self.tst_params(23)
315
316 def test_params_24(self):
317 self.tst_params(24)
318
319 def test_params_25(self):
320 self.tst_params(25)
321
322 def test_params_26(self):
323 self.tst_params(26)
324
325 def test_params_27(self):
326 self.tst_params(27)
327
328 def test_params_28(self):
329 self.tst_params(28)
330
331 def test_params_29(self):
332 self.tst_params(29)
333
334 def test_params_30(self):
335 self.tst_params(30)
336
337 def test_params_31(self):
338 self.tst_params(31)
339
340 def test_params_32(self):
341 self.tst_params(32)
342
343 def test_params_33(self):
344 self.tst_params(33)
345
346 def test_params_34(self):
347 self.tst_params(34)
348
349 def test_params_35(self):
350 self.tst_params(35)
351
352 def test_params_36(self):
353 self.tst_params(36)
354
355 def test_params_37(self):
356 self.tst_params(37)
357
358 def test_params_38(self):
359 self.tst_params(38)
360
361 def test_params_39(self):
362 self.tst_params(39)
363
364 def test_params_40(self):
365 self.tst_params(40)
366
367 def test_params_41(self):
368 self.tst_params(41)
369
370 def test_params_42(self):
371 self.tst_params(42)
372
373 def test_params_43(self):
374 self.tst_params(43)
375
376 def test_params_44(self):
377 self.tst_params(44)
378
379 def test_params_45(self):
380 self.tst_params(45)
381
382 def test_params_46(self):
383 self.tst_params(46)
384
385 def test_params_47(self):
386 self.tst_params(47)
387
388 def test_params_48(self):
389 self.tst_params(48)
390
391 def test_params_49(self):
392 self.tst_params(49)
393
394 def test_params_50(self):
395 self.tst_params(50)
396
397 def test_params_51(self):
398 self.tst_params(51)
399
400 def test_params_52(self):
401 self.tst_params(52)
402
403 def test_params_53(self):
404 self.tst_params(53)
405
406 def test_params_54(self):
407 self.tst_params(54)
408
409 def test_params_55(self):
410 self.tst_params(55)
411
412 def test_params_56(self):
413 self.tst_params(56)
414
415 def test_params_57(self):
416 self.tst_params(57)
417
418 def test_params_58(self):
419 self.tst_params(58)
420
421 def test_params_59(self):
422 self.tst_params(59)
423
424 def test_params_60(self):
425 self.tst_params(60)
426
427 def test_params_61(self):
428 self.tst_params(61)
429
430 def test_params_62(self):
431 self.tst_params(62)
432
433 def test_params_63(self):
434 self.tst_params(63)
435
436 def test_params_64(self):
437 self.tst_params(64)
438
439
440 class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
441 def tst(self, io_width, frac_wid, extra_precision,
442 table_addr_bits, table_data_bits, iter_count):
443 assert isinstance(io_width, int)
444 assert isinstance(frac_wid, int)
445 assert isinstance(extra_precision, int)
446 assert isinstance(table_addr_bits, int)
447 assert isinstance(table_data_bits, int)
448 assert isinstance(iter_count, int)
449 with self.subTest(io_width=io_width, frac_wid=frac_wid,
450 extra_precision=extra_precision,
451 table_addr_bits=table_addr_bits,
452 table_data_bits=table_data_bits,
453 iter_count=iter_count):
454 for bits in range(1 << io_width):
455 radicand = FixedPoint(bits, frac_wid)
456 expected_sqrt = radicand.sqrt(RoundDir.DOWN)
457 expected_rsqrt = FixedPoint(0, frac_wid)
458 if radicand > 0:
459 expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
460 with self.subTest(radicand=repr(radicand),
461 expected_sqrt=repr(expected_sqrt),
462 expected_rsqrt=repr(expected_rsqrt)):
463 sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
464 radicand=radicand, io_width=io_width,
465 frac_wid=frac_wid,
466 extra_precision=extra_precision,
467 table_addr_bits=table_addr_bits,
468 table_data_bits=table_data_bits,
469 iter_count=iter_count)
470 with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
471 self.assertEqual((sqrt, rsqrt),
472 (expected_sqrt, expected_rsqrt))
473
474 def test1(self):
475 self.tst(io_width=16, frac_wid=8, extra_precision=20,
476 table_addr_bits=4, table_data_bits=28, iter_count=4)
477
478
479 if __name__ == "__main__":
480 unittest.main()