e2984dc16db684cc31ea20973d235ba72e9aa01d
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
9 from nmutil
.formaltest
import FHDLTestCase
10 from soc
.fu
.div
.experiment
.goldschmidt_div_sqrt
import (
11 GoldschmidtDivParams
, ParamsNotAccurateEnough
, goldschmidt_div
,
12 FixedPoint
, RoundDir
, goldschmidt_sqrt_rsqrt
)
15 class TestFixedPoint(FHDLTestCase
):
16 def test_str_roundtrip(self
):
17 for frac_wid
in range(8):
18 for bits
in range(-1 << 9, 1 << 9):
19 with self
.subTest(bits
=hex(bits
), frac_wid
=frac_wid
):
20 value
= FixedPoint(bits
, frac_wid
)
21 round_trip_value
= FixedPoint
.cast(str(value
))
22 self
.assertEqual(value
, round_trip_value
)
28 except (ValueError, ZeroDivisionError) as e
:
29 return None, e
.__class
__.__name
__
32 for frac_wid
in range(8):
33 for bits
in range(1 << 9):
34 for round_dir
in RoundDir
:
35 radicand
= FixedPoint(bits
, frac_wid
)
36 expected_f
= math
.sqrt(float(radicand
))
37 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
38 expected_f
, frac_wid
, round_dir
))
39 with self
.subTest(radicand
=repr(radicand
),
40 round_dir
=str(round_dir
),
41 expected
=repr(expected
)):
42 result
= self
.trap(lambda: radicand
.sqrt(round_dir
))
43 self
.assertEqual(result
, expected
)
46 for frac_wid
in range(8):
47 for bits
in range(1, 1 << 9):
48 for round_dir
in RoundDir
:
49 radicand
= FixedPoint(bits
, frac_wid
)
50 expected_f
= 1 / math
.sqrt(float(radicand
))
51 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
52 expected_f
, frac_wid
, round_dir
))
53 with self
.subTest(radicand
=repr(radicand
),
54 round_dir
=str(round_dir
),
55 expected
=repr(expected
)):
56 result
= self
.trap(lambda: radicand
.rsqrt(round_dir
))
57 self
.assertEqual(result
, expected
)
60 class TestGoldschmidtDiv(FHDLTestCase
):
62 with self
.assertRaises(ParamsNotAccurateEnough
):
63 GoldschmidtDivParams(io_width
=3, extra_precision
=2,
64 table_addr_bits
=3, table_data_bits
=5,
68 with self
.assertRaises(ParamsNotAccurateEnough
):
69 GoldschmidtDivParams(io_width
=4, extra_precision
=1,
70 table_addr_bits
=1, table_data_bits
=5,
73 def tst(self
, io_width
):
74 assert isinstance(io_width
, int)
75 params
= GoldschmidtDivParams
.get(io_width
)
76 with self
.subTest(params
=str(params
)):
77 for d
in range(1, 1 << io_width
):
78 for n
in range(d
<< io_width
):
79 expected_q
, expected_r
= divmod(n
, d
)
80 with self
.subTest(n
=hex(n
), d
=hex(d
),
81 expected_q
=hex(expected_q
),
82 expected_r
=hex(expected_r
)):
83 q
, r
= goldschmidt_div(n
, d
, params
)
84 with self
.subTest(q
=hex(q
), r
=hex(r
)):
85 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
87 def test_1_through_4(self
):
88 for io_width
in range(1, 4 + 1):
89 with self
.subTest(io_width
=io_width
):
98 def tst_params(self
, io_width
):
99 assert isinstance(io_width
, int)
100 params
= GoldschmidtDivParams
.get(io_width
)
104 def test_params_1(self
):
107 def test_params_2(self
):
110 def test_params_3(self
):
113 def test_params_4(self
):
116 def test_params_5(self
):
119 def test_params_6(self
):
122 def test_params_7(self
):
125 def test_params_8(self
):
128 def test_params_9(self
):
131 def test_params_10(self
):
134 def test_params_11(self
):
137 def test_params_12(self
):
140 def test_params_13(self
):
143 def test_params_14(self
):
146 def test_params_15(self
):
149 def test_params_16(self
):
152 def test_params_17(self
):
155 def test_params_18(self
):
158 def test_params_19(self
):
161 def test_params_20(self
):
164 def test_params_21(self
):
167 def test_params_22(self
):
170 def test_params_23(self
):
173 def test_params_24(self
):
176 def test_params_25(self
):
179 def test_params_26(self
):
182 def test_params_27(self
):
185 def test_params_28(self
):
188 def test_params_29(self
):
191 def test_params_30(self
):
194 def test_params_31(self
):
197 def test_params_32(self
):
200 def test_params_33(self
):
203 def test_params_34(self
):
206 def test_params_35(self
):
209 def test_params_36(self
):
212 def test_params_37(self
):
215 def test_params_38(self
):
218 def test_params_39(self
):
221 def test_params_40(self
):
224 def test_params_41(self
):
227 def test_params_42(self
):
230 def test_params_43(self
):
233 def test_params_44(self
):
236 def test_params_45(self
):
239 def test_params_46(self
):
242 def test_params_47(self
):
245 def test_params_48(self
):
248 def test_params_49(self
):
251 def test_params_50(self
):
254 def test_params_51(self
):
257 def test_params_52(self
):
260 def test_params_53(self
):
263 def test_params_54(self
):
266 def test_params_55(self
):
269 def test_params_56(self
):
272 def test_params_57(self
):
275 def test_params_58(self
):
278 def test_params_59(self
):
281 def test_params_60(self
):
284 def test_params_61(self
):
287 def test_params_62(self
):
290 def test_params_63(self
):
293 def test_params_64(self
):
297 class TestGoldschmidtSqrtRSqrt(FHDLTestCase
):
298 def tst(self
, io_width
, frac_wid
, extra_precision
,
299 table_addr_bits
, table_data_bits
, iter_count
):
300 assert isinstance(io_width
, int)
301 assert isinstance(frac_wid
, int)
302 assert isinstance(extra_precision
, int)
303 assert isinstance(table_addr_bits
, int)
304 assert isinstance(table_data_bits
, int)
305 assert isinstance(iter_count
, int)
306 with self
.subTest(io_width
=io_width
, frac_wid
=frac_wid
,
307 extra_precision
=extra_precision
,
308 table_addr_bits
=table_addr_bits
,
309 table_data_bits
=table_data_bits
,
310 iter_count
=iter_count
):
311 for bits
in range(1 << io_width
):
312 radicand
= FixedPoint(bits
, frac_wid
)
313 expected_sqrt
= radicand
.sqrt(RoundDir
.DOWN
)
314 expected_rsqrt
= FixedPoint(0, frac_wid
)
316 expected_rsqrt
= radicand
.rsqrt(RoundDir
.DOWN
)
317 with self
.subTest(radicand
=repr(radicand
),
318 expected_sqrt
=repr(expected_sqrt
),
319 expected_rsqrt
=repr(expected_rsqrt
)):
320 sqrt
, rsqrt
= goldschmidt_sqrt_rsqrt(
321 radicand
=radicand
, io_width
=io_width
,
323 extra_precision
=extra_precision
,
324 table_addr_bits
=table_addr_bits
,
325 table_data_bits
=table_data_bits
,
326 iter_count
=iter_count
)
327 with self
.subTest(sqrt
=repr(sqrt
), rsqrt
=repr(rsqrt
)):
328 self
.assertEqual((sqrt
, rsqrt
),
329 (expected_sqrt
, expected_rsqrt
))
332 self
.tst(io_width
=16, frac_wid
=8, extra_precision
=20,
333 table_addr_bits
=4, table_data_bits
=28, iter_count
=4)
336 if __name__
== "__main__":