1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
11 from types
import FunctionType
14 from functools
import cached_property
16 from cached_property
import cached_property
18 # fix broken IDE type detection for cached_property
19 from typing
import TYPE_CHECKING
21 from functools
import cached_property
27 def cache_on_self(func
):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
32 assert isinstance(func
, FunctionType
), \
33 "non-plain methods are not supported"
35 cache_name
= func
.__name
__ + "__cache"
37 def wrapper(self
, *args
, **kwargs
):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
40 if cache
is _NOT_FOUND
:
41 self
.__dict
__[cache_name
] = cache
= {}
42 key
= (args
, *kwargs
.items())
43 retval
= cache
.get(key
, _NOT_FOUND
)
44 if retval
is _NOT_FOUND
:
45 retval
= func(self
, *args
, **kwargs
)
49 wrapper
.__doc
__ = func
.__doc
__
54 class RoundDir(enum
.Enum
):
57 NEAREST_TIES_UP
= enum
.auto()
58 ERROR_IF_INEXACT
= enum
.auto()
61 @dataclass(frozen
=True)
66 def __post_init__(self
):
67 assert isinstance(self
.bits
, int)
68 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value
, FixedPoint
):
76 if isinstance(value
, int):
77 return FixedPoint(value
, 0)
78 if isinstance(value
, str):
80 neg
= value
.startswith("-")
81 if neg
or value
.startswith("+"):
83 if value
.startswith(("0x", "0X")) and "." in value
:
93 raise ValueError("too many `.` in string")
98 if not digit
.isalnum():
99 raise ValueError("invalid hexadecimal digit")
101 bits |
= int("0x" + digit
, base
=16)
103 bits
= int(value
, base
=0)
107 return FixedPoint(bits
, frac_wid
)
109 if isinstance(value
, float):
110 n
, d
= value
.as_integer_ratio()
111 log2_d
= d
.bit_length() - 1
112 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n
, log2_d
)
115 raise TypeError("can't convert type to FixedPoint")
118 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid
, int) and frac_wid
>= 0
122 assert isinstance(round_dir
, RoundDir
)
123 if isinstance(value
, Fraction
):
124 numerator
= value
.numerator
125 denominator
= value
.denominator
127 value
= FixedPoint
.cast(value
)
128 numerator
= value
.bits
129 denominator
= 1 << value
.frac_wid
131 numerator
= -numerator
132 denominator
= -denominator
133 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
134 if round_dir
== RoundDir
.DOWN
:
136 elif round_dir
== RoundDir
.UP
:
139 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
140 if remainder
* 2 >= denominator
:
142 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
144 raise ValueError("inexact conversion")
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits
, frac_wid
)
149 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self
.bits
/ (1 << self
.frac_wid
))
159 def as_fraction(self
):
160 return Fraction(self
.bits
, 1 << self
.frac_wid
)
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs
= FixedPoint
.cast(rhs
)
167 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
168 lhs
= self
.to_frac_wid(common_frac_wid
)
169 rhs
= rhs
.to_frac_wid(common_frac_wid
)
170 return lhs
.bits
- rhs
.bits
172 def __eq__(self
, rhs
):
173 return self
.cmp(rhs
) == 0
175 def __ne__(self
, rhs
):
176 return self
.cmp(rhs
) != 0
178 def __gt__(self
, rhs
):
179 return self
.cmp(rhs
) > 0
181 def __lt__(self
, rhs
):
182 return self
.cmp(rhs
) < 0
184 def __ge__(self
, rhs
):
185 return self
.cmp(rhs
) >= 0
187 def __le__(self
, rhs
):
188 return self
.cmp(rhs
) <= 0
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
194 fract_mask
= (1 << self
.frac_wid
) - 1
195 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
199 return "-" + str(-self
)
201 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
202 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
203 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
204 return hex(math
.floor(self
)) + "." + frac_str
207 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
209 def __add__(self
, rhs
):
210 rhs
= FixedPoint
.cast(rhs
)
211 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
212 lhs
= self
.to_frac_wid(common_frac_wid
)
213 rhs
= rhs
.to_frac_wid(common_frac_wid
)
214 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
216 def __radd__(self
, lhs
):
218 return self
.__add
__(lhs
)
221 return FixedPoint(-self
.bits
, self
.frac_wid
)
223 def __sub__(self
, rhs
):
224 rhs
= FixedPoint
.cast(rhs
)
225 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
226 lhs
= self
.to_frac_wid(common_frac_wid
)
227 rhs
= rhs
.to_frac_wid(common_frac_wid
)
228 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
230 def __rsub__(self
, lhs
):
232 return -self
.__sub
__(lhs
)
234 def __mul__(self
, rhs
):
235 rhs
= FixedPoint
.cast(rhs
)
236 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
238 def __rmul__(self
, lhs
):
240 return self
.__mul
__(lhs
)
243 return self
.bits
>> self
.frac_wid
247 class GoldschmidtDivState
:
249 """original numerator"""
252 """original denominator"""
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
260 f
: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
263 quotient
: "int | None" = None
266 remainder
: "int | None" = None
267 """final remainder"""
269 n_shift
: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
280 def _assert_accuracy(condition
, msg
="not accurate enough"):
283 raise ParamsNotAccurateEnough(msg
)
286 @dataclass(frozen
=True, unsafe_hash
=True)
287 class GoldschmidtDivParams
:
288 """parameters for a Goldschmidt division algorithm.
289 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
298 """number of bits of additional precision used inside the algorithm."""
301 """the number of address bits used in the lookup-table."""
304 """the number of data bits used in the lookup-table."""
307 """the total number of iterations of the division algorithm's loop"""
309 # tuple to be immutable, default so repr() works for debugging even when
310 # __post_init__ hasn't finished running yet
311 table
: "tuple[FixedPoint, ...]" = field(init
=False, default
=NotImplemented)
312 """the lookup-table"""
314 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False,
315 default
=NotImplemented)
316 """the operations needed to perform the goldschmidt division algorithm."""
318 def _shrink_bound(self
, bound
, round_dir
):
319 """prevent fractions from having huge numerators/denominators by
320 rounding to a `FixedPoint` and converting back to a `Fraction`.
322 This is intended only for values used to compute bounds, and not for
323 values that end up in the hardware.
325 assert isinstance(bound
, (Fraction
, int))
326 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
327 "you shouldn't use that round_dir on bounds"
328 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
329 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
330 return fixed
.as_fraction()
332 def _shrink_min(self
, min_bound
):
333 """prevent fractions used as minimum bounds from having huge
334 numerators/denominators by rounding down to a `FixedPoint` and
335 converting back to a `Fraction`.
337 This is intended only for values used to compute bounds, and not for
338 values that end up in the hardware.
340 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
342 def _shrink_max(self
, max_bound
):
343 """prevent fractions used as maximum bounds from having huge
344 numerators/denominators by rounding up to a `FixedPoint` and
345 converting back to a `Fraction`.
347 This is intended only for values used to compute bounds, and not for
348 values that end up in the hardware.
350 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
353 def table_addr_count(self
):
354 """number of distinct addresses in the lookup-table."""
355 # used while computing self.table, so can't just do len(self.table)
356 return 1 << self
.table_addr_bits
358 def table_input_exact_range(self
, addr
):
359 """return the range of inputs as `Fraction`s used for the table entry
360 with address `addr`."""
361 assert isinstance(addr
, int)
362 assert 0 <= addr
< self
.table_addr_count
363 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
364 addr_shift
= self
.io_width
- self
.table_addr_bits
365 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
366 denominator
= 1 << self
.io_width
367 values_per_table_entry
= 1 << addr_shift
368 max_numerator
= min_numerator
+ values_per_table_entry
- 1
369 min_input
= Fraction(min_numerator
, denominator
)
370 max_input
= Fraction(max_numerator
, denominator
)
371 min_input
= self
._shrink
_min
(min_input
)
372 max_input
= self
._shrink
_max
(max_input
)
373 assert 1 <= min_input
<= max_input
< 2
374 return min_input
, max_input
376 def table_value_exact_range(self
, addr
):
377 """return the range of values as `Fraction`s used for the table entry
378 with address `addr`."""
379 min_input
, max_input
= self
.table_input_exact_range(addr
)
380 # division swaps min/max
381 min_value
= 1 / max_input
382 max_value
= 1 / min_input
383 min_value
= self
._shrink
_min
(min_value
)
384 max_value
= self
._shrink
_max
(max_value
)
385 assert 0.5 < min_value
<= max_value
<= 1
386 return min_value
, max_value
388 def table_exact_value(self
, index
):
389 min_value
, max_value
= self
.table_value_exact_range(index
)
393 def __post_init__(self
):
394 # called by the autogenerated __init__
395 assert self
.io_width
>= 1
396 assert self
.extra_precision
>= 0
397 assert self
.table_addr_bits
>= 1
398 assert self
.table_data_bits
>= 1
399 assert self
.iter_count
>= 1
401 for addr
in range(1 << self
.table_addr_bits
):
402 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
403 self
.table_data_bits
,
405 # we have to use object.__setattr__ since frozen=True
406 object.__setattr
__(self
, "table", tuple(table
))
407 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
410 def expanded_width(self
):
411 """the total number of bits of precision used inside the algorithm."""
412 return self
.io_width
+ self
.extra_precision
415 def max_neps(self
, i
):
416 """maximum value of `neps[i]`.
417 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
419 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
420 return Fraction(1, 1 << self
.expanded_width
)
423 def max_deps(self
, i
):
424 """maximum value of `deps[i]`.
425 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
427 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
428 return Fraction(1, 1 << self
.expanded_width
)
431 def max_feps(self
, i
):
432 """maximum value of `feps[i]`.
433 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
435 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
436 # zero, because the computation of `F_prime[i]` in
437 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
442 """minimum and maximum values of `e[0]`
443 (the relative error in `F_prime[-1]`)
447 for addr
in range(self
.table_addr_count
):
448 # `F_prime[-1] = (1 - e[0]) / B`
449 # => `e[0] = 1 - B * F_prime[-1]`
450 min_b
, max_b
= self
.table_input_exact_range(addr
)
451 f_prime_m1
= self
.table
[addr
].as_fraction()
452 assert min_b
>= 0 and f_prime_m1
>= 0, \
453 "only positive quadrant of interval multiplication implemented"
454 min_product
= min_b
* f_prime_m1
455 max_product
= max_b
* f_prime_m1
456 # negation swaps min/max
457 cur_min_e0
= 1 - max_product
458 cur_max_e0
= 1 - min_product
459 min_e0
= min(min_e0
, cur_min_e0
)
460 max_e0
= max(max_e0
, cur_max_e0
)
461 min_e0
= self
._shrink
_min
(min_e0
)
462 max_e0
= self
._shrink
_max
(max_e0
)
463 return min_e0
, max_e0
467 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
469 min_e0
, max_e0
= self
.e0_range
474 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
476 min_e0
, max_e0
= self
.e0_range
480 def max_abs_e0(self
):
481 """maximum value of `abs(e[0])`."""
482 return max(abs(self
.min_e0
), abs(self
.max_e0
))
485 def min_abs_e0(self
):
486 """minimum value of `abs(e[0])`."""
491 """maximum value of `n[i]` (the relative error in `N_prime[i]`
492 relative to the previous iteration)
494 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
497 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
498 # `n[0] <= 2 * neps[0] / (1 - e[0])`
500 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
501 "only one quadrant of interval division implemented"
502 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
505 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
506 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
507 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
508 "only one quadrant of interval multiplication implemented"
509 prod
= (1 - self
.max_f(0)) * min_mpd
510 assert self
.max_neps(1) >= 0 and prod
> 0, \
511 "only one quadrant of interval division implemented"
512 retval
= self
.max_neps(1) / prod
515 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
516 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
517 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
518 "only one quadrant of interval division implemented"
519 retval
= self
.max_neps(i
) / min_mpd
521 return self
._shrink
_max
(retval
)
525 """maximum value of `d[i]` (the relative error in `D_prime[i]`
526 relative to the previous iteration)
528 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
531 # `d[0] = deps[0] / (1 - e[0])`
533 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
534 "only one quadrant of interval division implemented"
535 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
538 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
539 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
540 "only one quadrant of interval multiplication implemented"
541 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
542 assert self
.max_deps(1) >= 0 and divisor
> 0, \
543 "only one quadrant of interval division implemented"
544 retval
= self
.max_deps(1) / divisor
547 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
548 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
549 "only one quadrant of interval division implemented"
550 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
552 return self
._shrink
_max
(retval
)
556 """maximum value of `f[i]` (the relative error in `F_prime[i]`
557 relative to the previous iteration)
559 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
562 # `f[0] = feps[0] / (1 - delta[0])`
564 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
565 "only one quadrant of interval division implemented"
566 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
570 retval
= self
.max_feps(1)
573 # `f[i] <= max_feps[i]`
574 retval
= self
.max_feps(i
)
576 return self
._shrink
_max
(retval
)
579 def max_delta(self
, i
):
580 """ maximum value of `delta[i]`.
581 `delta[i]` is defined in Definition 4 of paper.
583 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
585 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
586 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
588 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
589 prev_max_delta
= self
.max_delta(i
- 1)
590 assert prev_max_delta
>= 0
591 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
593 # `delta[i]` has to be smaller than one otherwise errors would go off
595 _assert_accuracy(retval
< 1)
597 return self
._shrink
_max
(retval
)
601 """ maximum value of `pi[i]`.
602 `pi[i]` is defined right below Theorem 5 of paper.
604 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
605 # `pi[i] = 1 - (1 - n[i]) * prod`
606 # where `prod` is the product of,
607 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
608 min_prod
= Fraction(1)
610 max_n_j
= self
.max_n(j
)
611 max_d_j
= self
.max_d(j
)
612 assert max_n_j
<= 1 and max_d_j
> -1, \
613 "only one quadrant of interval division implemented"
614 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
615 max_n_i
= self
.max_n(i
)
616 assert max_n_i
<= 1 and min_prod
>= 0, \
617 "only one quadrant of interval multiplication implemented"
618 retval
= 1 - (1 - max_n_i
) * min_prod
619 return self
._shrink
_max
(retval
)
622 def max_n_shift(self
):
623 """ maximum value of `state.n_shift`.
625 # input numerator is `2*io_width`-bits
626 max_n
= (1 << (self
.io_width
* 2)) - 1
628 # normalize so 1 <= n < 2
636 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
639 for i
in range(self
.iter_count
):
640 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
641 return self
._shrink
_max
(n_hat
)
643 def __make_ops(self
):
644 """ Goldschmidt division algorithm.
647 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
648 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
649 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
651 yields: GoldschmidtDivOp
652 the operations needed to perform the division.
654 # establish assumptions of the paper's error analysis (section 3.1):
656 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
657 yield GoldschmidtDivOp
.Normalize
659 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
660 # the assumption is met by multipliers with > 4-bits precision
661 _assert_accuracy(self
.expanded_width
> 4)
663 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
664 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
665 + self
.max_f(0) < Fraction(1, 2))
667 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
668 # (B is the denominator)
670 for addr
in range(self
.table_addr_count
):
671 f_prime_m1
= self
.table
[addr
]
672 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
674 yield GoldschmidtDivOp
.FEqTableLookup
676 # we use Setting I (section 4.1 of the paper):
677 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
678 # the conditions on n_hat are satisfied by construction.
679 for i
in range(self
.iter_count
):
680 _assert_accuracy(self
.max_f(i
) == 0)
681 yield GoldschmidtDivOp
.MulNByF
682 if i
!= self
.iter_count
- 1:
683 yield GoldschmidtDivOp
.MulDByF
684 yield GoldschmidtDivOp
.FEq2MinusD
686 # relative approximation error `p(N_prime[i])`:
687 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
688 # `0 <= p(N_prime[i])`
689 # `p(N_prime[i]) <= (2 * i) * n_hat \`
690 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
691 i
= self
.iter_count
- 1 # last used `i`
692 # compute power manually to prevent huge intermediate values
693 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
695 power
= self
._shrink
_max
(power
* power
)
697 max_rel_error
= (2 * i
) * self
.n_hat
+ power
699 min_a_over_b
= Fraction(1, 2)
700 max_a_over_b
= Fraction(2)
701 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
702 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
704 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
705 f
"not accurate enough: max_rel_error={max_rel_error}"
706 f
" max_allowed_rel_error={max_allowed_rel_error}")
708 yield GoldschmidtDivOp
.CalcResult
710 def default_cost_fn(self
):
711 """ calculate the estimated cost on an arbitrary scale of implementing
712 goldschmidt division with the specified parameters. larger cost
713 values mean worse parameters.
715 This is the default cost function for `GoldschmidtDivParams.get`.
719 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
720 cost
= float(rom_cells
)
722 if op
== GoldschmidtDivOp
.MulNByF \
723 or op
== GoldschmidtDivOp
.MulDByF
:
724 mul_cost
= self
.expanded_width
** 2
725 mul_cost
*= self
.expanded_width
.bit_length()
727 cost
+= 1e6
* self
.iter_count
732 """ find efficient parameters for a goldschmidt division algorithm
733 with `params.io_width == io_width`.
735 assert isinstance(io_width
, int) and io_width
>= 1
738 for extra_precision
in range(io_width
* 2 + 4):
739 for table_addr_bits
in range(1, 7 + 1):
740 table_data_bits
= io_width
+ extra_precision
741 for iter_count
in range(1, 2 * io_width
.bit_length()):
743 return GoldschmidtDivParams(
745 extra_precision
=extra_precision
,
746 table_addr_bits
=table_addr_bits
,
747 table_data_bits
=table_data_bits
,
748 iter_count
=iter_count
)
749 except ParamsNotAccurateEnough
as e
:
750 last_params
= (f
"GoldschmidtDivParams("
751 f
"io_width={io_width!r}, "
752 f
"extra_precision={extra_precision!r}, "
753 f
"table_addr_bits={table_addr_bits!r}, "
754 f
"table_data_bits={table_data_bits!r}, "
755 f
"iter_count={iter_count!r})")
757 raise ValueError(f
"can't find working parameters for a goldschmidt "
758 f
"division algorithm: last params: {last_params}"
763 class GoldschmidtDivOp(enum
.Enum
):
764 Normalize
= "n, d, n_shift = normalize(n, d)"
765 FEqTableLookup
= "f = table_lookup(d)"
768 FEq2MinusD
= "f = 2 - d"
769 CalcResult
= "result = unnormalize_and_round(n)"
771 def run(self
, params
, state
):
772 assert isinstance(params
, GoldschmidtDivParams
)
773 assert isinstance(state
, GoldschmidtDivState
)
774 expanded_width
= params
.expanded_width
775 table_addr_bits
= params
.table_addr_bits
776 if self
== GoldschmidtDivOp
.Normalize
:
777 # normalize so 1 <= d < 2
778 # can easily be done with count-leading-zeros and left shift
780 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
781 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
784 # normalize so 1 <= n < 2
786 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
788 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
789 # compute initial f by table lookup
791 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
792 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
793 state
.f
= params
.table
[d_m_1
.bits
]
794 elif self
== GoldschmidtDivOp
.MulNByF
:
795 assert state
.f
is not None
796 n
= state
.n
* state
.f
797 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
798 elif self
== GoldschmidtDivOp
.MulDByF
:
799 assert state
.f
is not None
800 d
= state
.d
* state
.f
801 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
802 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
803 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
804 elif self
== GoldschmidtDivOp
.CalcResult
:
805 assert state
.n_shift
is not None
806 # scale to correct value
807 n
= state
.n
* (1 << state
.n_shift
)
809 state
.quotient
= math
.floor(n
)
810 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
811 if state
.remainder
>= state
.orig_d
:
813 state
.remainder
-= state
.orig_d
815 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
818 def goldschmidt_div(n
, d
, params
):
819 """ Goldschmidt division algorithm.
822 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
823 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
824 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
828 numerator. a `2*width`-bit unsigned integer.
829 must be less than `d << width`, otherwise the quotient wouldn't
832 denominator. a `width`-bit unsigned integer. must not be zero.
834 the bit-width of the inputs/outputs. must be a positive integer.
836 returns: tuple[int, int]
837 the quotient and remainder. a tuple of two `width`-bit unsigned
840 assert isinstance(params
, GoldschmidtDivParams
)
841 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
842 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
844 # this whole algorithm is done with fixed-point arithmetic where values
845 # have `width` fractional bits
847 state
= GoldschmidtDivState(
850 n
=FixedPoint(n
, params
.io_width
),
851 d
=FixedPoint(d
, params
.io_width
),
854 for op
in params
.ops
:
855 op
.run(params
, state
)
857 assert state
.quotient
is not None
858 assert state
.remainder
is not None
860 return state
.quotient
, state
.remainder