1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
11 from types
import FunctionType
14 from functools
import cached_property
16 from cached_property
import cached_property
18 # fix broken IDE type detection for cached_property
19 from typing
import TYPE_CHECKING
21 from functools
import cached_property
27 def cache_on_self(func
):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
32 assert isinstance(func
, FunctionType
), \
33 "non-plain methods are not supported"
35 cache_name
= func
.__name
__ + "__cache"
37 def wrapper(self
, *args
, **kwargs
):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
40 if cache
is _NOT_FOUND
:
41 self
.__dict
__[cache_name
] = cache
= {}
42 key
= (args
, *kwargs
.items())
43 retval
= cache
.get(key
, _NOT_FOUND
)
44 if retval
is _NOT_FOUND
:
45 retval
= func(self
, *args
, **kwargs
)
49 wrapper
.__doc
__ = func
.__doc
__
54 class RoundDir(enum
.Enum
):
57 NEAREST_TIES_UP
= enum
.auto()
58 ERROR_IF_INEXACT
= enum
.auto()
61 @dataclass(frozen
=True)
66 def __post_init__(self
):
67 assert isinstance(self
.bits
, int)
68 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value
, FixedPoint
):
76 if isinstance(value
, int):
77 return FixedPoint(value
, 0)
78 if isinstance(value
, str):
80 neg
= value
.startswith("-")
81 if neg
or value
.startswith("+"):
83 if value
.startswith(("0x", "0X")) and "." in value
:
93 raise ValueError("too many `.` in string")
98 if not digit
.isalnum():
99 raise ValueError("invalid hexadecimal digit")
101 bits |
= int("0x" + digit
, base
=16)
103 bits
= int(value
, base
=0)
107 return FixedPoint(bits
, frac_wid
)
109 if isinstance(value
, float):
110 n
, d
= value
.as_integer_ratio()
111 log2_d
= d
.bit_length() - 1
112 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n
, log2_d
)
115 raise TypeError("can't convert type to FixedPoint")
118 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid
, int) and frac_wid
>= 0
122 assert isinstance(round_dir
, RoundDir
)
123 if isinstance(value
, Fraction
):
124 numerator
= value
.numerator
125 denominator
= value
.denominator
127 value
= FixedPoint
.cast(value
)
128 numerator
= value
.bits
129 denominator
= 1 << value
.frac_wid
131 numerator
= -numerator
132 denominator
= -denominator
133 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
134 if round_dir
== RoundDir
.DOWN
:
136 elif round_dir
== RoundDir
.UP
:
139 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
140 if remainder
* 2 >= denominator
:
142 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
144 raise ValueError("inexact conversion")
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits
, frac_wid
)
149 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self
.bits
/ (1 << self
.frac_wid
))
159 def as_fraction(self
):
160 return Fraction(self
.bits
, 1 << self
.frac_wid
)
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs
= FixedPoint
.cast(rhs
)
167 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
168 lhs
= self
.to_frac_wid(common_frac_wid
)
169 rhs
= rhs
.to_frac_wid(common_frac_wid
)
170 return lhs
.bits
- rhs
.bits
172 def __eq__(self
, rhs
):
173 return self
.cmp(rhs
) == 0
175 def __ne__(self
, rhs
):
176 return self
.cmp(rhs
) != 0
178 def __gt__(self
, rhs
):
179 return self
.cmp(rhs
) > 0
181 def __lt__(self
, rhs
):
182 return self
.cmp(rhs
) < 0
184 def __ge__(self
, rhs
):
185 return self
.cmp(rhs
) >= 0
187 def __le__(self
, rhs
):
188 return self
.cmp(rhs
) <= 0
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
194 fract_mask
= (1 << self
.frac_wid
) - 1
195 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
199 return "-" + str(-self
)
201 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
202 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
203 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
204 return hex(math
.floor(self
)) + "." + frac_str
207 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
209 def __add__(self
, rhs
):
210 rhs
= FixedPoint
.cast(rhs
)
211 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
212 lhs
= self
.to_frac_wid(common_frac_wid
)
213 rhs
= rhs
.to_frac_wid(common_frac_wid
)
214 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
216 def __radd__(self
, lhs
):
218 return self
.__add
__(lhs
)
221 return FixedPoint(-self
.bits
, self
.frac_wid
)
223 def __sub__(self
, rhs
):
224 rhs
= FixedPoint
.cast(rhs
)
225 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
226 lhs
= self
.to_frac_wid(common_frac_wid
)
227 rhs
= rhs
.to_frac_wid(common_frac_wid
)
228 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
230 def __rsub__(self
, lhs
):
232 return -self
.__sub
__(lhs
)
234 def __mul__(self
, rhs
):
235 rhs
= FixedPoint
.cast(rhs
)
236 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
238 def __rmul__(self
, lhs
):
240 return self
.__mul
__(lhs
)
243 return self
.bits
>> self
.frac_wid
247 class GoldschmidtDivState
:
249 """original numerator"""
252 """original denominator"""
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
260 f
: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
263 quotient
: "int | None" = None
266 remainder
: "int | None" = None
267 """final remainder"""
269 n_shift
: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
280 def _assert_accuracy(condition
, msg
="not accurate enough"):
283 raise ParamsNotAccurateEnough(msg
)
286 @dataclass(frozen
=True, unsafe_hash
=True)
287 class GoldschmidtDivParamsBase
:
288 """parameters for a Goldschmidt division algorithm, excluding derived
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
298 """number of bits of additional precision used inside the algorithm."""
301 """the number of address bits used in the lookup-table."""
304 """the number of data bits used in the lookup-table."""
307 """the total number of iterations of the division algorithm's loop"""
310 @dataclass(frozen
=True, unsafe_hash
=True)
311 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
312 """parameters for a Goldschmidt division algorithm.
313 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
316 # tuple to be immutable, default so repr() works for debugging even when
317 # __post_init__ hasn't finished running yet
318 table
: "tuple[FixedPoint, ...]" = field(init
=False, default
=NotImplemented)
319 """the lookup-table"""
321 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False,
322 default
=NotImplemented)
323 """the operations needed to perform the goldschmidt division algorithm."""
325 def _shrink_bound(self
, bound
, round_dir
):
326 """prevent fractions from having huge numerators/denominators by
327 rounding to a `FixedPoint` and converting back to a `Fraction`.
329 This is intended only for values used to compute bounds, and not for
330 values that end up in the hardware.
332 assert isinstance(bound
, (Fraction
, int))
333 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
334 "you shouldn't use that round_dir on bounds"
335 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
336 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
337 return fixed
.as_fraction()
339 def _shrink_min(self
, min_bound
):
340 """prevent fractions used as minimum bounds from having huge
341 numerators/denominators by rounding down to a `FixedPoint` and
342 converting back to a `Fraction`.
344 This is intended only for values used to compute bounds, and not for
345 values that end up in the hardware.
347 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
349 def _shrink_max(self
, max_bound
):
350 """prevent fractions used as maximum bounds from having huge
351 numerators/denominators by rounding up to a `FixedPoint` and
352 converting back to a `Fraction`.
354 This is intended only for values used to compute bounds, and not for
355 values that end up in the hardware.
357 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
360 def table_addr_count(self
):
361 """number of distinct addresses in the lookup-table."""
362 # used while computing self.table, so can't just do len(self.table)
363 return 1 << self
.table_addr_bits
365 def table_input_exact_range(self
, addr
):
366 """return the range of inputs as `Fraction`s used for the table entry
367 with address `addr`."""
368 assert isinstance(addr
, int)
369 assert 0 <= addr
< self
.table_addr_count
370 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
371 addr_shift
= self
.io_width
- self
.table_addr_bits
372 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
373 denominator
= 1 << self
.io_width
374 values_per_table_entry
= 1 << addr_shift
375 max_numerator
= min_numerator
+ values_per_table_entry
- 1
376 min_input
= Fraction(min_numerator
, denominator
)
377 max_input
= Fraction(max_numerator
, denominator
)
378 min_input
= self
._shrink
_min
(min_input
)
379 max_input
= self
._shrink
_max
(max_input
)
380 assert 1 <= min_input
<= max_input
< 2
381 return min_input
, max_input
383 def table_value_exact_range(self
, addr
):
384 """return the range of values as `Fraction`s used for the table entry
385 with address `addr`."""
386 min_input
, max_input
= self
.table_input_exact_range(addr
)
387 # division swaps min/max
388 min_value
= 1 / max_input
389 max_value
= 1 / min_input
390 min_value
= self
._shrink
_min
(min_value
)
391 max_value
= self
._shrink
_max
(max_value
)
392 assert 0.5 < min_value
<= max_value
<= 1
393 return min_value
, max_value
395 def table_exact_value(self
, index
):
396 min_value
, max_value
= self
.table_value_exact_range(index
)
400 def __post_init__(self
):
401 # called by the autogenerated __init__
402 assert self
.io_width
>= 1
403 assert self
.extra_precision
>= 0
404 assert self
.table_addr_bits
>= 1
405 assert self
.table_data_bits
>= 1
406 assert self
.iter_count
>= 1
408 for addr
in range(1 << self
.table_addr_bits
):
409 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
410 self
.table_data_bits
,
412 # we have to use object.__setattr__ since frozen=True
413 object.__setattr
__(self
, "table", tuple(table
))
414 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
417 def expanded_width(self
):
418 """the total number of bits of precision used inside the algorithm."""
419 return self
.io_width
+ self
.extra_precision
422 def max_neps(self
, i
):
423 """maximum value of `neps[i]`.
424 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
426 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
427 return Fraction(1, 1 << self
.expanded_width
)
430 def max_deps(self
, i
):
431 """maximum value of `deps[i]`.
432 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
434 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
435 return Fraction(1, 1 << self
.expanded_width
)
438 def max_feps(self
, i
):
439 """maximum value of `feps[i]`.
440 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
442 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
443 # zero, because the computation of `F_prime[i]` in
444 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
449 """minimum and maximum values of `e[0]`
450 (the relative error in `F_prime[-1]`)
454 for addr
in range(self
.table_addr_count
):
455 # `F_prime[-1] = (1 - e[0]) / B`
456 # => `e[0] = 1 - B * F_prime[-1]`
457 min_b
, max_b
= self
.table_input_exact_range(addr
)
458 f_prime_m1
= self
.table
[addr
].as_fraction()
459 assert min_b
>= 0 and f_prime_m1
>= 0, \
460 "only positive quadrant of interval multiplication implemented"
461 min_product
= min_b
* f_prime_m1
462 max_product
= max_b
* f_prime_m1
463 # negation swaps min/max
464 cur_min_e0
= 1 - max_product
465 cur_max_e0
= 1 - min_product
466 min_e0
= min(min_e0
, cur_min_e0
)
467 max_e0
= max(max_e0
, cur_max_e0
)
468 min_e0
= self
._shrink
_min
(min_e0
)
469 max_e0
= self
._shrink
_max
(max_e0
)
470 return min_e0
, max_e0
474 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
476 min_e0
, max_e0
= self
.e0_range
481 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
483 min_e0
, max_e0
= self
.e0_range
487 def max_abs_e0(self
):
488 """maximum value of `abs(e[0])`."""
489 return max(abs(self
.min_e0
), abs(self
.max_e0
))
492 def min_abs_e0(self
):
493 """minimum value of `abs(e[0])`."""
498 """maximum value of `n[i]` (the relative error in `N_prime[i]`
499 relative to the previous iteration)
501 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
504 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
505 # `n[0] <= 2 * neps[0] / (1 - e[0])`
507 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
508 "only one quadrant of interval division implemented"
509 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
512 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
513 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
514 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
515 "only one quadrant of interval multiplication implemented"
516 prod
= (1 - self
.max_f(0)) * min_mpd
517 assert self
.max_neps(1) >= 0 and prod
> 0, \
518 "only one quadrant of interval division implemented"
519 retval
= self
.max_neps(1) / prod
522 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
523 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
524 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
525 "only one quadrant of interval division implemented"
526 retval
= self
.max_neps(i
) / min_mpd
528 return self
._shrink
_max
(retval
)
532 """maximum value of `d[i]` (the relative error in `D_prime[i]`
533 relative to the previous iteration)
535 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
538 # `d[0] = deps[0] / (1 - e[0])`
540 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
541 "only one quadrant of interval division implemented"
542 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
545 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
546 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
547 "only one quadrant of interval multiplication implemented"
548 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
549 assert self
.max_deps(1) >= 0 and divisor
> 0, \
550 "only one quadrant of interval division implemented"
551 retval
= self
.max_deps(1) / divisor
554 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
555 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
556 "only one quadrant of interval division implemented"
557 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
559 return self
._shrink
_max
(retval
)
563 """maximum value of `f[i]` (the relative error in `F_prime[i]`
564 relative to the previous iteration)
566 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
569 # `f[0] = feps[0] / (1 - delta[0])`
571 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
572 "only one quadrant of interval division implemented"
573 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
577 retval
= self
.max_feps(1)
580 # `f[i] <= max_feps[i]`
581 retval
= self
.max_feps(i
)
583 return self
._shrink
_max
(retval
)
586 def max_delta(self
, i
):
587 """ maximum value of `delta[i]`.
588 `delta[i]` is defined in Definition 4 of paper.
590 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
592 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
593 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
595 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
596 prev_max_delta
= self
.max_delta(i
- 1)
597 assert prev_max_delta
>= 0
598 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
600 # `delta[i]` has to be smaller than one otherwise errors would go off
602 _assert_accuracy(retval
< 1)
604 return self
._shrink
_max
(retval
)
608 """ maximum value of `pi[i]`.
609 `pi[i]` is defined right below Theorem 5 of paper.
611 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
612 # `pi[i] = 1 - (1 - n[i]) * prod`
613 # where `prod` is the product of,
614 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
615 min_prod
= Fraction(1)
617 max_n_j
= self
.max_n(j
)
618 max_d_j
= self
.max_d(j
)
619 assert max_n_j
<= 1 and max_d_j
> -1, \
620 "only one quadrant of interval division implemented"
621 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
622 max_n_i
= self
.max_n(i
)
623 assert max_n_i
<= 1 and min_prod
>= 0, \
624 "only one quadrant of interval multiplication implemented"
625 retval
= 1 - (1 - max_n_i
) * min_prod
626 return self
._shrink
_max
(retval
)
629 def max_n_shift(self
):
630 """ maximum value of `state.n_shift`.
632 # input numerator is `2*io_width`-bits
633 max_n
= (1 << (self
.io_width
* 2)) - 1
635 # normalize so 1 <= n < 2
643 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
646 for i
in range(self
.iter_count
):
647 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
648 return self
._shrink
_max
(n_hat
)
650 def __make_ops(self
):
651 """ Goldschmidt division algorithm.
654 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
655 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
656 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
658 yields: GoldschmidtDivOp
659 the operations needed to perform the division.
661 # establish assumptions of the paper's error analysis (section 3.1):
663 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
664 yield GoldschmidtDivOp
.Normalize
666 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
667 # the assumption is met by multipliers with > 4-bits precision
668 _assert_accuracy(self
.expanded_width
> 4)
670 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
671 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
672 + self
.max_f(0) < Fraction(1, 2))
674 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
675 # (B is the denominator)
677 for addr
in range(self
.table_addr_count
):
678 f_prime_m1
= self
.table
[addr
]
679 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
681 yield GoldschmidtDivOp
.FEqTableLookup
683 # we use Setting I (section 4.1 of the paper):
684 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
685 # the conditions on n_hat are satisfied by construction.
686 for i
in range(self
.iter_count
):
687 _assert_accuracy(self
.max_f(i
) == 0)
688 yield GoldschmidtDivOp
.MulNByF
689 if i
!= self
.iter_count
- 1:
690 yield GoldschmidtDivOp
.MulDByF
691 yield GoldschmidtDivOp
.FEq2MinusD
693 # relative approximation error `p(N_prime[i])`:
694 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
695 # `0 <= p(N_prime[i])`
696 # `p(N_prime[i]) <= (2 * i) * n_hat \`
697 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
698 i
= self
.iter_count
- 1 # last used `i`
699 # compute power manually to prevent huge intermediate values
700 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
702 power
= self
._shrink
_max
(power
* power
)
704 max_rel_error
= (2 * i
) * self
.n_hat
+ power
706 min_a_over_b
= Fraction(1, 2)
707 max_a_over_b
= Fraction(2)
708 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
709 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
711 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
712 f
"not accurate enough: max_rel_error={max_rel_error}"
713 f
" max_allowed_rel_error={max_allowed_rel_error}")
715 yield GoldschmidtDivOp
.CalcResult
717 def default_cost_fn(self
):
718 """ calculate the estimated cost on an arbitrary scale of implementing
719 goldschmidt division with the specified parameters. larger cost
720 values mean worse parameters.
722 This is the default cost function for `GoldschmidtDivParams.get`.
726 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
727 cost
= float(rom_cells
)
729 if op
== GoldschmidtDivOp
.MulNByF \
730 or op
== GoldschmidtDivOp
.MulDByF
:
731 mul_cost
= self
.expanded_width
** 2
732 mul_cost
*= self
.expanded_width
.bit_length()
734 cost
+= 1e6
* self
.iter_count
739 """ find efficient parameters for a goldschmidt division algorithm
740 with `params.io_width == io_width`.
742 assert isinstance(io_width
, int) and io_width
>= 1
745 for extra_precision
in range(io_width
* 2 + 4):
746 for table_addr_bits
in range(1, 7 + 1):
747 table_data_bits
= io_width
+ extra_precision
748 for iter_count
in range(1, 2 * io_width
.bit_length()):
750 return GoldschmidtDivParams(
752 extra_precision
=extra_precision
,
753 table_addr_bits
=table_addr_bits
,
754 table_data_bits
=table_data_bits
,
755 iter_count
=iter_count
)
756 except ParamsNotAccurateEnough
as e
:
757 last_params
= (f
"GoldschmidtDivParams("
758 f
"io_width={io_width!r}, "
759 f
"extra_precision={extra_precision!r}, "
760 f
"table_addr_bits={table_addr_bits!r}, "
761 f
"table_data_bits={table_data_bits!r}, "
762 f
"iter_count={iter_count!r})")
764 raise ValueError(f
"can't find working parameters for a goldschmidt "
765 f
"division algorithm: last params: {last_params}"
770 class GoldschmidtDivOp(enum
.Enum
):
771 Normalize
= "n, d, n_shift = normalize(n, d)"
772 FEqTableLookup
= "f = table_lookup(d)"
775 FEq2MinusD
= "f = 2 - d"
776 CalcResult
= "result = unnormalize_and_round(n)"
778 def run(self
, params
, state
):
779 assert isinstance(params
, GoldschmidtDivParams
)
780 assert isinstance(state
, GoldschmidtDivState
)
781 expanded_width
= params
.expanded_width
782 table_addr_bits
= params
.table_addr_bits
783 if self
== GoldschmidtDivOp
.Normalize
:
784 # normalize so 1 <= d < 2
785 # can easily be done with count-leading-zeros and left shift
787 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
788 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
791 # normalize so 1 <= n < 2
793 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
795 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
796 # compute initial f by table lookup
798 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
799 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
800 state
.f
= params
.table
[d_m_1
.bits
]
801 elif self
== GoldschmidtDivOp
.MulNByF
:
802 assert state
.f
is not None
803 n
= state
.n
* state
.f
804 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
805 elif self
== GoldschmidtDivOp
.MulDByF
:
806 assert state
.f
is not None
807 d
= state
.d
* state
.f
808 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
809 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
810 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
811 elif self
== GoldschmidtDivOp
.CalcResult
:
812 assert state
.n_shift
is not None
813 # scale to correct value
814 n
= state
.n
* (1 << state
.n_shift
)
816 state
.quotient
= math
.floor(n
)
817 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
818 if state
.remainder
>= state
.orig_d
:
820 state
.remainder
-= state
.orig_d
822 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
825 def goldschmidt_div(n
, d
, params
):
826 """ Goldschmidt division algorithm.
829 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
830 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
831 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
835 numerator. a `2*width`-bit unsigned integer.
836 must be less than `d << width`, otherwise the quotient wouldn't
839 denominator. a `width`-bit unsigned integer. must not be zero.
841 the bit-width of the inputs/outputs. must be a positive integer.
843 returns: tuple[int, int]
844 the quotient and remainder. a tuple of two `width`-bit unsigned
847 assert isinstance(params
, GoldschmidtDivParams
)
848 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
849 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
851 # this whole algorithm is done with fixed-point arithmetic where values
852 # have `width` fractional bits
854 state
= GoldschmidtDivState(
857 n
=FixedPoint(n
, params
.io_width
),
858 d
=FixedPoint(d
, params
.io_width
),
861 for op
in params
.ops
:
862 op
.run(params
, state
)
864 assert state
.quotient
is not None
865 assert state
.remainder
is not None
867 return state
.quotient
, state
.remainder