1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
11 from fractions
import Fraction
12 from types
import FunctionType
13 from functools
import lru_cache
14 from nmigen
.hdl
.ast
import Signal
, unsigned
, signed
, Const
15 from nmigen
.hdl
.dsl
import Module
, Elaboratable
16 from nmigen
.hdl
.mem
import Memory
17 from nmutil
.clz
import CLZ
18 from nmutil
.plain_data
import plain_data
, fields
, replace
21 from functools
import cached_property
23 from cached_property
import cached_property
25 # fix broken IDE type detection for cached_property
26 from typing
import TYPE_CHECKING
, Any
28 from functools
import cached_property
34 def cache_on_self(func
):
35 """like `functools.cached_property`, except for methods. unlike
36 `lru_cache` the cache is per-class instance rather than a global cache
39 assert isinstance(func
, FunctionType
), \
40 "non-plain methods are not supported"
42 cache_name
= func
.__name
__ + "__cache"
44 def wrapper(self
, *args
, **kwargs
):
45 # specifically access through `__dict__` to bypass frozen=True
46 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
47 if cache
is _NOT_FOUND
:
48 self
.__dict
__[cache_name
] = cache
= {}
49 key
= (args
, *kwargs
.items())
50 retval
= cache
.get(key
, _NOT_FOUND
)
51 if retval
is _NOT_FOUND
:
52 retval
= func(self
, *args
, **kwargs
)
56 wrapper
.__doc
__ = func
.__doc
__
61 class RoundDir(enum
.Enum
):
64 NEAREST_TIES_UP
= enum
.auto()
65 ERROR_IF_INEXACT
= enum
.auto()
68 @plain_data(frozen
=True, eq
=False, repr=False)
70 __slots__
= "bits", "frac_wid"
72 def __init__(self
, bits
, frac_wid
):
74 self
.frac_wid
= frac_wid
75 assert isinstance(self
.bits
, int)
76 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
80 """convert `value` to a fixed-point number with enough fractional
81 bits to preserve its value."""
82 if isinstance(value
, FixedPoint
):
84 if isinstance(value
, int):
85 return FixedPoint(value
, 0)
86 if isinstance(value
, str):
88 neg
= value
.startswith("-")
89 if neg
or value
.startswith("+"):
91 if value
.startswith(("0x", "0X")) and "." in value
:
101 raise ValueError("too many `.` in string")
106 if not digit
.isalnum():
107 raise ValueError("invalid hexadecimal digit")
109 bits |
= int("0x" + digit
, base
=16)
111 bits
= int(value
, base
=0)
115 return FixedPoint(bits
, frac_wid
)
117 if isinstance(value
, float):
118 n
, d
= value
.as_integer_ratio()
119 log2_d
= d
.bit_length() - 1
120 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
121 "fail with float being IEEE 754")
122 return FixedPoint(n
, log2_d
)
123 raise TypeError("can't convert type to FixedPoint")
126 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
127 """convert `value` to the nearest fixed-point number with `frac_wid`
128 fractional bits, rounding according to `round_dir`."""
129 assert isinstance(frac_wid
, int) and frac_wid
>= 0
130 assert isinstance(round_dir
, RoundDir
)
131 if isinstance(value
, Fraction
):
132 numerator
= value
.numerator
133 denominator
= value
.denominator
135 value
= FixedPoint
.cast(value
)
136 numerator
= value
.bits
137 denominator
= 1 << value
.frac_wid
139 numerator
= -numerator
140 denominator
= -denominator
141 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
142 if round_dir
== RoundDir
.DOWN
:
144 elif round_dir
== RoundDir
.UP
:
147 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
148 if remainder
* 2 >= denominator
:
150 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
152 raise ValueError("inexact conversion")
154 assert False, "unimplemented round_dir"
155 return FixedPoint(bits
, frac_wid
)
157 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
158 """convert to the nearest fixed-point number with `frac_wid`
159 fractional bits, rounding according to `round_dir`."""
160 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
163 # use truediv to get correct result even when bits
164 # and frac_wid are huge
165 return float(self
.bits
/ (1 << self
.frac_wid
))
167 def as_fraction(self
):
168 return Fraction(self
.bits
, 1 << self
.frac_wid
)
171 """compare self with rhs, returning a positive integer if self is
172 greater than rhs, zero if self is equal to rhs, and a negative integer
173 if self is less than rhs."""
174 rhs
= FixedPoint
.cast(rhs
)
175 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
176 lhs
= self
.to_frac_wid(common_frac_wid
)
177 rhs
= rhs
.to_frac_wid(common_frac_wid
)
178 return lhs
.bits
- rhs
.bits
180 def __eq__(self
, rhs
):
181 return self
.cmp(rhs
) == 0
183 def __ne__(self
, rhs
):
184 return self
.cmp(rhs
) != 0
186 def __gt__(self
, rhs
):
187 return self
.cmp(rhs
) > 0
189 def __lt__(self
, rhs
):
190 return self
.cmp(rhs
) < 0
192 def __ge__(self
, rhs
):
193 return self
.cmp(rhs
) >= 0
195 def __le__(self
, rhs
):
196 return self
.cmp(rhs
) <= 0
199 """return the fractional part of `self`.
200 that is `self - math.floor(self)`.
202 fract_mask
= (1 << self
.frac_wid
) - 1
203 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
207 return "-" + str(-self
)
209 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
210 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
211 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
212 return hex(math
.floor(self
)) + "." + frac_str
215 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
217 def __add__(self
, rhs
):
218 rhs
= FixedPoint
.cast(rhs
)
219 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
220 lhs
= self
.to_frac_wid(common_frac_wid
)
221 rhs
= rhs
.to_frac_wid(common_frac_wid
)
222 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
224 def __radd__(self
, lhs
):
226 return self
.__add
__(lhs
)
229 return FixedPoint(-self
.bits
, self
.frac_wid
)
231 def __sub__(self
, rhs
):
232 rhs
= FixedPoint
.cast(rhs
)
233 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
234 lhs
= self
.to_frac_wid(common_frac_wid
)
235 rhs
= rhs
.to_frac_wid(common_frac_wid
)
236 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
238 def __rsub__(self
, lhs
):
240 return -self
.__sub
__(lhs
)
242 def __mul__(self
, rhs
):
243 rhs
= FixedPoint
.cast(rhs
)
244 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
246 def __rmul__(self
, lhs
):
248 return self
.__mul
__(lhs
)
251 return self
.bits
>> self
.frac_wid
253 def div(self
, rhs
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
254 assert isinstance(frac_wid
, int) and frac_wid
>= 0
255 assert isinstance(round_dir
, RoundDir
)
256 rhs
= FixedPoint
.cast(rhs
)
257 return FixedPoint
.with_frac_wid(self
.as_fraction()
261 def sqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
262 assert isinstance(round_dir
, RoundDir
)
264 raise ValueError("can't compute sqrt of negative number")
267 retval
= FixedPoint(0, self
.frac_wid
)
268 int_part_wid
= self
.bits
.bit_length() - self
.frac_wid
269 first_bit_index
= -(-int_part_wid
// 2) # division rounds up
270 last_bit_index
= -self
.frac_wid
271 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
272 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
274 if trial
* trial
<= self
:
276 if round_dir
== RoundDir
.DOWN
:
278 elif round_dir
== RoundDir
.UP
:
279 if retval
* retval
< self
:
280 retval
+= FixedPoint(1, self
.frac_wid
)
281 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
282 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
283 if half_way
* half_way
<= self
:
284 retval
+= FixedPoint(1, self
.frac_wid
)
285 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
286 if retval
* retval
!= self
:
287 raise ValueError("inexact sqrt")
289 assert False, "unimplemented round_dir"
292 def rsqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
293 """compute the reciprocal-sqrt of `self`"""
294 assert isinstance(round_dir
, RoundDir
)
296 raise ValueError("can't compute rsqrt of negative number")
298 raise ZeroDivisionError("can't compute rsqrt of zero")
299 retval
= FixedPoint(0, self
.frac_wid
)
300 first_bit_index
= -(-self
.frac_wid
// 2) # division rounds up
301 last_bit_index
= -self
.frac_wid
302 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
303 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
305 if trial
* trial
* self
<= 1:
307 if round_dir
== RoundDir
.DOWN
:
309 elif round_dir
== RoundDir
.UP
:
310 if retval
* retval
* self
< 1:
311 retval
+= FixedPoint(1, self
.frac_wid
)
312 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
313 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
314 if half_way
* half_way
* self
<= 1:
315 retval
+= FixedPoint(1, self
.frac_wid
)
316 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
317 if retval
* retval
* self
!= 1:
318 raise ValueError("inexact rsqrt")
320 assert False, "unimplemented round_dir"
324 class ParamsNotAccurateEnough(Exception):
325 """raised when the parameters aren't accurate enough to have goldschmidt
329 def _assert_accuracy(condition
, msg
="not accurate enough"):
332 raise ParamsNotAccurateEnough(msg
)
335 @plain_data(frozen
=True, unsafe_hash
=True)
336 class GoldschmidtDivParamsBase
:
337 """parameters for a Goldschmidt division algorithm, excluding derived
341 __slots__
= ("io_width", "extra_precision", "table_addr_bits",
342 "table_data_bits", "iter_count")
344 def __init__(self
, io_width
, extra_precision
, table_addr_bits
,
345 table_data_bits
, iter_count
):
346 assert isinstance(io_width
, int)
347 assert isinstance(extra_precision
, int)
348 assert isinstance(table_addr_bits
, int)
349 assert isinstance(table_data_bits
, int)
350 assert isinstance(iter_count
, int)
351 self
.io_width
= io_width
352 """bit-width of the input divisor and the result.
353 the input numerator is `2 * io_width`-bits wide.
356 self
.extra_precision
= extra_precision
357 """number of bits of additional precision used inside the algorithm."""
359 self
.table_addr_bits
= table_addr_bits
360 """the number of address bits used in the lookup-table."""
362 self
.table_data_bits
= table_data_bits
363 """the number of data bits used in the lookup-table."""
365 self
.iter_count
= iter_count
366 """the total number of iterations of the division algorithm's loop"""
369 @plain_data(frozen
=True, unsafe_hash
=True)
370 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
371 """parameters for a Goldschmidt division algorithm.
372 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
375 __slots__
= "table", "ops"
377 def _shrink_bound(self
, bound
, round_dir
):
378 """prevent fractions from having huge numerators/denominators by
379 rounding to a `FixedPoint` and converting back to a `Fraction`.
381 This is intended only for values used to compute bounds, and not for
382 values that end up in the hardware.
384 assert isinstance(bound
, (Fraction
, int))
385 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
386 "you shouldn't use that round_dir on bounds"
387 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
388 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
389 return fixed
.as_fraction()
391 def _shrink_min(self
, min_bound
):
392 """prevent fractions used as minimum bounds from having huge
393 numerators/denominators by rounding down to a `FixedPoint` and
394 converting back to a `Fraction`.
396 This is intended only for values used to compute bounds, and not for
397 values that end up in the hardware.
399 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
401 def _shrink_max(self
, max_bound
):
402 """prevent fractions used as maximum bounds from having huge
403 numerators/denominators by rounding up to a `FixedPoint` and
404 converting back to a `Fraction`.
406 This is intended only for values used to compute bounds, and not for
407 values that end up in the hardware.
409 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
412 def table_addr_count(self
):
413 """number of distinct addresses in the lookup-table."""
414 # used while computing self.table, so can't just do len(self.table)
415 return 1 << self
.table_addr_bits
417 def table_input_exact_range(self
, addr
):
418 """return the range of inputs as `Fraction`s used for the table entry
419 with address `addr`."""
420 assert isinstance(addr
, int)
421 assert 0 <= addr
< self
.table_addr_count
422 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
423 addr_shift
= self
.io_width
- self
.table_addr_bits
424 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
425 denominator
= 1 << self
.io_width
426 values_per_table_entry
= 1 << addr_shift
427 max_numerator
= min_numerator
+ values_per_table_entry
- 1
428 min_input
= Fraction(min_numerator
, denominator
)
429 max_input
= Fraction(max_numerator
, denominator
)
430 min_input
= self
._shrink
_min
(min_input
)
431 max_input
= self
._shrink
_max
(max_input
)
432 assert 1 <= min_input
<= max_input
< 2
433 return min_input
, max_input
435 def table_value_exact_range(self
, addr
):
436 """return the range of values as `Fraction`s used for the table entry
437 with address `addr`."""
438 min_input
, max_input
= self
.table_input_exact_range(addr
)
439 # division swaps min/max
440 min_value
= 1 / max_input
441 max_value
= 1 / min_input
442 min_value
= self
._shrink
_min
(min_value
)
443 max_value
= self
._shrink
_max
(max_value
)
444 assert 0.5 < min_value
<= max_value
<= 1
445 return min_value
, max_value
447 def table_exact_value(self
, index
):
448 min_value
, max_value
= self
.table_value_exact_range(index
)
452 def __init__(self
, io_width
, extra_precision
, table_addr_bits
,
453 table_data_bits
, iter_count
):
454 super().__init
__(io_width
=io_width
,
455 extra_precision
=extra_precision
,
456 table_addr_bits
=table_addr_bits
,
457 table_data_bits
=table_data_bits
,
458 iter_count
=iter_count
)
459 _assert_accuracy(self
.io_width
>= 1, "io_width out of range")
460 _assert_accuracy(self
.extra_precision
>= 0,
461 "extra_precision out of range")
462 _assert_accuracy(self
.table_addr_bits
>= 1,
463 "table_addr_bits out of range")
464 _assert_accuracy(self
.table_data_bits
>= 1,
465 "table_data_bits out of range")
466 _assert_accuracy(self
.iter_count
>= 1, "iter_count out of range")
468 for addr
in range(1 << self
.table_addr_bits
):
469 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
470 self
.table_data_bits
,
473 self
.table
= tuple(table
)
474 """ the lookup-table.
475 type: tuple[FixedPoint, ...]
478 self
.ops
= tuple(self
.__make
_ops
())
479 "the operations needed to perform the goldschmidt division algorithm."
482 def expanded_width(self
):
483 """the total number of bits of precision used inside the algorithm."""
484 return self
.io_width
+ self
.extra_precision
487 def n_d_f_int_wid(self
):
488 """the number of bits in the integer part of `state.n`, `state.d`, and
489 `state.f` during the main iteration loop.
494 def n_d_f_total_wid(self
):
495 """the total number of bits (both integer and fraction bits) in
496 `state.n`, `state.d`, and `state.f` during the main iteration loop.
498 return self
.n_d_f_int_wid
+ self
.expanded_width
501 def max_neps(self
, i
):
502 """maximum value of `neps[i]`.
503 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
505 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
506 return Fraction(1, 1 << self
.expanded_width
)
509 def max_deps(self
, i
):
510 """maximum value of `deps[i]`.
511 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
513 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
514 return Fraction(1, 1 << self
.expanded_width
)
517 def max_feps(self
, i
):
518 """maximum value of `feps[i]`.
519 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
521 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
522 # zero, because the computation of `F_prime[i]` in
523 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
528 """minimum and maximum values of `e[0]`
529 (the relative error in `F_prime[-1]`)
533 for addr
in range(self
.table_addr_count
):
534 # `F_prime[-1] = (1 - e[0]) / B`
535 # => `e[0] = 1 - B * F_prime[-1]`
536 min_b
, max_b
= self
.table_input_exact_range(addr
)
537 f_prime_m1
= self
.table
[addr
].as_fraction()
538 assert min_b
>= 0 and f_prime_m1
>= 0, \
539 "only positive quadrant of interval multiplication implemented"
540 min_product
= min_b
* f_prime_m1
541 max_product
= max_b
* f_prime_m1
542 # negation swaps min/max
543 cur_min_e0
= 1 - max_product
544 cur_max_e0
= 1 - min_product
545 min_e0
= min(min_e0
, cur_min_e0
)
546 max_e0
= max(max_e0
, cur_max_e0
)
547 min_e0
= self
._shrink
_min
(min_e0
)
548 max_e0
= self
._shrink
_max
(max_e0
)
549 return min_e0
, max_e0
553 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
555 min_e0
, max_e0
= self
.e0_range
560 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
562 min_e0
, max_e0
= self
.e0_range
566 def max_abs_e0(self
):
567 """maximum value of `abs(e[0])`."""
568 return max(abs(self
.min_e0
), abs(self
.max_e0
))
571 def min_abs_e0(self
):
572 """minimum value of `abs(e[0])`."""
577 """maximum value of `n[i]` (the relative error in `N_prime[i]`
578 relative to the previous iteration)
580 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
583 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
584 # `n[0] <= 2 * neps[0] / (1 - e[0])`
586 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
587 "only one quadrant of interval division implemented"
588 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
591 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
592 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
593 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
594 "only one quadrant of interval multiplication implemented"
595 prod
= (1 - self
.max_f(0)) * min_mpd
596 assert self
.max_neps(1) >= 0 and prod
> 0, \
597 "only one quadrant of interval division implemented"
598 retval
= self
.max_neps(1) / prod
601 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
602 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
603 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
604 "only one quadrant of interval division implemented"
605 retval
= self
.max_neps(i
) / min_mpd
607 return self
._shrink
_max
(retval
)
611 """maximum value of `d[i]` (the relative error in `D_prime[i]`
612 relative to the previous iteration)
614 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
617 # `d[0] = deps[0] / (1 - e[0])`
619 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
620 "only one quadrant of interval division implemented"
621 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
624 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
625 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
626 "only one quadrant of interval multiplication implemented"
627 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
628 assert self
.max_deps(1) >= 0 and divisor
> 0, \
629 "only one quadrant of interval division implemented"
630 retval
= self
.max_deps(1) / divisor
633 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
634 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
635 "only one quadrant of interval division implemented"
636 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
638 return self
._shrink
_max
(retval
)
642 """maximum value of `f[i]` (the relative error in `F_prime[i]`
643 relative to the previous iteration)
645 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
648 # `f[0] = feps[0] / (1 - delta[0])`
650 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
651 "only one quadrant of interval division implemented"
652 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
656 retval
= self
.max_feps(1)
659 # `f[i] <= max_feps[i]`
660 retval
= self
.max_feps(i
)
662 return self
._shrink
_max
(retval
)
665 def max_delta(self
, i
):
666 """ maximum value of `delta[i]`.
667 `delta[i]` is defined in Definition 4 of paper.
669 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
671 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
672 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
674 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
675 prev_max_delta
= self
.max_delta(i
- 1)
676 assert prev_max_delta
>= 0
677 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
679 # `delta[i]` has to be smaller than one otherwise errors would go off
681 _assert_accuracy(retval
< 1)
683 return self
._shrink
_max
(retval
)
687 """ maximum value of `pi[i]`.
688 `pi[i]` is defined right below Theorem 5 of paper.
690 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
691 # `pi[i] = 1 - (1 - n[i]) * prod`
692 # where `prod` is the product of,
693 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
694 min_prod
= Fraction(1)
696 max_n_j
= self
.max_n(j
)
697 max_d_j
= self
.max_d(j
)
698 assert max_n_j
<= 1 and max_d_j
> -1, \
699 "only one quadrant of interval division implemented"
700 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
701 max_n_i
= self
.max_n(i
)
702 assert max_n_i
<= 1 and min_prod
>= 0, \
703 "only one quadrant of interval multiplication implemented"
704 retval
= 1 - (1 - max_n_i
) * min_prod
705 return self
._shrink
_max
(retval
)
708 def max_n_shift(self
):
709 """ maximum value of `state.n_shift`.
711 # numerator must be less than `denominator << self.io_width`, so
712 # `n_shift` is at most `self.io_width`
717 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
720 for i
in range(self
.iter_count
):
721 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
722 return self
._shrink
_max
(n_hat
)
724 def __make_ops(self
):
725 """ Goldschmidt division algorithm.
728 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
729 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
730 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
732 yields: GoldschmidtDivOp
733 the operations needed to perform the division.
735 # establish assumptions of the paper's error analysis (section 3.1):
737 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
738 yield GoldschmidtDivOp
.Normalize
740 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
741 # the assumption is met by multipliers with > 4-bits precision
742 _assert_accuracy(self
.expanded_width
> 4)
744 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
745 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
746 + self
.max_f(0) < Fraction(1, 2))
748 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
749 # (B is the denominator)
751 for addr
in range(self
.table_addr_count
):
752 f_prime_m1
= self
.table
[addr
]
753 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
755 yield GoldschmidtDivOp
.FEqTableLookup
757 # we use Setting I (section 4.1 of the paper):
758 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
759 # the conditions on n_hat are satisfied by construction.
760 for i
in range(self
.iter_count
):
761 _assert_accuracy(self
.max_f(i
) == 0)
762 yield GoldschmidtDivOp
.MulNByF
763 if i
!= self
.iter_count
- 1:
764 yield GoldschmidtDivOp
.MulDByF
765 yield GoldschmidtDivOp
.FEq2MinusD
767 # relative approximation error `p(N_prime[i])`:
768 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
769 # `0 <= p(N_prime[i])`
770 # `p(N_prime[i]) <= (2 * i) * n_hat \`
771 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
772 i
= self
.iter_count
- 1 # last used `i`
773 # compute power manually to prevent huge intermediate values
774 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
776 power
= self
._shrink
_max
(power
* power
)
778 max_rel_error
= (2 * i
) * self
.n_hat
+ power
780 min_a_over_b
= Fraction(1, 2)
781 min_abs_error_for_correctness
= min_a_over_b
/ (1 << self
.max_n_shift
)
782 min_rel_error_for_correctness
= (min_abs_error_for_correctness
786 max_rel_error
< min_rel_error_for_correctness
,
787 f
"not accurate enough: max_rel_error={max_rel_error}"
788 f
" min_rel_error_for_correctness={min_rel_error_for_correctness}")
790 yield GoldschmidtDivOp
.CalcResult
793 def default_cost_fn(self
):
794 """ calculate the estimated cost on an arbitrary scale of implementing
795 goldschmidt division with the specified parameters. larger cost
796 values mean worse parameters.
798 This is the default cost function for `GoldschmidtDivParams.get`.
802 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
803 cost
= float(rom_cells
)
805 if op
== GoldschmidtDivOp
.MulNByF \
806 or op
== GoldschmidtDivOp
.MulDByF
:
807 mul_cost
= self
.expanded_width
** 2
808 mul_cost
*= self
.expanded_width
.bit_length()
810 cost
+= 5e7
* self
.iter_count
814 @lru_cache(maxsize
=1 << 16)
815 def __cached_new(base_params
):
816 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
818 for field
in fields(GoldschmidtDivParamsBase
):
819 kwargs
[field
] = getattr(base_params
, field
)
821 return GoldschmidtDivParams(**kwargs
), None
822 except ParamsNotAccurateEnough
as e
:
826 def __raise(e
): # type: (ParamsNotAccurateEnough) -> Any
830 def cached_new(base_params
, handle_error
=__raise
):
831 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
832 params
, error
= GoldschmidtDivParams
.__cached
_new
(base_params
)
836 return handle_error(error
)
839 def get(io_width
, cost_fn
=default_cost_fn
, max_table_addr_bits
=12):
840 """ find efficient parameters for a goldschmidt division algorithm
841 with `params.io_width == io_width`.
845 bit-width of the input divisor and the result.
846 the input numerator is `2 * io_width`-bits wide.
847 cost_fn: Callable[[GoldschmidtDivParams], float]
848 return the estimated cost on an arbitrary scale of implementing
849 goldschmidt division with the specified parameters. larger cost
850 values mean worse parameters.
851 max_table_addr_bits: int
852 maximum allowable value of `table_addr_bits`
854 assert isinstance(io_width
, int) and io_width
>= 1
855 assert callable(cost_fn
)
858 last_error_params
= None
860 def cached_new(base_params
):
862 nonlocal last_error
, last_error_params
864 last_error_params
= base_params
867 retval
= GoldschmidtDivParams
.cached_new(base_params
, handle_error
)
869 logging
.debug(f
"GoldschmidtDivParams.get: err: {base_params}")
871 logging
.debug(f
"GoldschmidtDivParams.get: ok: {base_params}")
874 @lru_cache(maxsize
=None)
875 def get_cost(base_params
):
876 params
= cached_new(base_params
)
879 retval
= cost_fn(params
)
880 logging
.debug(f
"GoldschmidtDivParams.get: cost={retval}: {params}")
883 # start with parameters big enough to always work.
884 initial_extra_precision
= io_width
* 2 + 4
885 initial_params
= GoldschmidtDivParamsBase(
887 extra_precision
=initial_extra_precision
,
888 table_addr_bits
=min(max_table_addr_bits
, io_width
),
889 table_data_bits
=io_width
+ initial_extra_precision
,
890 iter_count
=1 + io_width
.bit_length())
892 if cached_new(initial_params
) is None:
893 raise ValueError(f
"initial goldschmidt division algorithm "
894 f
"parameters are invalid: {initial_params}"
897 # find good initial `iter_count`
898 params
= initial_params
899 for iter_count
in range(1, initial_params
.iter_count
):
900 trial_params
= replace(params
, iter_count
=iter_count
)
901 if cached_new(trial_params
) is not None:
902 params
= trial_params
905 # now find `table_addr_bits`
906 cost
= get_cost(params
)
907 for table_addr_bits
in range(1, max_table_addr_bits
):
908 trial_params
= replace(params
, table_addr_bits
=table_addr_bits
)
909 trial_cost
= get_cost(trial_params
)
910 if trial_cost
< cost
:
911 params
= trial_params
915 # check one higher `iter_count` to see if it has lower cost
916 for table_addr_bits
in range(1, max_table_addr_bits
+ 1):
917 trial_params
= replace(params
,
918 table_addr_bits
=table_addr_bits
,
919 iter_count
=params
.iter_count
+ 1)
920 trial_cost
= get_cost(trial_params
)
921 if trial_cost
< cost
:
922 params
= trial_params
926 # now shrink `table_data_bits`
928 trial_params
= replace(params
,
929 table_data_bits
=params
.table_data_bits
- 1)
930 trial_cost
= get_cost(trial_params
)
931 if trial_cost
< cost
:
932 params
= trial_params
937 # and shrink `extra_precision`
939 trial_params
= replace(params
,
940 extra_precision
=params
.extra_precision
- 1)
941 trial_cost
= get_cost(trial_params
)
942 if trial_cost
< cost
:
943 params
= trial_params
948 retval
= cached_new(params
)
949 assert isinstance(retval
, GoldschmidtDivParams
)
954 """count leading zeros -- handy for debugging."""
955 assert isinstance(wid
, int)
956 assert isinstance(v
, int) and 0 <= v
< (1 << wid
)
957 return (1 << wid
).bit_length() - v
.bit_length()
961 class GoldschmidtDivOp(enum
.Enum
):
962 Normalize
= "n, d, n_shift = normalize(n, d)"
963 FEqTableLookup
= "f = table_lookup(d)"
966 FEq2MinusD
= "f = 2 - d"
967 CalcResult
= "result = unnormalize_and_round(n)"
969 def run(self
, params
, state
):
970 assert isinstance(params
, GoldschmidtDivParams
)
971 assert isinstance(state
, GoldschmidtDivState
)
972 expanded_width
= params
.expanded_width
973 table_addr_bits
= params
.table_addr_bits
974 if self
== GoldschmidtDivOp
.Normalize
:
975 # normalize so 1 <= d < 2
976 # can easily be done with count-leading-zeros and left shift
978 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
979 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
982 # normalize so 1 <= n < 2
984 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
,
985 round_dir
=RoundDir
.DOWN
)
987 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
988 # compute initial f by table lookup
990 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
991 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
992 state
.f
= params
.table
[d_m_1
.bits
]
993 state
.f
= state
.f
.to_frac_wid(expanded_width
,
994 round_dir
=RoundDir
.DOWN
)
995 elif self
== GoldschmidtDivOp
.MulNByF
:
996 assert state
.f
is not None
997 n
= state
.n
* state
.f
998 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
999 elif self
== GoldschmidtDivOp
.MulDByF
:
1000 assert state
.f
is not None
1001 d
= state
.d
* state
.f
1002 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
1003 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
1004 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
1005 elif self
== GoldschmidtDivOp
.CalcResult
:
1006 assert state
.n_shift
is not None
1007 # scale to correct value
1008 n
= state
.n
* (1 << state
.n_shift
)
1010 state
.quotient
= math
.floor(n
)
1011 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
1012 if state
.remainder
>= state
.orig_d
:
1014 state
.remainder
-= state
.orig_d
1016 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
1018 def gen_hdl(self
, params
, state
, sync_rom
):
1019 """generate the hdl for this operation.
1022 params: GoldschmidtDivParams
1023 the goldschmidt division parameters.
1024 state: GoldschmidtDivHDLState
1025 the input/output state
1027 true if the rom should be read synchronously rather than
1028 combinatorially, incurring an extra clock cycle of latency.
1030 assert isinstance(params
, GoldschmidtDivParams
)
1031 assert isinstance(state
, GoldschmidtDivHDLState
)
1033 if self
== GoldschmidtDivOp
.Normalize
:
1034 # normalize so 1 <= d < 2
1035 assert state
.d
.width
== params
.io_width
1036 assert state
.n
.width
== 2 * params
.io_width
1037 d_leading_zeros
= CLZ(params
.io_width
)
1038 m
.submodules
.d_leading_zeros
= d_leading_zeros
1039 m
.d
.comb
+= d_leading_zeros
.sig_in
.eq(state
.d
)
1040 d_shift_out
= Signal
.like(state
.d
)
1041 m
.d
.comb
+= d_shift_out
.eq(state
.d
<< d_leading_zeros
.lz
)
1042 d
= Signal(params
.n_d_f_total_wid
)
1043 m
.d
.comb
+= d
.eq((d_shift_out
<< (1 + params
.expanded_width
))
1046 # normalize so 1 <= n < 2
1047 n_leading_zeros
= CLZ(2 * params
.io_width
)
1048 m
.submodules
.n_leading_zeros
= n_leading_zeros
1049 m
.d
.comb
+= n_leading_zeros
.sig_in
.eq(state
.n
)
1050 signed_zero
= Const(0, signed(1)) # force subtraction to be signed
1051 n_shift_s_v
= (params
.io_width
+ signed_zero
+ d_leading_zeros
.lz
1052 - n_leading_zeros
.lz
)
1053 n_shift_s
= Signal
.like(n_shift_s_v
)
1054 n_shift_n_lz_out
= Signal
.like(state
.n
)
1055 n_shift_d_lz_out
= Signal
.like(state
.n
<< d_leading_zeros
.lz
)
1057 n_shift_s
.eq(n_shift_s_v
),
1058 n_shift_d_lz_out
.eq(state
.n
<< d_leading_zeros
.lz
),
1059 n_shift_n_lz_out
.eq(state
.n
<< n_leading_zeros
.lz
),
1061 state
.n_shift
= Signal(d_leading_zeros
.lz
.width
)
1062 n
= Signal(params
.n_d_f_total_wid
)
1063 with m
.If(n_shift_s
< 0):
1065 state
.n_shift
.eq(0),
1066 n
.eq((n_shift_d_lz_out
<< (1 + params
.expanded_width
))
1071 state
.n_shift
.eq(n_shift_s
),
1072 n
.eq((n_shift_n_lz_out
<< (1 + params
.expanded_width
))
1077 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
1078 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1079 # compute initial f by table lookup
1081 # extra bit for table entries == 1.0
1082 table_width
= 1 + params
.table_data_bits
1083 table
= Memory(width
=table_width
, depth
=len(params
.table
),
1084 init
=[i
.bits
for i
in params
.table
])
1085 addr
= state
.d
[:-params
.n_d_f_int_wid
][-params
.table_addr_bits
:]
1087 table_read
= table
.read_port()
1088 m
.d
.comb
+= table_read
.addr
.eq(addr
)
1089 state
.insert_pipeline_register()
1091 table_read
= table
.read_port(domain
="comb")
1092 m
.d
.comb
+= table_read
.addr
.eq(addr
)
1093 m
.submodules
.table_read
= table_read
1094 state
.f
= Signal(params
.n_d_f_int_wid
+ params
.expanded_width
)
1095 data_shift
= params
.expanded_width
- params
.table_data_bits
1096 m
.d
.comb
+= state
.f
.eq(table_read
.data
<< data_shift
)
1097 elif self
== GoldschmidtDivOp
.MulNByF
:
1098 assert state
.n
.width
== params
.n_d_f_total_wid
, "invalid n width"
1099 assert state
.f
is not None
1100 assert state
.f
.width
== params
.n_d_f_total_wid
, "invalid f width"
1101 n
= Signal
.like(state
.n
)
1102 m
.d
.comb
+= n
.eq((state
.n
* state
.f
) >> params
.expanded_width
)
1104 elif self
== GoldschmidtDivOp
.MulDByF
:
1105 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1106 assert state
.f
is not None
1107 assert state
.f
.width
== params
.n_d_f_total_wid
, "invalid f width"
1108 d
= Signal
.like(state
.d
)
1109 d_times_f
= Signal
.like(state
.d
* state
.f
)
1111 d_times_f
.eq(state
.d
* state
.f
),
1112 # round the multiplication up
1113 d
.eq((d_times_f
>> params
.expanded_width
)
1114 + (d_times_f
[:params
.expanded_width
] != 0)),
1117 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
1118 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1119 f
= Signal
.like(state
.d
)
1120 m
.d
.comb
+= f
.eq((2 << params
.expanded_width
) - state
.d
)
1122 elif self
== GoldschmidtDivOp
.CalcResult
:
1123 assert state
.n
.width
== params
.n_d_f_total_wid
, "invalid n width"
1124 assert state
.n_shift
is not None
1125 # scale to correct value
1126 n
= state
.n
* (1 << state
.n_shift
)
1127 q_approx
= Signal(params
.io_width
)
1128 # extra bit for if it's bigger than orig_d
1129 r_approx
= Signal(params
.io_width
+ 1)
1130 adjusted_r
= Signal(signed(1 + params
.io_width
))
1132 q_approx
.eq((state
.n
<< state
.n_shift
)
1133 >> params
.expanded_width
),
1134 r_approx
.eq(state
.orig_n
- q_approx
* state
.orig_d
),
1135 adjusted_r
.eq(r_approx
- state
.orig_d
),
1137 state
.quotient
= Signal(params
.io_width
)
1138 state
.remainder
= Signal(params
.io_width
)
1140 with m
.If(adjusted_r
>= 0):
1142 state
.quotient
.eq(q_approx
+ 1),
1143 state
.remainder
.eq(adjusted_r
),
1147 state
.quotient
.eq(q_approx
),
1148 state
.remainder
.eq(r_approx
),
1151 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
1154 @plain_data(repr=False)
1155 class GoldschmidtDivState
:
1156 __slots__
= ("orig_n", "orig_d", "n", "d",
1157 "f", "quotient", "remainder", "n_shift")
1159 def __init__(self
, orig_n
, orig_d
, n
, d
,
1160 f
=None, quotient
=None, remainder
=None, n_shift
=None):
1161 assert isinstance(orig_n
, int)
1162 assert isinstance(orig_d
, int)
1163 assert isinstance(n
, FixedPoint
)
1164 assert isinstance(d
, FixedPoint
)
1165 assert f
is None or isinstance(f
, FixedPoint
)
1166 assert quotient
is None or isinstance(quotient
, int)
1167 assert remainder
is None or isinstance(remainder
, int)
1168 assert n_shift
is None or isinstance(n_shift
, int)
1169 self
.orig_n
= orig_n
1170 """original numerator"""
1172 self
.orig_d
= orig_d
1173 """original denominator"""
1176 """numerator -- N_prime[i] in the paper's algorithm 2"""
1179 """denominator -- D_prime[i] in the paper's algorithm 2"""
1182 """current factor -- F_prime[i] in the paper's algorithm 2"""
1184 self
.quotient
= quotient
1185 """final quotient"""
1187 self
.remainder
= remainder
1188 """final remainder"""
1190 self
.n_shift
= n_shift
1191 """amount the numerator needs to be left-shifted at the end of the
1197 for field
in fields(GoldschmidtDivState
):
1198 value
= getattr(self
, field
)
1201 if isinstance(value
, int) and field
!= "n_shift":
1202 fields_str
.append(f
"{field}={hex(value)}")
1204 fields_str
.append(f
"{field}={value!r}")
1205 return f
"GoldschmidtDivState({', '.join(fields_str)})"
1208 def goldschmidt_div(n
, d
, params
, trace
=lambda state
: None):
1209 """ Goldschmidt division algorithm.
1212 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1213 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1214 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1218 numerator. a `2*width`-bit unsigned integer.
1219 must be less than `d << width`, otherwise the quotient wouldn't
1220 fit in `width` bits.
1222 denominator. a `width`-bit unsigned integer. must not be zero.
1224 the bit-width of the inputs/outputs. must be a positive integer.
1225 trace: Function[[GoldschmidtDivState], None]
1226 called with the initial state and the state after executing each
1227 operation in `params.ops`.
1229 returns: tuple[int, int]
1230 the quotient and remainder. a tuple of two `width`-bit unsigned
1233 assert isinstance(params
, GoldschmidtDivParams
)
1234 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
1235 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
1237 # this whole algorithm is done with fixed-point arithmetic where values
1238 # have `width` fractional bits
1240 state
= GoldschmidtDivState(
1243 n
=FixedPoint(n
, params
.io_width
),
1244 d
=FixedPoint(d
, params
.io_width
),
1248 for op
in params
.ops
:
1249 op
.run(params
, state
)
1252 assert state
.quotient
is not None
1253 assert state
.remainder
is not None
1255 return state
.quotient
, state
.remainder
1258 @plain_data(eq
=False)
1259 class GoldschmidtDivHDLState
:
1260 __slots__
= ("m", "orig_n", "orig_d", "n", "d",
1261 "f", "quotient", "remainder", "n_shift")
1263 __signal_name_prefix
= "state_"
1265 def __init__(self
, m
, orig_n
, orig_d
, n
, d
,
1266 f
=None, quotient
=None, remainder
=None, n_shift
=None):
1267 assert isinstance(m
, Module
)
1268 assert isinstance(orig_n
, Signal
)
1269 assert isinstance(orig_d
, Signal
)
1270 assert isinstance(n
, Signal
)
1271 assert isinstance(d
, Signal
)
1272 assert f
is None or isinstance(f
, Signal
)
1273 assert quotient
is None or isinstance(quotient
, Signal
)
1274 assert remainder
is None or isinstance(remainder
, Signal
)
1275 assert n_shift
is None or isinstance(n_shift
, Signal
)
1278 """The HDL Module"""
1280 self
.orig_n
= orig_n
1281 """original numerator"""
1283 self
.orig_d
= orig_d
1284 """original denominator"""
1287 """numerator -- N_prime[i] in the paper's algorithm 2"""
1290 """denominator -- D_prime[i] in the paper's algorithm 2"""
1293 """current factor -- F_prime[i] in the paper's algorithm 2"""
1295 self
.quotient
= quotient
1296 """final quotient"""
1298 self
.remainder
= remainder
1299 """final remainder"""
1301 self
.n_shift
= n_shift
1302 """amount the numerator needs to be left-shifted at the end of the
1306 # old_signals must be set last
1307 self
.old_signals
= defaultdict(list)
1309 def __setattr__(self
, name
, value
):
1310 assert isinstance(name
, str)
1311 if name
.startswith("_"):
1312 return super().__setattr
__(name
, value
)
1314 old_signals
= self
.old_signals
[name
]
1315 except AttributeError:
1316 # haven't yet finished __post_init__
1317 return super().__setattr
__(name
, value
)
1318 assert name
!= "m" and name
!= "old_signals", f
"can't write to {name}"
1319 assert isinstance(value
, Signal
)
1320 value
.name
= f
"{self.__signal_name_prefix}{name}_{len(old_signals)}"
1321 old_signal
= getattr(self
, name
, None)
1322 if old_signal
is not None:
1323 assert isinstance(old_signal
, Signal
)
1324 old_signals
.append(old_signal
)
1325 return super().__setattr
__(name
, value
)
1327 def insert_pipeline_register(self
):
1328 old_prefix
= self
.__signal
_name
_prefix
1330 for field
in fields(GoldschmidtDivHDLState
):
1331 if field
.startswith("_") or field
== "m":
1333 old_sig
= getattr(self
, field
, None)
1336 assert isinstance(old_sig
, Signal
)
1337 new_sig
= Signal
.like(old_sig
)
1338 setattr(self
, field
, new_sig
)
1339 self
.m
.d
.sync
+= new_sig
.eq(old_sig
)
1341 self
.__signal
_name
_prefix
= old_prefix
1344 class GoldschmidtDivHDL(Elaboratable
):
1345 """ Goldschmidt division algorithm.
1348 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1349 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1350 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1353 params: GoldschmidtDivParams
1354 the goldschmidt division algorithm parameters.
1355 pipe_reg_indexes: list[int]
1356 the operation indexes where pipeline registers should be inserted.
1357 duplicate values mean multiple registers should be inserted for
1358 that operation index -- this is useful to allow yosys to spread a
1359 multiplication across those multiple pipeline stages.
1361 true if the rom should be read synchronously rather than
1362 combinatorially, incurring an extra clock cycle of latency.
1363 n: Signal(unsigned(2 * params.io_width))
1364 input numerator. a `2 * params.io_width`-bit unsigned integer.
1365 must be less than `d << params.io_width`, otherwise the quotient
1366 wouldn't fit in `params.io_width` bits.
1367 d: Signal(unsigned(params.io_width))
1368 input denominator. a `params.io_width`-bit unsigned integer.
1370 q: Signal(unsigned(params.io_width))
1371 output quotient. only valid when `n < (d << params.io_width)`.
1372 r: Signal(unsigned(params.io_width))
1373 output remainder. only valid when `n < (d << params.io_width)`.
1374 trace: list[GoldschmidtDivHDLState]
1375 list of the initial state and the state after executing each
1376 operation in `params.ops`.
1380 def total_pipeline_registers(self
):
1381 """the total number of pipeline registers"""
1382 return len(self
.pipe_reg_indexes
) + self
.sync_rom
1384 def __init__(self
, params
, pipe_reg_indexes
=(), sync_rom
=False):
1385 assert isinstance(params
, GoldschmidtDivParams
)
1386 assert isinstance(sync_rom
, bool)
1387 self
.params
= params
1388 self
.pipe_reg_indexes
= sorted(int(i
) for i
in pipe_reg_indexes
)
1389 self
.sync_rom
= sync_rom
1390 self
.n
= Signal(unsigned(2 * params
.io_width
))
1391 self
.d
= Signal(unsigned(params
.io_width
))
1392 self
.q
= Signal(unsigned(params
.io_width
))
1393 self
.r
= Signal(unsigned(params
.io_width
))
1395 # in constructor so we get trace without needing to call elaborate
1396 state
= GoldschmidtDivHDLState(
1403 self
.trace
= [replace(state
)]
1406 pipe_reg_indexes
= list(reversed(self
.pipe_reg_indexes
))
1408 for op_index
, op
in enumerate(self
.params
.ops
):
1409 while len(pipe_reg_indexes
) > 0 \
1410 and pipe_reg_indexes
[-1] <= op_index
:
1411 pipe_reg_indexes
.pop()
1412 state
.insert_pipeline_register()
1413 op
.gen_hdl(self
.params
, state
, self
.sync_rom
)
1414 self
.trace
.append(replace(state
))
1416 while len(pipe_reg_indexes
) > 0:
1417 pipe_reg_indexes
.pop()
1418 state
.insert_pipeline_register()
1421 self
.q
.eq(state
.quotient
),
1422 self
.r
.eq(state
.remainder
),
1425 def elaborate(self
, platform
):
1426 return self
.trace
[0].m
1429 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
= 2
1433 def goldschmidt_sqrt_rsqrt_table(table_addr_bits
, table_data_bits
):
1434 """Generate the look-up table needed for Goldschmidt's square-root and
1435 reciprocal-square-root algorithm.
1438 table_addr_bits: int
1439 the number of address bits for the look-up table.
1440 table_data_bits: int
1441 the number of data bits for the look-up table.
1443 assert isinstance(table_addr_bits
, int) and \
1444 table_addr_bits
>= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1445 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1447 table_len
= 1 << table_addr_bits
1448 for addr
in range(table_len
):
1450 value
= FixedPoint(0, table_data_bits
)
1451 elif (addr
<< 2) < table_len
:
1452 value
= None # table entries should be unused
1454 table_addr_frac_wid
= table_addr_bits
1455 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1456 max_input_value
= FixedPoint(addr
+ 1, table_addr_bits
- 2)
1457 max_frac_wid
= max(max_input_value
.frac_wid
, table_data_bits
)
1458 value
= max_input_value
.to_frac_wid(max_frac_wid
)
1459 value
= value
.rsqrt(RoundDir
.DOWN
)
1460 value
= value
.to_frac_wid(table_data_bits
, RoundDir
.DOWN
)
1463 # tuple for immutability
1466 # FIXME: add code to calculate error bounds and check that the algorithm will
1467 # actually work (like in the goldschmidt division algorithm).
1468 # FIXME: add code to calculate a good set of parameters based on the error
1472 def goldschmidt_sqrt_rsqrt(radicand
, io_width
, frac_wid
, extra_precision
,
1473 table_addr_bits
, table_data_bits
, iter_count
):
1474 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1476 uses algorithm based on second method at:
1477 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1480 radicand: FixedPoint(frac_wid=frac_wid)
1481 the input value to take the square-root and reciprocal-square-root of.
1483 the number of bits in the input (`radicand`) and output values.
1485 the number of fraction bits in the input (`radicand`) and output
1487 extra_precision: int
1488 the number of bits of internal extra precision.
1489 table_addr_bits: int
1490 the number of address bits for the look-up table.
1491 table_data_bits: int
1492 the number of data bits for the look-up table.
1494 returns: tuple[FixedPoint, FixedPoint]
1495 the square-root and reciprocal-square-root, rounded down to the
1496 nearest representable value. If `radicand == 0`, then the
1497 reciprocal-square-root value returned is zero.
1499 assert (isinstance(radicand
, FixedPoint
)
1500 and radicand
.frac_wid
== frac_wid
1501 and 0 <= radicand
.bits
< (1 << io_width
))
1502 assert isinstance(io_width
, int) and io_width
>= 1
1503 assert isinstance(frac_wid
, int) and 0 <= frac_wid
< io_width
1504 assert isinstance(extra_precision
, int) and extra_precision
>= io_width
1505 assert isinstance(table_addr_bits
, int) and table_addr_bits
>= 1
1506 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1507 assert isinstance(iter_count
, int) and iter_count
>= 0
1508 expanded_frac_wid
= frac_wid
+ extra_precision
1509 s
= radicand
.to_frac_wid(expanded_frac_wid
)
1510 sqrt_rshift
= extra_precision
1511 rsqrt_rshift
= extra_precision
1512 while s
!= 0 and s
< 1:
1513 s
= (s
* 4).to_frac_wid(expanded_frac_wid
)
1517 s
= s
.div(4, expanded_frac_wid
)
1520 table
= goldschmidt_sqrt_rsqrt_table(table_addr_bits
=table_addr_bits
,
1521 table_data_bits
=table_data_bits
)
1522 # core goldschmidt sqrt/rsqrt algorithm:
1524 table_addr_frac_wid
= table_addr_bits
1525 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1526 addr
= s
.to_frac_wid(table_addr_frac_wid
, RoundDir
.DOWN
)
1527 assert 0 <= addr
.bits
< (1 << table_addr_bits
), "table addr out of range"
1528 f
= table
[addr
.bits
]
1529 assert f
is not None, "accessed invalid table entry"
1530 # use with_frac_wid to fix IDE type deduction
1531 f
= FixedPoint
.with_frac_wid(f
, expanded_frac_wid
, RoundDir
.DOWN
)
1532 x
= (s
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1533 h
= (f
* 0.5).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1534 for _
in range(iter_count
):
1536 f
= (1.5 - x
* h
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1537 x
= (x
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1538 h
= (h
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1540 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1542 sqrt
= FixedPoint(x
.bits
>> sqrt_rshift
, frac_wid
)
1543 rsqrt
= FixedPoint(r
.bits
>> rsqrt_rshift
, frac_wid
)
1545 next_sqrt
= FixedPoint(sqrt
.bits
+ 1, frac_wid
)
1546 if next_sqrt
* next_sqrt
<= radicand
:
1549 next_rsqrt
= FixedPoint(rsqrt
.bits
+ 1, frac_wid
)
1550 if next_rsqrt
* next_rsqrt
* radicand
<= 1 and radicand
!= 0: