add the goldschmidt sqrt/rsqrt algorithm, still need code to calculate good parameters
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 import math
8 import unittest
9 from nmutil.formaltest import FHDLTestCase
10 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
11 GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
12 FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
13
14
15 class TestFixedPoint(FHDLTestCase):
16 def test_str_roundtrip(self):
17 for frac_wid in range(8):
18 for bits in range(-1 << 9, 1 << 9):
19 with self.subTest(bits=hex(bits), frac_wid=frac_wid):
20 value = FixedPoint(bits, frac_wid)
21 round_trip_value = FixedPoint.cast(str(value))
22 self.assertEqual(value, round_trip_value)
23
24 @staticmethod
25 def trap(f):
26 try:
27 return f(), None
28 except (ValueError, ZeroDivisionError) as e:
29 return None, e.__class__.__name__
30
31 def test_sqrt(self):
32 for frac_wid in range(8):
33 for bits in range(1 << 9):
34 for round_dir in RoundDir:
35 radicand = FixedPoint(bits, frac_wid)
36 expected_f = math.sqrt(float(radicand))
37 expected = self.trap(lambda: FixedPoint.with_frac_wid(
38 expected_f, frac_wid, round_dir))
39 with self.subTest(radicand=repr(radicand),
40 round_dir=str(round_dir),
41 expected=repr(expected)):
42 result = self.trap(lambda: radicand.sqrt(round_dir))
43 self.assertEqual(result, expected)
44
45 def test_rsqrt(self):
46 for frac_wid in range(8):
47 for bits in range(1, 1 << 9):
48 for round_dir in RoundDir:
49 radicand = FixedPoint(bits, frac_wid)
50 expected_f = 1 / math.sqrt(float(radicand))
51 expected = self.trap(lambda: FixedPoint.with_frac_wid(
52 expected_f, frac_wid, round_dir))
53 with self.subTest(radicand=repr(radicand),
54 round_dir=str(round_dir),
55 expected=repr(expected)):
56 result = self.trap(lambda: radicand.rsqrt(round_dir))
57 self.assertEqual(result, expected)
58
59
60 class TestGoldschmidtDiv(FHDLTestCase):
61 def test_case1(self):
62 with self.assertRaises(ParamsNotAccurateEnough):
63 GoldschmidtDivParams(io_width=3, extra_precision=2,
64 table_addr_bits=3, table_data_bits=5,
65 iter_count=2)
66
67 def test_case2(self):
68 with self.assertRaises(ParamsNotAccurateEnough):
69 GoldschmidtDivParams(io_width=4, extra_precision=1,
70 table_addr_bits=1, table_data_bits=5,
71 iter_count=1)
72
73 def tst(self, io_width):
74 assert isinstance(io_width, int)
75 params = GoldschmidtDivParams.get(io_width)
76 with self.subTest(params=str(params)):
77 for d in range(1, 1 << io_width):
78 for n in range(d << io_width):
79 expected_q, expected_r = divmod(n, d)
80 with self.subTest(n=hex(n), d=hex(d),
81 expected_q=hex(expected_q),
82 expected_r=hex(expected_r)):
83 q, r = goldschmidt_div(n, d, params)
84 with self.subTest(q=hex(q), r=hex(r)):
85 self.assertEqual((q, r), (expected_q, expected_r))
86
87 def test_1_through_4(self):
88 for io_width in range(1, 4 + 1):
89 with self.subTest(io_width=io_width):
90 self.tst(io_width)
91
92 def test_5(self):
93 self.tst(5)
94
95 def test_6(self):
96 self.tst(6)
97
98 def tst_params(self, io_width):
99 assert isinstance(io_width, int)
100 params = GoldschmidtDivParams.get(io_width)
101 print()
102 print(params)
103
104 def test_params_1(self):
105 self.tst_params(1)
106
107 def test_params_2(self):
108 self.tst_params(2)
109
110 def test_params_3(self):
111 self.tst_params(3)
112
113 def test_params_4(self):
114 self.tst_params(4)
115
116 def test_params_5(self):
117 self.tst_params(5)
118
119 def test_params_6(self):
120 self.tst_params(6)
121
122 def test_params_7(self):
123 self.tst_params(7)
124
125 def test_params_8(self):
126 self.tst_params(8)
127
128 def test_params_9(self):
129 self.tst_params(9)
130
131 def test_params_10(self):
132 self.tst_params(10)
133
134 def test_params_11(self):
135 self.tst_params(11)
136
137 def test_params_12(self):
138 self.tst_params(12)
139
140 def test_params_13(self):
141 self.tst_params(13)
142
143 def test_params_14(self):
144 self.tst_params(14)
145
146 def test_params_15(self):
147 self.tst_params(15)
148
149 def test_params_16(self):
150 self.tst_params(16)
151
152 def test_params_17(self):
153 self.tst_params(17)
154
155 def test_params_18(self):
156 self.tst_params(18)
157
158 def test_params_19(self):
159 self.tst_params(19)
160
161 def test_params_20(self):
162 self.tst_params(20)
163
164 def test_params_21(self):
165 self.tst_params(21)
166
167 def test_params_22(self):
168 self.tst_params(22)
169
170 def test_params_23(self):
171 self.tst_params(23)
172
173 def test_params_24(self):
174 self.tst_params(24)
175
176 def test_params_25(self):
177 self.tst_params(25)
178
179 def test_params_26(self):
180 self.tst_params(26)
181
182 def test_params_27(self):
183 self.tst_params(27)
184
185 def test_params_28(self):
186 self.tst_params(28)
187
188 def test_params_29(self):
189 self.tst_params(29)
190
191 def test_params_30(self):
192 self.tst_params(30)
193
194 def test_params_31(self):
195 self.tst_params(31)
196
197 def test_params_32(self):
198 self.tst_params(32)
199
200 def test_params_33(self):
201 self.tst_params(33)
202
203 def test_params_34(self):
204 self.tst_params(34)
205
206 def test_params_35(self):
207 self.tst_params(35)
208
209 def test_params_36(self):
210 self.tst_params(36)
211
212 def test_params_37(self):
213 self.tst_params(37)
214
215 def test_params_38(self):
216 self.tst_params(38)
217
218 def test_params_39(self):
219 self.tst_params(39)
220
221 def test_params_40(self):
222 self.tst_params(40)
223
224 def test_params_41(self):
225 self.tst_params(41)
226
227 def test_params_42(self):
228 self.tst_params(42)
229
230 def test_params_43(self):
231 self.tst_params(43)
232
233 def test_params_44(self):
234 self.tst_params(44)
235
236 def test_params_45(self):
237 self.tst_params(45)
238
239 def test_params_46(self):
240 self.tst_params(46)
241
242 def test_params_47(self):
243 self.tst_params(47)
244
245 def test_params_48(self):
246 self.tst_params(48)
247
248 def test_params_49(self):
249 self.tst_params(49)
250
251 def test_params_50(self):
252 self.tst_params(50)
253
254 def test_params_51(self):
255 self.tst_params(51)
256
257 def test_params_52(self):
258 self.tst_params(52)
259
260 def test_params_53(self):
261 self.tst_params(53)
262
263 def test_params_54(self):
264 self.tst_params(54)
265
266 def test_params_55(self):
267 self.tst_params(55)
268
269 def test_params_56(self):
270 self.tst_params(56)
271
272 def test_params_57(self):
273 self.tst_params(57)
274
275 def test_params_58(self):
276 self.tst_params(58)
277
278 def test_params_59(self):
279 self.tst_params(59)
280
281 def test_params_60(self):
282 self.tst_params(60)
283
284 def test_params_61(self):
285 self.tst_params(61)
286
287 def test_params_62(self):
288 self.tst_params(62)
289
290 def test_params_63(self):
291 self.tst_params(63)
292
293 def test_params_64(self):
294 self.tst_params(64)
295
296
297 class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
298 def tst(self, io_width, frac_wid, extra_precision,
299 table_addr_bits, table_data_bits, iter_count):
300 assert isinstance(io_width, int)
301 assert isinstance(frac_wid, int)
302 assert isinstance(extra_precision, int)
303 assert isinstance(table_addr_bits, int)
304 assert isinstance(table_data_bits, int)
305 assert isinstance(iter_count, int)
306 with self.subTest(io_width=io_width, frac_wid=frac_wid,
307 extra_precision=extra_precision,
308 table_addr_bits=table_addr_bits,
309 table_data_bits=table_data_bits,
310 iter_count=iter_count):
311 for bits in range(1 << io_width):
312 radicand = FixedPoint(bits, frac_wid)
313 expected_sqrt = radicand.sqrt(RoundDir.DOWN)
314 expected_rsqrt = FixedPoint(0, frac_wid)
315 if radicand > 0:
316 expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
317 with self.subTest(radicand=repr(radicand),
318 expected_sqrt=repr(expected_sqrt),
319 expected_rsqrt=repr(expected_rsqrt)):
320 sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
321 radicand=radicand, io_width=io_width,
322 frac_wid=frac_wid,
323 extra_precision=extra_precision,
324 table_addr_bits=table_addr_bits,
325 table_data_bits=table_data_bits,
326 iter_count=iter_count)
327 with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
328 self.assertEqual((sqrt, rsqrt),
329 (expected_sqrt, expected_rsqrt))
330
331 def test1(self):
332 self.tst(io_width=16, frac_wid=8, extra_precision=20,
333 table_addr_bits=4, table_data_bits=28, iter_count=4)
334
335
336 if __name__ == "__main__":
337 unittest.main()