1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import fields
, replace
10 from nmutil
.formaltest
import FHDLTestCase
11 from nmutil
.sim_util
import do_sim
, hash_256
12 from nmigen
.sim
import Tick
, Delay
13 from nmigen
.hdl
.ast
import Signal
14 from nmigen
.hdl
.dsl
import Module
15 from soc
.fu
.div
.experiment
.goldschmidt_div_sqrt
import (
16 GoldschmidtDivHDL
, GoldschmidtDivHDLState
, GoldschmidtDivOp
, GoldschmidtDivParams
,
17 GoldschmidtDivState
, ParamsNotAccurateEnough
, goldschmidt_div
,
18 FixedPoint
, RoundDir
, goldschmidt_sqrt_rsqrt
)
21 class TestFixedPoint(FHDLTestCase
):
22 def test_str_roundtrip(self
):
23 for frac_wid
in range(8):
24 for bits
in range(-1 << 9, 1 << 9):
25 with self
.subTest(bits
=hex(bits
), frac_wid
=frac_wid
):
26 value
= FixedPoint(bits
, frac_wid
)
27 round_trip_value
= FixedPoint
.cast(str(value
))
28 self
.assertEqual(value
, round_trip_value
)
34 except (ValueError, ZeroDivisionError) as e
:
35 return None, e
.__class
__.__name
__
38 for frac_wid
in range(8):
39 for bits
in range(1 << 9):
40 for round_dir
in RoundDir
:
41 radicand
= FixedPoint(bits
, frac_wid
)
42 expected_f
= math
.sqrt(float(radicand
))
43 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
44 expected_f
, frac_wid
, round_dir
))
45 with self
.subTest(radicand
=repr(radicand
),
46 round_dir
=str(round_dir
),
47 expected
=repr(expected
)):
48 result
= self
.trap(lambda: radicand
.sqrt(round_dir
))
49 self
.assertEqual(result
, expected
)
52 for frac_wid
in range(8):
53 for bits
in range(1, 1 << 9):
54 for round_dir
in RoundDir
:
55 radicand
= FixedPoint(bits
, frac_wid
)
56 expected_f
= 1 / math
.sqrt(float(radicand
))
57 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
58 expected_f
, frac_wid
, round_dir
))
59 with self
.subTest(radicand
=repr(radicand
),
60 round_dir
=str(round_dir
),
61 expected
=repr(expected
)):
62 result
= self
.trap(lambda: radicand
.rsqrt(round_dir
))
63 self
.assertEqual(result
, expected
)
66 class TestGoldschmidtDiv(FHDLTestCase
):
68 with self
.assertRaises(ParamsNotAccurateEnough
):
69 GoldschmidtDivParams(io_width
=3, extra_precision
=2,
70 table_addr_bits
=3, table_data_bits
=5,
74 with self
.assertRaises(ParamsNotAccurateEnough
):
75 GoldschmidtDivParams(io_width
=4, extra_precision
=1,
76 table_addr_bits
=1, table_data_bits
=5,
80 def cases(io_width
, cases
=None):
81 assert isinstance(io_width
, int) and io_width
>= 1
84 assert isinstance(d
, int) \
85 and 0 < d
< (1 << io_width
), "invalid case"
86 assert isinstance(n
, int) \
87 and 0 <= n
< (d
<< io_width
), "invalid case"
90 assert io_width
* 2 <= 256, \
91 "can't generate big enough numbers for test cases"
92 for i
in range(10000):
93 d
= hash_256(f
'd {i}') % (1 << io_width
)
96 n
= hash_256(f
'n {i}') % (d
<< io_width
)
99 for d
in range(1, 1 << io_width
):
100 for n
in range(d
<< io_width
):
103 def tst(self
, io_width
, cases
=None):
104 assert isinstance(io_width
, int)
105 params
= GoldschmidtDivParams
.get(io_width
)
106 with self
.subTest(params
=str(params
)):
107 for n
, d
in self
.cases(io_width
, cases
):
108 expected_q
, expected_r
= divmod(n
, d
)
109 with self
.subTest(n
=hex(n
), d
=hex(d
),
110 expected_q
=hex(expected_q
),
111 expected_r
=hex(expected_r
)):
112 q
, r
= goldschmidt_div(n
, d
, params
)
113 with self
.subTest(q
=hex(q
), r
=hex(r
)):
114 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
116 def tst_sim(self
, io_width
, cases
=None, pipe_reg_indexes
=(),
118 assert isinstance(io_width
, int)
119 params
= GoldschmidtDivParams
.get(io_width
)
121 dut
= GoldschmidtDivHDL(params
, pipe_reg_indexes
=pipe_reg_indexes
,
123 m
.submodules
.dut
= dut
124 # make sync domain get added
125 m
.d
.sync
+= Signal().eq(0)
129 for n
, d
in self
.cases(io_width
, cases
):
134 def check_interals(n
, d
):
135 # check internals only if dut is completely combinatorial
136 # so we don't have to figure out how to read values in
137 # previous clock cycles
138 if dut
.total_pipeline_registers
!= 0:
142 def ref_trace_fn(state
):
143 assert isinstance(state
, GoldschmidtDivState
)
144 ref_trace
.append((replace(state
)))
145 goldschmidt_div(n
=n
, d
=d
, params
=params
, trace
=ref_trace_fn
)
146 self
.assertEqual(len(dut
.trace
), len(ref_trace
))
147 for index
, state
in enumerate(dut
.trace
):
148 ref_state
= ref_trace
[index
]
149 last_op
= None if index
== 0 else params
.ops
[index
- 1]
150 with self
.subTest(index
=index
, state
=repr(state
),
151 ref_state
=repr(ref_state
),
152 last_op
=str(last_op
)):
153 for field
in fields(GoldschmidtDivHDLState
):
154 sig
= getattr(state
, field
.name
)
155 if not isinstance(sig
, Signal
):
157 ref_value
= getattr(ref_state
, field
.name
)
158 ref_value_str
= repr(ref_value
)
159 if isinstance(ref_value
, int):
160 ref_value_str
= hex(ref_value
)
162 with self
.subTest(field_name
=field
.name
,
164 sig_shape
=repr(sig
.shape()),
166 ref_value
=ref_value_str
):
167 if isinstance(ref_value
, int):
168 self
.assertEqual(value
, ref_value
)
170 assert isinstance(ref_value
, FixedPoint
)
171 self
.assertEqual(value
, ref_value
.bits
)
175 for _
in range(dut
.total_pipeline_registers
):
177 for n
, d
in self
.cases(io_width
, cases
):
179 expected_q
, expected_r
= divmod(n
, d
)
180 with self
.subTest(n
=hex(n
), d
=hex(d
),
181 expected_q
=hex(expected_q
),
182 expected_r
=hex(expected_r
)):
185 with self
.subTest(q
=hex(q
), r
=hex(r
)):
186 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
187 yield from check_interals(n
, d
)
191 with self
.subTest(params
=str(params
)):
192 with
do_sim(self
, m
, (dut
.n
, dut
.d
, dut
.q
, dut
.r
)) as sim
:
194 sim
.add_process(inputs_proc
)
195 sim
.add_process(check_outputs
)
198 def test_1_through_4(self
):
199 for io_width
in range(1, 4 + 1):
200 with self
.subTest(io_width
=io_width
):
221 def test_sim_5(self
):
224 def test_sim_8(self
):
227 def test_sim_16(self
):
230 def test_sim_32(self
):
233 def test_sim_64(self
):
236 def tst_params(self
, io_width
):
237 assert isinstance(io_width
, int)
238 params
= GoldschmidtDivParams
.get(io_width
)
242 def test_params_1(self
):
245 def test_params_2(self
):
248 def test_params_3(self
):
251 def test_params_4(self
):
254 def test_params_5(self
):
257 def test_params_6(self
):
260 def test_params_7(self
):
263 def test_params_8(self
):
266 def test_params_9(self
):
269 def test_params_10(self
):
272 def test_params_11(self
):
275 def test_params_12(self
):
278 def test_params_13(self
):
281 def test_params_14(self
):
284 def test_params_15(self
):
287 def test_params_16(self
):
290 def test_params_17(self
):
293 def test_params_18(self
):
296 def test_params_19(self
):
299 def test_params_20(self
):
302 def test_params_21(self
):
305 def test_params_22(self
):
308 def test_params_23(self
):
311 def test_params_24(self
):
314 def test_params_25(self
):
317 def test_params_26(self
):
320 def test_params_27(self
):
323 def test_params_28(self
):
326 def test_params_29(self
):
329 def test_params_30(self
):
332 def test_params_31(self
):
335 def test_params_32(self
):
338 def test_params_33(self
):
341 def test_params_34(self
):
344 def test_params_35(self
):
347 def test_params_36(self
):
350 def test_params_37(self
):
353 def test_params_38(self
):
356 def test_params_39(self
):
359 def test_params_40(self
):
362 def test_params_41(self
):
365 def test_params_42(self
):
368 def test_params_43(self
):
371 def test_params_44(self
):
374 def test_params_45(self
):
377 def test_params_46(self
):
380 def test_params_47(self
):
383 def test_params_48(self
):
386 def test_params_49(self
):
389 def test_params_50(self
):
392 def test_params_51(self
):
395 def test_params_52(self
):
398 def test_params_53(self
):
401 def test_params_54(self
):
404 def test_params_55(self
):
407 def test_params_56(self
):
410 def test_params_57(self
):
413 def test_params_58(self
):
416 def test_params_59(self
):
419 def test_params_60(self
):
422 def test_params_61(self
):
425 def test_params_62(self
):
428 def test_params_63(self
):
431 def test_params_64(self
):
435 class TestGoldschmidtSqrtRSqrt(FHDLTestCase
):
436 def tst(self
, io_width
, frac_wid
, extra_precision
,
437 table_addr_bits
, table_data_bits
, iter_count
):
438 assert isinstance(io_width
, int)
439 assert isinstance(frac_wid
, int)
440 assert isinstance(extra_precision
, int)
441 assert isinstance(table_addr_bits
, int)
442 assert isinstance(table_data_bits
, int)
443 assert isinstance(iter_count
, int)
444 with self
.subTest(io_width
=io_width
, frac_wid
=frac_wid
,
445 extra_precision
=extra_precision
,
446 table_addr_bits
=table_addr_bits
,
447 table_data_bits
=table_data_bits
,
448 iter_count
=iter_count
):
449 for bits
in range(1 << io_width
):
450 radicand
= FixedPoint(bits
, frac_wid
)
451 expected_sqrt
= radicand
.sqrt(RoundDir
.DOWN
)
452 expected_rsqrt
= FixedPoint(0, frac_wid
)
454 expected_rsqrt
= radicand
.rsqrt(RoundDir
.DOWN
)
455 with self
.subTest(radicand
=repr(radicand
),
456 expected_sqrt
=repr(expected_sqrt
),
457 expected_rsqrt
=repr(expected_rsqrt
)):
458 sqrt
, rsqrt
= goldschmidt_sqrt_rsqrt(
459 radicand
=radicand
, io_width
=io_width
,
461 extra_precision
=extra_precision
,
462 table_addr_bits
=table_addr_bits
,
463 table_data_bits
=table_data_bits
,
464 iter_count
=iter_count
)
465 with self
.subTest(sqrt
=repr(sqrt
), rsqrt
=repr(rsqrt
)):
466 self
.assertEqual((sqrt
, rsqrt
),
467 (expected_sqrt
, expected_rsqrt
))
470 self
.tst(io_width
=16, frac_wid
=8, extra_precision
=20,
471 table_addr_bits
=4, table_data_bits
=28, iter_count
=4)
474 if __name__
== "__main__":