1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
, fields
, replace
11 from fractions
import Fraction
12 from types
import FunctionType
13 from functools
import lru_cache
16 from functools
import cached_property
18 from cached_property
import cached_property
20 # fix broken IDE type detection for cached_property
21 from typing
import TYPE_CHECKING
, Any
23 from functools
import cached_property
29 def cache_on_self(func
):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
34 assert isinstance(func
, FunctionType
), \
35 "non-plain methods are not supported"
37 cache_name
= func
.__name
__ + "__cache"
39 def wrapper(self
, *args
, **kwargs
):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
42 if cache
is _NOT_FOUND
:
43 self
.__dict
__[cache_name
] = cache
= {}
44 key
= (args
, *kwargs
.items())
45 retval
= cache
.get(key
, _NOT_FOUND
)
46 if retval
is _NOT_FOUND
:
47 retval
= func(self
, *args
, **kwargs
)
51 wrapper
.__doc
__ = func
.__doc
__
56 class RoundDir(enum
.Enum
):
59 NEAREST_TIES_UP
= enum
.auto()
60 ERROR_IF_INEXACT
= enum
.auto()
63 @dataclass(frozen
=True)
68 def __post_init__(self
):
69 assert isinstance(self
.bits
, int)
70 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value
, FixedPoint
):
78 if isinstance(value
, int):
79 return FixedPoint(value
, 0)
80 if isinstance(value
, str):
82 neg
= value
.startswith("-")
83 if neg
or value
.startswith("+"):
85 if value
.startswith(("0x", "0X")) and "." in value
:
95 raise ValueError("too many `.` in string")
100 if not digit
.isalnum():
101 raise ValueError("invalid hexadecimal digit")
103 bits |
= int("0x" + digit
, base
=16)
105 bits
= int(value
, base
=0)
109 return FixedPoint(bits
, frac_wid
)
111 if isinstance(value
, float):
112 n
, d
= value
.as_integer_ratio()
113 log2_d
= d
.bit_length() - 1
114 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n
, log2_d
)
117 raise TypeError("can't convert type to FixedPoint")
120 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid
, int) and frac_wid
>= 0
124 assert isinstance(round_dir
, RoundDir
)
125 if isinstance(value
, Fraction
):
126 numerator
= value
.numerator
127 denominator
= value
.denominator
129 value
= FixedPoint
.cast(value
)
130 numerator
= value
.bits
131 denominator
= 1 << value
.frac_wid
133 numerator
= -numerator
134 denominator
= -denominator
135 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
136 if round_dir
== RoundDir
.DOWN
:
138 elif round_dir
== RoundDir
.UP
:
141 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
142 if remainder
* 2 >= denominator
:
144 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
146 raise ValueError("inexact conversion")
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits
, frac_wid
)
151 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self
.bits
/ (1 << self
.frac_wid
))
161 def as_fraction(self
):
162 return Fraction(self
.bits
, 1 << self
.frac_wid
)
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs
= FixedPoint
.cast(rhs
)
169 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
170 lhs
= self
.to_frac_wid(common_frac_wid
)
171 rhs
= rhs
.to_frac_wid(common_frac_wid
)
172 return lhs
.bits
- rhs
.bits
174 def __eq__(self
, rhs
):
175 return self
.cmp(rhs
) == 0
177 def __ne__(self
, rhs
):
178 return self
.cmp(rhs
) != 0
180 def __gt__(self
, rhs
):
181 return self
.cmp(rhs
) > 0
183 def __lt__(self
, rhs
):
184 return self
.cmp(rhs
) < 0
186 def __ge__(self
, rhs
):
187 return self
.cmp(rhs
) >= 0
189 def __le__(self
, rhs
):
190 return self
.cmp(rhs
) <= 0
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
196 fract_mask
= (1 << self
.frac_wid
) - 1
197 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
201 return "-" + str(-self
)
203 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
204 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
205 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
206 return hex(math
.floor(self
)) + "." + frac_str
209 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
211 def __add__(self
, rhs
):
212 rhs
= FixedPoint
.cast(rhs
)
213 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
214 lhs
= self
.to_frac_wid(common_frac_wid
)
215 rhs
= rhs
.to_frac_wid(common_frac_wid
)
216 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
218 def __radd__(self
, lhs
):
220 return self
.__add
__(lhs
)
223 return FixedPoint(-self
.bits
, self
.frac_wid
)
225 def __sub__(self
, rhs
):
226 rhs
= FixedPoint
.cast(rhs
)
227 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
228 lhs
= self
.to_frac_wid(common_frac_wid
)
229 rhs
= rhs
.to_frac_wid(common_frac_wid
)
230 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
232 def __rsub__(self
, lhs
):
234 return -self
.__sub
__(lhs
)
236 def __mul__(self
, rhs
):
237 rhs
= FixedPoint
.cast(rhs
)
238 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
240 def __rmul__(self
, lhs
):
242 return self
.__mul
__(lhs
)
245 return self
.bits
>> self
.frac_wid
249 class GoldschmidtDivState
:
251 """original numerator"""
254 """original denominator"""
257 """numerator -- N_prime[i] in the paper's algorithm 2"""
260 """denominator -- D_prime[i] in the paper's algorithm 2"""
262 f
: "FixedPoint | None" = None
263 """current factor -- F_prime[i] in the paper's algorithm 2"""
265 quotient
: "int | None" = None
268 remainder
: "int | None" = None
269 """final remainder"""
271 n_shift
: "int | None" = None
272 """amount the numerator needs to be left-shifted at the end of the
277 class ParamsNotAccurateEnough(Exception):
278 """raised when the parameters aren't accurate enough to have goldschmidt
282 def _assert_accuracy(condition
, msg
="not accurate enough"):
285 raise ParamsNotAccurateEnough(msg
)
288 @dataclass(frozen
=True, unsafe_hash
=True)
289 class GoldschmidtDivParamsBase
:
290 """parameters for a Goldschmidt division algorithm, excluding derived
295 """bit-width of the input divisor and the result.
296 the input numerator is `2 * io_width`-bits wide.
300 """number of bits of additional precision used inside the algorithm."""
303 """the number of address bits used in the lookup-table."""
306 """the number of data bits used in the lookup-table."""
309 """the total number of iterations of the division algorithm's loop"""
312 @dataclass(frozen
=True, unsafe_hash
=True)
313 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
314 """parameters for a Goldschmidt division algorithm.
315 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
318 # tuple to be immutable, repr=False so repr() works for debugging even when
319 # __post_init__ hasn't finished running yet
320 table
: "tuple[FixedPoint, ...]" = field(init
=False, repr=False)
321 """the lookup-table"""
323 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False, repr=False)
324 """the operations needed to perform the goldschmidt division algorithm."""
326 def _shrink_bound(self
, bound
, round_dir
):
327 """prevent fractions from having huge numerators/denominators by
328 rounding to a `FixedPoint` and converting back to a `Fraction`.
330 This is intended only for values used to compute bounds, and not for
331 values that end up in the hardware.
333 assert isinstance(bound
, (Fraction
, int))
334 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
335 "you shouldn't use that round_dir on bounds"
336 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
337 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
338 return fixed
.as_fraction()
340 def _shrink_min(self
, min_bound
):
341 """prevent fractions used as minimum bounds from having huge
342 numerators/denominators by rounding down to a `FixedPoint` and
343 converting back to a `Fraction`.
345 This is intended only for values used to compute bounds, and not for
346 values that end up in the hardware.
348 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
350 def _shrink_max(self
, max_bound
):
351 """prevent fractions used as maximum bounds from having huge
352 numerators/denominators by rounding up to a `FixedPoint` and
353 converting back to a `Fraction`.
355 This is intended only for values used to compute bounds, and not for
356 values that end up in the hardware.
358 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
361 def table_addr_count(self
):
362 """number of distinct addresses in the lookup-table."""
363 # used while computing self.table, so can't just do len(self.table)
364 return 1 << self
.table_addr_bits
366 def table_input_exact_range(self
, addr
):
367 """return the range of inputs as `Fraction`s used for the table entry
368 with address `addr`."""
369 assert isinstance(addr
, int)
370 assert 0 <= addr
< self
.table_addr_count
371 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
372 addr_shift
= self
.io_width
- self
.table_addr_bits
373 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
374 denominator
= 1 << self
.io_width
375 values_per_table_entry
= 1 << addr_shift
376 max_numerator
= min_numerator
+ values_per_table_entry
- 1
377 min_input
= Fraction(min_numerator
, denominator
)
378 max_input
= Fraction(max_numerator
, denominator
)
379 min_input
= self
._shrink
_min
(min_input
)
380 max_input
= self
._shrink
_max
(max_input
)
381 assert 1 <= min_input
<= max_input
< 2
382 return min_input
, max_input
384 def table_value_exact_range(self
, addr
):
385 """return the range of values as `Fraction`s used for the table entry
386 with address `addr`."""
387 min_input
, max_input
= self
.table_input_exact_range(addr
)
388 # division swaps min/max
389 min_value
= 1 / max_input
390 max_value
= 1 / min_input
391 min_value
= self
._shrink
_min
(min_value
)
392 max_value
= self
._shrink
_max
(max_value
)
393 assert 0.5 < min_value
<= max_value
<= 1
394 return min_value
, max_value
396 def table_exact_value(self
, index
):
397 min_value
, max_value
= self
.table_value_exact_range(index
)
401 def __post_init__(self
):
402 # called by the autogenerated __init__
403 _assert_accuracy(self
.io_width
>= 1, "io_width out of range")
404 _assert_accuracy(self
.extra_precision
>= 0,
405 "extra_precision out of range")
406 _assert_accuracy(self
.table_addr_bits
>= 1,
407 "table_addr_bits out of range")
408 _assert_accuracy(self
.table_data_bits
>= 1,
409 "table_data_bits out of range")
410 _assert_accuracy(self
.iter_count
>= 1, "iter_count out of range")
412 for addr
in range(1 << self
.table_addr_bits
):
413 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
414 self
.table_data_bits
,
416 # we have to use object.__setattr__ since frozen=True
417 object.__setattr
__(self
, "table", tuple(table
))
418 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
421 def expanded_width(self
):
422 """the total number of bits of precision used inside the algorithm."""
423 return self
.io_width
+ self
.extra_precision
426 def max_neps(self
, i
):
427 """maximum value of `neps[i]`.
428 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
430 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
431 return Fraction(1, 1 << self
.expanded_width
)
434 def max_deps(self
, i
):
435 """maximum value of `deps[i]`.
436 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
438 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
439 return Fraction(1, 1 << self
.expanded_width
)
442 def max_feps(self
, i
):
443 """maximum value of `feps[i]`.
444 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
446 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
447 # zero, because the computation of `F_prime[i]` in
448 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
453 """minimum and maximum values of `e[0]`
454 (the relative error in `F_prime[-1]`)
458 for addr
in range(self
.table_addr_count
):
459 # `F_prime[-1] = (1 - e[0]) / B`
460 # => `e[0] = 1 - B * F_prime[-1]`
461 min_b
, max_b
= self
.table_input_exact_range(addr
)
462 f_prime_m1
= self
.table
[addr
].as_fraction()
463 assert min_b
>= 0 and f_prime_m1
>= 0, \
464 "only positive quadrant of interval multiplication implemented"
465 min_product
= min_b
* f_prime_m1
466 max_product
= max_b
* f_prime_m1
467 # negation swaps min/max
468 cur_min_e0
= 1 - max_product
469 cur_max_e0
= 1 - min_product
470 min_e0
= min(min_e0
, cur_min_e0
)
471 max_e0
= max(max_e0
, cur_max_e0
)
472 min_e0
= self
._shrink
_min
(min_e0
)
473 max_e0
= self
._shrink
_max
(max_e0
)
474 return min_e0
, max_e0
478 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
480 min_e0
, max_e0
= self
.e0_range
485 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
487 min_e0
, max_e0
= self
.e0_range
491 def max_abs_e0(self
):
492 """maximum value of `abs(e[0])`."""
493 return max(abs(self
.min_e0
), abs(self
.max_e0
))
496 def min_abs_e0(self
):
497 """minimum value of `abs(e[0])`."""
502 """maximum value of `n[i]` (the relative error in `N_prime[i]`
503 relative to the previous iteration)
505 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
508 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
509 # `n[0] <= 2 * neps[0] / (1 - e[0])`
511 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
512 "only one quadrant of interval division implemented"
513 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
516 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
517 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
518 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
519 "only one quadrant of interval multiplication implemented"
520 prod
= (1 - self
.max_f(0)) * min_mpd
521 assert self
.max_neps(1) >= 0 and prod
> 0, \
522 "only one quadrant of interval division implemented"
523 retval
= self
.max_neps(1) / prod
526 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
527 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
528 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
529 "only one quadrant of interval division implemented"
530 retval
= self
.max_neps(i
) / min_mpd
532 return self
._shrink
_max
(retval
)
536 """maximum value of `d[i]` (the relative error in `D_prime[i]`
537 relative to the previous iteration)
539 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
542 # `d[0] = deps[0] / (1 - e[0])`
544 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
545 "only one quadrant of interval division implemented"
546 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
549 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
550 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
551 "only one quadrant of interval multiplication implemented"
552 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
553 assert self
.max_deps(1) >= 0 and divisor
> 0, \
554 "only one quadrant of interval division implemented"
555 retval
= self
.max_deps(1) / divisor
558 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
559 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
560 "only one quadrant of interval division implemented"
561 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
563 return self
._shrink
_max
(retval
)
567 """maximum value of `f[i]` (the relative error in `F_prime[i]`
568 relative to the previous iteration)
570 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
573 # `f[0] = feps[0] / (1 - delta[0])`
575 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
576 "only one quadrant of interval division implemented"
577 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
581 retval
= self
.max_feps(1)
584 # `f[i] <= max_feps[i]`
585 retval
= self
.max_feps(i
)
587 return self
._shrink
_max
(retval
)
590 def max_delta(self
, i
):
591 """ maximum value of `delta[i]`.
592 `delta[i]` is defined in Definition 4 of paper.
594 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
596 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
597 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
599 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
600 prev_max_delta
= self
.max_delta(i
- 1)
601 assert prev_max_delta
>= 0
602 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
604 # `delta[i]` has to be smaller than one otherwise errors would go off
606 _assert_accuracy(retval
< 1)
608 return self
._shrink
_max
(retval
)
612 """ maximum value of `pi[i]`.
613 `pi[i]` is defined right below Theorem 5 of paper.
615 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
616 # `pi[i] = 1 - (1 - n[i]) * prod`
617 # where `prod` is the product of,
618 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
619 min_prod
= Fraction(1)
621 max_n_j
= self
.max_n(j
)
622 max_d_j
= self
.max_d(j
)
623 assert max_n_j
<= 1 and max_d_j
> -1, \
624 "only one quadrant of interval division implemented"
625 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
626 max_n_i
= self
.max_n(i
)
627 assert max_n_i
<= 1 and min_prod
>= 0, \
628 "only one quadrant of interval multiplication implemented"
629 retval
= 1 - (1 - max_n_i
) * min_prod
630 return self
._shrink
_max
(retval
)
633 def max_n_shift(self
):
634 """ maximum value of `state.n_shift`.
636 # input numerator is `2*io_width`-bits
637 max_n
= (1 << (self
.io_width
* 2)) - 1
639 # normalize so 1 <= n < 2
647 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
650 for i
in range(self
.iter_count
):
651 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
652 return self
._shrink
_max
(n_hat
)
654 def __make_ops(self
):
655 """ Goldschmidt division algorithm.
658 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
659 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
660 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
662 yields: GoldschmidtDivOp
663 the operations needed to perform the division.
665 # establish assumptions of the paper's error analysis (section 3.1):
667 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
668 yield GoldschmidtDivOp
.Normalize
670 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
671 # the assumption is met by multipliers with > 4-bits precision
672 _assert_accuracy(self
.expanded_width
> 4)
674 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
675 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
676 + self
.max_f(0) < Fraction(1, 2))
678 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
679 # (B is the denominator)
681 for addr
in range(self
.table_addr_count
):
682 f_prime_m1
= self
.table
[addr
]
683 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
685 yield GoldschmidtDivOp
.FEqTableLookup
687 # we use Setting I (section 4.1 of the paper):
688 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
689 # the conditions on n_hat are satisfied by construction.
690 for i
in range(self
.iter_count
):
691 _assert_accuracy(self
.max_f(i
) == 0)
692 yield GoldschmidtDivOp
.MulNByF
693 if i
!= self
.iter_count
- 1:
694 yield GoldschmidtDivOp
.MulDByF
695 yield GoldschmidtDivOp
.FEq2MinusD
697 # relative approximation error `p(N_prime[i])`:
698 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
699 # `0 <= p(N_prime[i])`
700 # `p(N_prime[i]) <= (2 * i) * n_hat \`
701 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
702 i
= self
.iter_count
- 1 # last used `i`
703 # compute power manually to prevent huge intermediate values
704 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
706 power
= self
._shrink
_max
(power
* power
)
708 max_rel_error
= (2 * i
) * self
.n_hat
+ power
710 min_a_over_b
= Fraction(1, 2)
711 max_a_over_b
= Fraction(2)
712 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
713 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
715 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
716 f
"not accurate enough: max_rel_error={max_rel_error}"
717 f
" max_allowed_rel_error={max_allowed_rel_error}")
719 yield GoldschmidtDivOp
.CalcResult
722 def default_cost_fn(self
):
723 """ calculate the estimated cost on an arbitrary scale of implementing
724 goldschmidt division with the specified parameters. larger cost
725 values mean worse parameters.
727 This is the default cost function for `GoldschmidtDivParams.get`.
731 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
732 cost
= float(rom_cells
)
734 if op
== GoldschmidtDivOp
.MulNByF \
735 or op
== GoldschmidtDivOp
.MulDByF
:
736 mul_cost
= self
.expanded_width
** 2
737 mul_cost
*= self
.expanded_width
.bit_length()
739 cost
+= 5e7
* self
.iter_count
743 @lru_cache(maxsize
=1 << 16)
744 def __cached_new(base_params
):
745 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
746 # can't use dataclasses.asdict, since it's recursive and will also give
747 # child class fields too, which we don't want.
749 for field
in fields(GoldschmidtDivParamsBase
):
750 kwargs
[field
.name
] = getattr(base_params
, field
.name
)
752 return GoldschmidtDivParams(**kwargs
), None
753 except ParamsNotAccurateEnough
as e
:
757 def __raise(e
): # type: (ParamsNotAccurateEnough) -> Any
761 def cached_new(base_params
, handle_error
=__raise
):
762 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
763 params
, error
= GoldschmidtDivParams
.__cached
_new
(base_params
)
767 return handle_error(error
)
770 def get(io_width
, cost_fn
=default_cost_fn
, max_table_addr_bits
=12):
771 """ find efficient parameters for a goldschmidt division algorithm
772 with `params.io_width == io_width`.
776 bit-width of the input divisor and the result.
777 the input numerator is `2 * io_width`-bits wide.
778 cost_fn: Callable[[GoldschmidtDivParams], float]
779 return the estimated cost on an arbitrary scale of implementing
780 goldschmidt division with the specified parameters. larger cost
781 values mean worse parameters.
782 max_table_addr_bits: int
783 maximum allowable value of `table_addr_bits`
785 assert isinstance(io_width
, int) and io_width
>= 1
786 assert callable(cost_fn
)
789 last_error_params
= None
791 def cached_new(base_params
):
793 nonlocal last_error
, last_error_params
795 last_error_params
= base_params
798 retval
= GoldschmidtDivParams
.cached_new(base_params
, handle_error
)
800 logging
.debug(f
"GoldschmidtDivParams.get: err: {base_params}")
802 logging
.debug(f
"GoldschmidtDivParams.get: ok: {base_params}")
805 @lru_cache(maxsize
=None)
806 def get_cost(base_params
):
807 params
= cached_new(base_params
)
810 retval
= cost_fn(params
)
811 logging
.debug(f
"GoldschmidtDivParams.get: cost={retval}: {params}")
814 # start with parameters big enough to always work.
815 initial_extra_precision
= io_width
* 2 + 4
816 initial_params
= GoldschmidtDivParamsBase(
818 extra_precision
=initial_extra_precision
,
819 table_addr_bits
=min(max_table_addr_bits
, io_width
),
820 table_data_bits
=io_width
+ initial_extra_precision
,
821 iter_count
=1 + io_width
.bit_length())
823 if cached_new(initial_params
) is None:
824 raise ValueError(f
"initial goldschmidt division algorithm "
825 f
"parameters are invalid: {initial_params}"
828 # find good initial `iter_count`
829 params
= initial_params
830 for iter_count
in range(1, initial_params
.iter_count
):
831 trial_params
= replace(params
, iter_count
=iter_count
)
832 if cached_new(trial_params
) is not None:
833 params
= trial_params
836 # now find `table_addr_bits`
837 cost
= get_cost(params
)
838 for table_addr_bits
in range(1, max_table_addr_bits
):
839 trial_params
= replace(params
, table_addr_bits
=table_addr_bits
)
840 trial_cost
= get_cost(trial_params
)
841 if trial_cost
< cost
:
842 params
= trial_params
846 # check one higher `iter_count` to see if it has lower cost
847 for table_addr_bits
in range(1, max_table_addr_bits
+ 1):
848 trial_params
= replace(params
,
849 table_addr_bits
=table_addr_bits
,
850 iter_count
=params
.iter_count
+ 1)
851 trial_cost
= get_cost(trial_params
)
852 if trial_cost
< cost
:
853 params
= trial_params
857 # now shrink `table_data_bits`
859 trial_params
= replace(params
,
860 table_data_bits
=params
.table_data_bits
- 1)
861 trial_cost
= get_cost(trial_params
)
862 if trial_cost
< cost
:
863 params
= trial_params
868 # and shrink `extra_precision`
870 trial_params
= replace(params
,
871 extra_precision
=params
.extra_precision
- 1)
872 trial_cost
= get_cost(trial_params
)
873 if trial_cost
< cost
:
874 params
= trial_params
879 return cached_new(params
)
883 class GoldschmidtDivOp(enum
.Enum
):
884 Normalize
= "n, d, n_shift = normalize(n, d)"
885 FEqTableLookup
= "f = table_lookup(d)"
888 FEq2MinusD
= "f = 2 - d"
889 CalcResult
= "result = unnormalize_and_round(n)"
891 def run(self
, params
, state
):
892 assert isinstance(params
, GoldschmidtDivParams
)
893 assert isinstance(state
, GoldschmidtDivState
)
894 expanded_width
= params
.expanded_width
895 table_addr_bits
= params
.table_addr_bits
896 if self
== GoldschmidtDivOp
.Normalize
:
897 # normalize so 1 <= d < 2
898 # can easily be done with count-leading-zeros and left shift
900 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
901 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
904 # normalize so 1 <= n < 2
906 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
908 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
909 # compute initial f by table lookup
911 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
912 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
913 state
.f
= params
.table
[d_m_1
.bits
]
914 elif self
== GoldschmidtDivOp
.MulNByF
:
915 assert state
.f
is not None
916 n
= state
.n
* state
.f
917 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
918 elif self
== GoldschmidtDivOp
.MulDByF
:
919 assert state
.f
is not None
920 d
= state
.d
* state
.f
921 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
922 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
923 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
924 elif self
== GoldschmidtDivOp
.CalcResult
:
925 assert state
.n_shift
is not None
926 # scale to correct value
927 n
= state
.n
* (1 << state
.n_shift
)
929 state
.quotient
= math
.floor(n
)
930 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
931 if state
.remainder
>= state
.orig_d
:
933 state
.remainder
-= state
.orig_d
935 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
938 def goldschmidt_div(n
, d
, params
):
939 """ Goldschmidt division algorithm.
942 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
943 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
944 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
948 numerator. a `2*width`-bit unsigned integer.
949 must be less than `d << width`, otherwise the quotient wouldn't
952 denominator. a `width`-bit unsigned integer. must not be zero.
954 the bit-width of the inputs/outputs. must be a positive integer.
956 returns: tuple[int, int]
957 the quotient and remainder. a tuple of two `width`-bit unsigned
960 assert isinstance(params
, GoldschmidtDivParams
)
961 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
962 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
964 # this whole algorithm is done with fixed-point arithmetic where values
965 # have `width` fractional bits
967 state
= GoldschmidtDivState(
970 n
=FixedPoint(n
, params
.io_width
),
971 d
=FixedPoint(d
, params
.io_width
),
974 for op
in params
.ops
:
975 op
.run(params
, state
)
977 assert state
.quotient
is not None
978 assert state
.remainder
is not None
980 return state
.quotient
, state
.remainder