1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
9 from nmutil
.formaltest
import FHDLTestCase
10 from nmutil
.sim_util
import do_sim
11 from nmigen
.sim
import Tick
, Delay
12 from nmigen
.hdl
.ast
import Signal
13 from nmigen
.hdl
.dsl
import Module
14 from soc
.fu
.div
.experiment
.goldschmidt_div_sqrt
import (
15 GoldschmidtDivHDL
, GoldschmidtDivParams
, ParamsNotAccurateEnough
,
16 goldschmidt_div
, FixedPoint
, RoundDir
, goldschmidt_sqrt_rsqrt
)
19 class TestFixedPoint(FHDLTestCase
):
20 def test_str_roundtrip(self
):
21 for frac_wid
in range(8):
22 for bits
in range(-1 << 9, 1 << 9):
23 with self
.subTest(bits
=hex(bits
), frac_wid
=frac_wid
):
24 value
= FixedPoint(bits
, frac_wid
)
25 round_trip_value
= FixedPoint
.cast(str(value
))
26 self
.assertEqual(value
, round_trip_value
)
32 except (ValueError, ZeroDivisionError) as e
:
33 return None, e
.__class
__.__name
__
36 for frac_wid
in range(8):
37 for bits
in range(1 << 9):
38 for round_dir
in RoundDir
:
39 radicand
= FixedPoint(bits
, frac_wid
)
40 expected_f
= math
.sqrt(float(radicand
))
41 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
42 expected_f
, frac_wid
, round_dir
))
43 with self
.subTest(radicand
=repr(radicand
),
44 round_dir
=str(round_dir
),
45 expected
=repr(expected
)):
46 result
= self
.trap(lambda: radicand
.sqrt(round_dir
))
47 self
.assertEqual(result
, expected
)
50 for frac_wid
in range(8):
51 for bits
in range(1, 1 << 9):
52 for round_dir
in RoundDir
:
53 radicand
= FixedPoint(bits
, frac_wid
)
54 expected_f
= 1 / math
.sqrt(float(radicand
))
55 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
56 expected_f
, frac_wid
, round_dir
))
57 with self
.subTest(radicand
=repr(radicand
),
58 round_dir
=str(round_dir
),
59 expected
=repr(expected
)):
60 result
= self
.trap(lambda: radicand
.rsqrt(round_dir
))
61 self
.assertEqual(result
, expected
)
64 class TestGoldschmidtDiv(FHDLTestCase
):
66 with self
.assertRaises(ParamsNotAccurateEnough
):
67 GoldschmidtDivParams(io_width
=3, extra_precision
=2,
68 table_addr_bits
=3, table_data_bits
=5,
72 with self
.assertRaises(ParamsNotAccurateEnough
):
73 GoldschmidtDivParams(io_width
=4, extra_precision
=1,
74 table_addr_bits
=1, table_data_bits
=5,
77 def tst(self
, io_width
):
78 assert isinstance(io_width
, int)
79 params
= GoldschmidtDivParams
.get(io_width
)
80 with self
.subTest(params
=str(params
)):
81 for d
in range(1, 1 << io_width
):
82 for n
in range(d
<< io_width
):
83 expected_q
, expected_r
= divmod(n
, d
)
84 with self
.subTest(n
=hex(n
), d
=hex(d
),
85 expected_q
=hex(expected_q
),
86 expected_r
=hex(expected_r
)):
87 q
, r
= goldschmidt_div(n
, d
, params
)
88 with self
.subTest(q
=hex(q
), r
=hex(r
)):
89 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
91 @unittest.skip("hdl/simulation currently broken")
92 def tst_sim(self
, io_width
, cases
=None, pipe_reg_indexes
=(),
94 # FIXME: finish getting hdl/simulation to work
95 assert isinstance(io_width
, int)
96 params
= GoldschmidtDivParams
.get(io_width
)
98 dut
= GoldschmidtDivHDL(params
, pipe_reg_indexes
=pipe_reg_indexes
,
100 m
.submodules
.dut
= dut
101 # make sync domain get added
102 m
.d
.sync
+= Signal().eq(0)
105 if cases
is not None:
108 for d
in range(1, 1 << io_width
):
109 for n
in range(d
<< io_width
):
114 for n
, d
in iter_cases():
121 for _
in range(dut
.total_pipeline_registers
):
123 for n
, d
in iter_cases():
125 expected_q
, expected_r
= divmod(n
, d
)
126 with self
.subTest(n
=hex(n
), d
=hex(d
),
127 expected_q
=hex(expected_q
),
128 expected_r
=hex(expected_r
)):
131 with self
.subTest(q
=hex(q
), r
=hex(r
)):
132 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
135 with self
.subTest(params
=str(params
)):
136 with
do_sim(self
, m
, (dut
.n
, dut
.d
, dut
.q
, dut
.r
)) as sim
:
138 sim
.add_process(inputs_proc
)
139 sim
.add_process(check_outputs
)
142 def test_1_through_4(self
):
143 for io_width
in range(1, 4 + 1):
144 with self
.subTest(io_width
=io_width
):
153 def test_sim_5(self
):
156 def tst_params(self
, io_width
):
157 assert isinstance(io_width
, int)
158 params
= GoldschmidtDivParams
.get(io_width
)
162 def test_params_1(self
):
165 def test_params_2(self
):
168 def test_params_3(self
):
171 def test_params_4(self
):
174 def test_params_5(self
):
177 def test_params_6(self
):
180 def test_params_7(self
):
183 def test_params_8(self
):
186 def test_params_9(self
):
189 def test_params_10(self
):
192 def test_params_11(self
):
195 def test_params_12(self
):
198 def test_params_13(self
):
201 def test_params_14(self
):
204 def test_params_15(self
):
207 def test_params_16(self
):
210 def test_params_17(self
):
213 def test_params_18(self
):
216 def test_params_19(self
):
219 def test_params_20(self
):
222 def test_params_21(self
):
225 def test_params_22(self
):
228 def test_params_23(self
):
231 def test_params_24(self
):
234 def test_params_25(self
):
237 def test_params_26(self
):
240 def test_params_27(self
):
243 def test_params_28(self
):
246 def test_params_29(self
):
249 def test_params_30(self
):
252 def test_params_31(self
):
255 def test_params_32(self
):
258 def test_params_33(self
):
261 def test_params_34(self
):
264 def test_params_35(self
):
267 def test_params_36(self
):
270 def test_params_37(self
):
273 def test_params_38(self
):
276 def test_params_39(self
):
279 def test_params_40(self
):
282 def test_params_41(self
):
285 def test_params_42(self
):
288 def test_params_43(self
):
291 def test_params_44(self
):
294 def test_params_45(self
):
297 def test_params_46(self
):
300 def test_params_47(self
):
303 def test_params_48(self
):
306 def test_params_49(self
):
309 def test_params_50(self
):
312 def test_params_51(self
):
315 def test_params_52(self
):
318 def test_params_53(self
):
321 def test_params_54(self
):
324 def test_params_55(self
):
327 def test_params_56(self
):
330 def test_params_57(self
):
333 def test_params_58(self
):
336 def test_params_59(self
):
339 def test_params_60(self
):
342 def test_params_61(self
):
345 def test_params_62(self
):
348 def test_params_63(self
):
351 def test_params_64(self
):
355 class TestGoldschmidtSqrtRSqrt(FHDLTestCase
):
356 def tst(self
, io_width
, frac_wid
, extra_precision
,
357 table_addr_bits
, table_data_bits
, iter_count
):
358 assert isinstance(io_width
, int)
359 assert isinstance(frac_wid
, int)
360 assert isinstance(extra_precision
, int)
361 assert isinstance(table_addr_bits
, int)
362 assert isinstance(table_data_bits
, int)
363 assert isinstance(iter_count
, int)
364 with self
.subTest(io_width
=io_width
, frac_wid
=frac_wid
,
365 extra_precision
=extra_precision
,
366 table_addr_bits
=table_addr_bits
,
367 table_data_bits
=table_data_bits
,
368 iter_count
=iter_count
):
369 for bits
in range(1 << io_width
):
370 radicand
= FixedPoint(bits
, frac_wid
)
371 expected_sqrt
= radicand
.sqrt(RoundDir
.DOWN
)
372 expected_rsqrt
= FixedPoint(0, frac_wid
)
374 expected_rsqrt
= radicand
.rsqrt(RoundDir
.DOWN
)
375 with self
.subTest(radicand
=repr(radicand
),
376 expected_sqrt
=repr(expected_sqrt
),
377 expected_rsqrt
=repr(expected_rsqrt
)):
378 sqrt
, rsqrt
= goldschmidt_sqrt_rsqrt(
379 radicand
=radicand
, io_width
=io_width
,
381 extra_precision
=extra_precision
,
382 table_addr_bits
=table_addr_bits
,
383 table_data_bits
=table_data_bits
,
384 iter_count
=iter_count
)
385 with self
.subTest(sqrt
=repr(sqrt
), rsqrt
=repr(rsqrt
)):
386 self
.assertEqual((sqrt
, rsqrt
),
387 (expected_sqrt
, expected_rsqrt
))
390 self
.tst(io_width
=16, frac_wid
=8, extra_precision
=20,
391 table_addr_bits
=4, table_data_bits
=28, iter_count
=4)
394 if __name__
== "__main__":