1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import fields
, replace
10 from nmutil
.formaltest
import FHDLTestCase
11 from nmutil
.sim_util
import do_sim
12 from nmigen
.sim
import Tick
, Delay
13 from nmigen
.hdl
.ast
import Signal
14 from nmigen
.hdl
.dsl
import Module
15 from soc
.fu
.div
.experiment
.goldschmidt_div_sqrt
import (
16 GoldschmidtDivHDL
, GoldschmidtDivHDLState
, GoldschmidtDivOp
, GoldschmidtDivParams
,
17 GoldschmidtDivState
, ParamsNotAccurateEnough
, goldschmidt_div
,
18 FixedPoint
, RoundDir
, goldschmidt_sqrt_rsqrt
)
21 class TestFixedPoint(FHDLTestCase
):
22 def test_str_roundtrip(self
):
23 for frac_wid
in range(8):
24 for bits
in range(-1 << 9, 1 << 9):
25 with self
.subTest(bits
=hex(bits
), frac_wid
=frac_wid
):
26 value
= FixedPoint(bits
, frac_wid
)
27 round_trip_value
= FixedPoint
.cast(str(value
))
28 self
.assertEqual(value
, round_trip_value
)
34 except (ValueError, ZeroDivisionError) as e
:
35 return None, e
.__class
__.__name
__
38 for frac_wid
in range(8):
39 for bits
in range(1 << 9):
40 for round_dir
in RoundDir
:
41 radicand
= FixedPoint(bits
, frac_wid
)
42 expected_f
= math
.sqrt(float(radicand
))
43 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
44 expected_f
, frac_wid
, round_dir
))
45 with self
.subTest(radicand
=repr(radicand
),
46 round_dir
=str(round_dir
),
47 expected
=repr(expected
)):
48 result
= self
.trap(lambda: radicand
.sqrt(round_dir
))
49 self
.assertEqual(result
, expected
)
52 for frac_wid
in range(8):
53 for bits
in range(1, 1 << 9):
54 for round_dir
in RoundDir
:
55 radicand
= FixedPoint(bits
, frac_wid
)
56 expected_f
= 1 / math
.sqrt(float(radicand
))
57 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
58 expected_f
, frac_wid
, round_dir
))
59 with self
.subTest(radicand
=repr(radicand
),
60 round_dir
=str(round_dir
),
61 expected
=repr(expected
)):
62 result
= self
.trap(lambda: radicand
.rsqrt(round_dir
))
63 self
.assertEqual(result
, expected
)
66 class TestGoldschmidtDiv(FHDLTestCase
):
68 with self
.assertRaises(ParamsNotAccurateEnough
):
69 GoldschmidtDivParams(io_width
=3, extra_precision
=2,
70 table_addr_bits
=3, table_data_bits
=5,
74 with self
.assertRaises(ParamsNotAccurateEnough
):
75 GoldschmidtDivParams(io_width
=4, extra_precision
=1,
76 table_addr_bits
=1, table_data_bits
=5,
79 def tst(self
, io_width
):
80 assert isinstance(io_width
, int)
81 params
= GoldschmidtDivParams
.get(io_width
)
82 with self
.subTest(params
=str(params
)):
83 for d
in range(1, 1 << io_width
):
84 for n
in range(d
<< io_width
):
85 expected_q
, expected_r
= divmod(n
, d
)
86 with self
.subTest(n
=hex(n
), d
=hex(d
),
87 expected_q
=hex(expected_q
),
88 expected_r
=hex(expected_r
)):
89 q
, r
= goldschmidt_div(n
, d
, params
)
90 with self
.subTest(q
=hex(q
), r
=hex(r
)):
91 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
93 def tst_sim(self
, io_width
, cases
=None, pipe_reg_indexes
=(),
95 assert isinstance(io_width
, int)
96 params
= GoldschmidtDivParams
.get(io_width
)
98 dut
= GoldschmidtDivHDL(params
, pipe_reg_indexes
=pipe_reg_indexes
,
100 m
.submodules
.dut
= dut
101 # make sync domain get added
102 m
.d
.sync
+= Signal().eq(0)
105 if cases
is not None:
107 assert isinstance(d
, int) \
108 and 0 < d
< (1 << params
.io_width
), "invalid case"
109 assert isinstance(n
, int) \
110 and 0 <= n
< (d
<< params
.io_width
), "invalid case"
113 for d
in range(1, 1 << io_width
):
114 for n
in range(d
<< io_width
):
119 for n
, d
in iter_cases():
124 def check_interals(n
, d
):
125 # check internals only if dut is completely combinatorial
126 # so we don't have to figure out how to read values in
127 # previous clock cycles
128 if dut
.total_pipeline_registers
!= 0:
132 def ref_trace_fn(state
):
133 assert isinstance(state
, GoldschmidtDivState
)
134 ref_trace
.append((replace(state
)))
135 goldschmidt_div(n
=n
, d
=d
, params
=params
, trace
=ref_trace_fn
)
136 self
.assertEqual(len(dut
.trace
), len(ref_trace
))
137 for index
, state
in enumerate(dut
.trace
):
138 ref_state
= ref_trace
[index
]
139 last_op
= None if index
== 0 else params
.ops
[index
- 1]
140 with self
.subTest(index
=index
, state
=repr(state
),
141 ref_state
=repr(ref_state
),
142 last_op
=str(last_op
)):
143 for field
in fields(GoldschmidtDivHDLState
):
144 sig
= getattr(state
, field
.name
)
145 if not isinstance(sig
, Signal
):
147 ref_value
= getattr(ref_state
, field
.name
)
148 ref_value_str
= repr(ref_value
)
149 if isinstance(ref_value
, int):
150 ref_value_str
= hex(ref_value
)
152 with self
.subTest(field_name
=field
.name
,
154 sig_shape
=repr(sig
.shape()),
156 ref_value
=ref_value_str
):
157 if isinstance(ref_value
, int):
158 self
.assertEqual(value
, ref_value
)
160 assert isinstance(ref_value
, FixedPoint
)
161 self
.assertEqual(value
, ref_value
.bits
)
165 for _
in range(dut
.total_pipeline_registers
):
167 for n
, d
in iter_cases():
169 expected_q
, expected_r
= divmod(n
, d
)
170 with self
.subTest(n
=hex(n
), d
=hex(d
),
171 expected_q
=hex(expected_q
),
172 expected_r
=hex(expected_r
)):
175 with self
.subTest(q
=hex(q
), r
=hex(r
)):
176 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
177 yield from check_interals(n
, d
)
181 with self
.subTest(params
=str(params
)):
182 with
do_sim(self
, m
, (dut
.n
, dut
.d
, dut
.q
, dut
.r
)) as sim
:
184 sim
.add_process(inputs_proc
)
185 sim
.add_process(check_outputs
)
188 def test_1_through_4(self
):
189 for io_width
in range(1, 4 + 1):
190 with self
.subTest(io_width
=io_width
):
199 def test_sim_5(self
):
202 def tst_params(self
, io_width
):
203 assert isinstance(io_width
, int)
204 params
= GoldschmidtDivParams
.get(io_width
)
208 def test_params_1(self
):
211 def test_params_2(self
):
214 def test_params_3(self
):
217 def test_params_4(self
):
220 def test_params_5(self
):
223 def test_params_6(self
):
226 def test_params_7(self
):
229 def test_params_8(self
):
232 def test_params_9(self
):
235 def test_params_10(self
):
238 def test_params_11(self
):
241 def test_params_12(self
):
244 def test_params_13(self
):
247 def test_params_14(self
):
250 def test_params_15(self
):
253 def test_params_16(self
):
256 def test_params_17(self
):
259 def test_params_18(self
):
262 def test_params_19(self
):
265 def test_params_20(self
):
268 def test_params_21(self
):
271 def test_params_22(self
):
274 def test_params_23(self
):
277 def test_params_24(self
):
280 def test_params_25(self
):
283 def test_params_26(self
):
286 def test_params_27(self
):
289 def test_params_28(self
):
292 def test_params_29(self
):
295 def test_params_30(self
):
298 def test_params_31(self
):
301 def test_params_32(self
):
304 def test_params_33(self
):
307 def test_params_34(self
):
310 def test_params_35(self
):
313 def test_params_36(self
):
316 def test_params_37(self
):
319 def test_params_38(self
):
322 def test_params_39(self
):
325 def test_params_40(self
):
328 def test_params_41(self
):
331 def test_params_42(self
):
334 def test_params_43(self
):
337 def test_params_44(self
):
340 def test_params_45(self
):
343 def test_params_46(self
):
346 def test_params_47(self
):
349 def test_params_48(self
):
352 def test_params_49(self
):
355 def test_params_50(self
):
358 def test_params_51(self
):
361 def test_params_52(self
):
364 def test_params_53(self
):
367 def test_params_54(self
):
370 def test_params_55(self
):
373 def test_params_56(self
):
376 def test_params_57(self
):
379 def test_params_58(self
):
382 def test_params_59(self
):
385 def test_params_60(self
):
388 def test_params_61(self
):
391 def test_params_62(self
):
394 def test_params_63(self
):
397 def test_params_64(self
):
401 class TestGoldschmidtSqrtRSqrt(FHDLTestCase
):
402 def tst(self
, io_width
, frac_wid
, extra_precision
,
403 table_addr_bits
, table_data_bits
, iter_count
):
404 assert isinstance(io_width
, int)
405 assert isinstance(frac_wid
, int)
406 assert isinstance(extra_precision
, int)
407 assert isinstance(table_addr_bits
, int)
408 assert isinstance(table_data_bits
, int)
409 assert isinstance(iter_count
, int)
410 with self
.subTest(io_width
=io_width
, frac_wid
=frac_wid
,
411 extra_precision
=extra_precision
,
412 table_addr_bits
=table_addr_bits
,
413 table_data_bits
=table_data_bits
,
414 iter_count
=iter_count
):
415 for bits
in range(1 << io_width
):
416 radicand
= FixedPoint(bits
, frac_wid
)
417 expected_sqrt
= radicand
.sqrt(RoundDir
.DOWN
)
418 expected_rsqrt
= FixedPoint(0, frac_wid
)
420 expected_rsqrt
= radicand
.rsqrt(RoundDir
.DOWN
)
421 with self
.subTest(radicand
=repr(radicand
),
422 expected_sqrt
=repr(expected_sqrt
),
423 expected_rsqrt
=repr(expected_rsqrt
)):
424 sqrt
, rsqrt
= goldschmidt_sqrt_rsqrt(
425 radicand
=radicand
, io_width
=io_width
,
427 extra_precision
=extra_precision
,
428 table_addr_bits
=table_addr_bits
,
429 table_data_bits
=table_data_bits
,
430 iter_count
=iter_count
)
431 with self
.subTest(sqrt
=repr(sqrt
), rsqrt
=repr(rsqrt
)):
432 self
.assertEqual((sqrt
, rsqrt
),
433 (expected_sqrt
, expected_rsqrt
))
436 self
.tst(io_width
=16, frac_wid
=8, extra_precision
=20,
437 table_addr_bits
=4, table_data_bits
=28, iter_count
=4)
440 if __name__
== "__main__":