add WIP HDL version of goldschmidt division -- it's currently broken
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import math
8 import unittest
9 from nmutil.formaltest import FHDLTestCase
10 from nmutil.sim_util import do_sim
11 from nmigen.sim import Tick, Delay
12 from nmigen.hdl.ast import Signal
13 from nmigen.hdl.dsl import Module
14 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
15 GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
16 goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
17
18
19 class TestFixedPoint(FHDLTestCase):
20 def test_str_roundtrip(self):
21 for frac_wid in range(8):
22 for bits in range(-1 << 9, 1 << 9):
23 with self.subTest(bits=hex(bits), frac_wid=frac_wid):
24 value = FixedPoint(bits, frac_wid)
25 round_trip_value = FixedPoint.cast(str(value))
26 self.assertEqual(value, round_trip_value)
27
28 @staticmethod
29 def trap(f):
30 try:
31 return f(), None
32 except (ValueError, ZeroDivisionError) as e:
33 return None, e.__class__.__name__
34
35 def test_sqrt(self):
36 for frac_wid in range(8):
37 for bits in range(1 << 9):
38 for round_dir in RoundDir:
39 radicand = FixedPoint(bits, frac_wid)
40 expected_f = math.sqrt(float(radicand))
41 expected = self.trap(lambda: FixedPoint.with_frac_wid(
42 expected_f, frac_wid, round_dir))
43 with self.subTest(radicand=repr(radicand),
44 round_dir=str(round_dir),
45 expected=repr(expected)):
46 result = self.trap(lambda: radicand.sqrt(round_dir))
47 self.assertEqual(result, expected)
48
49 def test_rsqrt(self):
50 for frac_wid in range(8):
51 for bits in range(1, 1 << 9):
52 for round_dir in RoundDir:
53 radicand = FixedPoint(bits, frac_wid)
54 expected_f = 1 / math.sqrt(float(radicand))
55 expected = self.trap(lambda: FixedPoint.with_frac_wid(
56 expected_f, frac_wid, round_dir))
57 with self.subTest(radicand=repr(radicand),
58 round_dir=str(round_dir),
59 expected=repr(expected)):
60 result = self.trap(lambda: radicand.rsqrt(round_dir))
61 self.assertEqual(result, expected)
62
63
64 class TestGoldschmidtDiv(FHDLTestCase):
65 def test_case1(self):
66 with self.assertRaises(ParamsNotAccurateEnough):
67 GoldschmidtDivParams(io_width=3, extra_precision=2,
68 table_addr_bits=3, table_data_bits=5,
69 iter_count=2)
70
71 def test_case2(self):
72 with self.assertRaises(ParamsNotAccurateEnough):
73 GoldschmidtDivParams(io_width=4, extra_precision=1,
74 table_addr_bits=1, table_data_bits=5,
75 iter_count=1)
76
77 def tst(self, io_width):
78 assert isinstance(io_width, int)
79 params = GoldschmidtDivParams.get(io_width)
80 with self.subTest(params=str(params)):
81 for d in range(1, 1 << io_width):
82 for n in range(d << io_width):
83 expected_q, expected_r = divmod(n, d)
84 with self.subTest(n=hex(n), d=hex(d),
85 expected_q=hex(expected_q),
86 expected_r=hex(expected_r)):
87 q, r = goldschmidt_div(n, d, params)
88 with self.subTest(q=hex(q), r=hex(r)):
89 self.assertEqual((q, r), (expected_q, expected_r))
90
91 @unittest.skip("hdl/simulation currently broken")
92 def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
93 sync_rom=False):
94 # FIXME: finish getting hdl/simulation to work
95 assert isinstance(io_width, int)
96 params = GoldschmidtDivParams.get(io_width)
97 m = Module()
98 dut = GoldschmidtDivHDL(params, pipe_reg_indexes=pipe_reg_indexes,
99 sync_rom=sync_rom)
100 m.submodules.dut = dut
101 # make sync domain get added
102 m.d.sync += Signal().eq(0)
103
104 def iter_cases():
105 if cases is not None:
106 yield from cases
107 return
108 for d in range(1, 1 << io_width):
109 for n in range(d << io_width):
110 yield (n, d)
111
112 def inputs_proc():
113 yield Tick()
114 for n, d in iter_cases():
115 yield dut.n.eq(n)
116 yield dut.d.eq(d)
117 yield Tick()
118
119 def check_outputs():
120 yield Tick()
121 for _ in range(dut.total_pipeline_registers):
122 yield Tick()
123 for n, d in iter_cases():
124 yield Delay(0.1e-6)
125 expected_q, expected_r = divmod(n, d)
126 with self.subTest(n=hex(n), d=hex(d),
127 expected_q=hex(expected_q),
128 expected_r=hex(expected_r)):
129 q = yield dut.q
130 r = yield dut.r
131 with self.subTest(q=hex(q), r=hex(r)):
132 self.assertEqual((q, r), (expected_q, expected_r))
133 yield Tick()
134
135 with self.subTest(params=str(params)):
136 with do_sim(self, m, (dut.n, dut.d, dut.q, dut.r)) as sim:
137 sim.add_clock(1e-6)
138 sim.add_process(inputs_proc)
139 sim.add_process(check_outputs)
140 sim.run()
141
142 def test_1_through_4(self):
143 for io_width in range(1, 4 + 1):
144 with self.subTest(io_width=io_width):
145 self.tst(io_width)
146
147 def test_5(self):
148 self.tst(5)
149
150 def test_6(self):
151 self.tst(6)
152
153 def test_sim_5(self):
154 self.tst_sim(5)
155
156 def tst_params(self, io_width):
157 assert isinstance(io_width, int)
158 params = GoldschmidtDivParams.get(io_width)
159 print()
160 print(params)
161
162 def test_params_1(self):
163 self.tst_params(1)
164
165 def test_params_2(self):
166 self.tst_params(2)
167
168 def test_params_3(self):
169 self.tst_params(3)
170
171 def test_params_4(self):
172 self.tst_params(4)
173
174 def test_params_5(self):
175 self.tst_params(5)
176
177 def test_params_6(self):
178 self.tst_params(6)
179
180 def test_params_7(self):
181 self.tst_params(7)
182
183 def test_params_8(self):
184 self.tst_params(8)
185
186 def test_params_9(self):
187 self.tst_params(9)
188
189 def test_params_10(self):
190 self.tst_params(10)
191
192 def test_params_11(self):
193 self.tst_params(11)
194
195 def test_params_12(self):
196 self.tst_params(12)
197
198 def test_params_13(self):
199 self.tst_params(13)
200
201 def test_params_14(self):
202 self.tst_params(14)
203
204 def test_params_15(self):
205 self.tst_params(15)
206
207 def test_params_16(self):
208 self.tst_params(16)
209
210 def test_params_17(self):
211 self.tst_params(17)
212
213 def test_params_18(self):
214 self.tst_params(18)
215
216 def test_params_19(self):
217 self.tst_params(19)
218
219 def test_params_20(self):
220 self.tst_params(20)
221
222 def test_params_21(self):
223 self.tst_params(21)
224
225 def test_params_22(self):
226 self.tst_params(22)
227
228 def test_params_23(self):
229 self.tst_params(23)
230
231 def test_params_24(self):
232 self.tst_params(24)
233
234 def test_params_25(self):
235 self.tst_params(25)
236
237 def test_params_26(self):
238 self.tst_params(26)
239
240 def test_params_27(self):
241 self.tst_params(27)
242
243 def test_params_28(self):
244 self.tst_params(28)
245
246 def test_params_29(self):
247 self.tst_params(29)
248
249 def test_params_30(self):
250 self.tst_params(30)
251
252 def test_params_31(self):
253 self.tst_params(31)
254
255 def test_params_32(self):
256 self.tst_params(32)
257
258 def test_params_33(self):
259 self.tst_params(33)
260
261 def test_params_34(self):
262 self.tst_params(34)
263
264 def test_params_35(self):
265 self.tst_params(35)
266
267 def test_params_36(self):
268 self.tst_params(36)
269
270 def test_params_37(self):
271 self.tst_params(37)
272
273 def test_params_38(self):
274 self.tst_params(38)
275
276 def test_params_39(self):
277 self.tst_params(39)
278
279 def test_params_40(self):
280 self.tst_params(40)
281
282 def test_params_41(self):
283 self.tst_params(41)
284
285 def test_params_42(self):
286 self.tst_params(42)
287
288 def test_params_43(self):
289 self.tst_params(43)
290
291 def test_params_44(self):
292 self.tst_params(44)
293
294 def test_params_45(self):
295 self.tst_params(45)
296
297 def test_params_46(self):
298 self.tst_params(46)
299
300 def test_params_47(self):
301 self.tst_params(47)
302
303 def test_params_48(self):
304 self.tst_params(48)
305
306 def test_params_49(self):
307 self.tst_params(49)
308
309 def test_params_50(self):
310 self.tst_params(50)
311
312 def test_params_51(self):
313 self.tst_params(51)
314
315 def test_params_52(self):
316 self.tst_params(52)
317
318 def test_params_53(self):
319 self.tst_params(53)
320
321 def test_params_54(self):
322 self.tst_params(54)
323
324 def test_params_55(self):
325 self.tst_params(55)
326
327 def test_params_56(self):
328 self.tst_params(56)
329
330 def test_params_57(self):
331 self.tst_params(57)
332
333 def test_params_58(self):
334 self.tst_params(58)
335
336 def test_params_59(self):
337 self.tst_params(59)
338
339 def test_params_60(self):
340 self.tst_params(60)
341
342 def test_params_61(self):
343 self.tst_params(61)
344
345 def test_params_62(self):
346 self.tst_params(62)
347
348 def test_params_63(self):
349 self.tst_params(63)
350
351 def test_params_64(self):
352 self.tst_params(64)
353
354
355 class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
356 def tst(self, io_width, frac_wid, extra_precision,
357 table_addr_bits, table_data_bits, iter_count):
358 assert isinstance(io_width, int)
359 assert isinstance(frac_wid, int)
360 assert isinstance(extra_precision, int)
361 assert isinstance(table_addr_bits, int)
362 assert isinstance(table_data_bits, int)
363 assert isinstance(iter_count, int)
364 with self.subTest(io_width=io_width, frac_wid=frac_wid,
365 extra_precision=extra_precision,
366 table_addr_bits=table_addr_bits,
367 table_data_bits=table_data_bits,
368 iter_count=iter_count):
369 for bits in range(1 << io_width):
370 radicand = FixedPoint(bits, frac_wid)
371 expected_sqrt = radicand.sqrt(RoundDir.DOWN)
372 expected_rsqrt = FixedPoint(0, frac_wid)
373 if radicand > 0:
374 expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
375 with self.subTest(radicand=repr(radicand),
376 expected_sqrt=repr(expected_sqrt),
377 expected_rsqrt=repr(expected_rsqrt)):
378 sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
379 radicand=radicand, io_width=io_width,
380 frac_wid=frac_wid,
381 extra_precision=extra_precision,
382 table_addr_bits=table_addr_bits,
383 table_data_bits=table_data_bits,
384 iter_count=iter_count)
385 with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
386 self.assertEqual((sqrt, rsqrt),
387 (expected_sqrt, expected_rsqrt))
388
389 def test1(self):
390 self.tst(io_width=16, frac_wid=8, extra_precision=20,
391 table_addr_bits=4, table_data_bits=28, iter_count=4)
392
393
394 if __name__ == "__main__":
395 unittest.main()