03378048810b08eb2e87e10d77d4576625338279
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
11 from types
import FunctionType
14 from functools
import cached_property
16 from cached_property
import cached_property
18 # fix broken IDE type detection for cached_property
19 from typing
import TYPE_CHECKING
21 from functools
import cached_property
27 def cache_on_self(func
):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
32 assert isinstance(func
, FunctionType
), \
33 "non-plain methods are not supported"
35 cache_name
= func
.__name
__ + "__cache"
37 def wrapper(self
, *args
, **kwargs
):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
40 if cache
is _NOT_FOUND
:
41 self
.__dict
__[cache_name
] = cache
= {}
42 key
= (args
, *kwargs
.items())
43 retval
= cache
.get(key
, _NOT_FOUND
)
44 if retval
is _NOT_FOUND
:
45 retval
= func(self
, *args
, **kwargs
)
49 wrapper
.__doc
__ = func
.__doc
__
54 class RoundDir(enum
.Enum
):
57 NEAREST_TIES_UP
= enum
.auto()
58 ERROR_IF_INEXACT
= enum
.auto()
61 @dataclass(frozen
=True)
66 def __post_init__(self
):
67 assert isinstance(self
.bits
, int)
68 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value
, FixedPoint
):
76 if isinstance(value
, int):
77 return FixedPoint(value
, 0)
78 if isinstance(value
, str):
80 neg
= value
.startswith("-")
81 if neg
or value
.startswith("+"):
83 if value
.startswith(("0x", "0X")) and "." in value
:
93 raise ValueError("too many `.` in string")
98 if not digit
.isalnum():
99 raise ValueError("invalid hexadecimal digit")
101 bits |
= int("0x" + digit
, base
=16)
103 bits
= int(value
, base
=0)
107 return FixedPoint(bits
, frac_wid
)
109 if isinstance(value
, float):
110 n
, d
= value
.as_integer_ratio()
111 log2_d
= d
.bit_length() - 1
112 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n
, log2_d
)
115 raise TypeError("can't convert type to FixedPoint")
118 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid
, int) and frac_wid
>= 0
122 assert isinstance(round_dir
, RoundDir
)
123 if isinstance(value
, Fraction
):
124 numerator
= value
.numerator
125 denominator
= value
.denominator
127 value
= FixedPoint
.cast(value
)
128 numerator
= value
.bits
129 denominator
= 1 << value
.frac_wid
131 numerator
= -numerator
132 denominator
= -denominator
133 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
134 if round_dir
== RoundDir
.DOWN
:
136 elif round_dir
== RoundDir
.UP
:
139 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
140 if remainder
* 2 >= denominator
:
142 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
144 raise ValueError("inexact conversion")
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits
, frac_wid
)
149 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self
.bits
/ (1 << self
.frac_wid
))
159 def as_fraction(self
):
160 return Fraction(self
.bits
, 1 << self
.frac_wid
)
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs
= FixedPoint
.cast(rhs
)
167 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
168 lhs
= self
.to_frac_wid(common_frac_wid
)
169 rhs
= rhs
.to_frac_wid(common_frac_wid
)
170 return lhs
.bits
- rhs
.bits
172 def __eq__(self
, rhs
):
173 return self
.cmp(rhs
) == 0
175 def __ne__(self
, rhs
):
176 return self
.cmp(rhs
) != 0
178 def __gt__(self
, rhs
):
179 return self
.cmp(rhs
) > 0
181 def __lt__(self
, rhs
):
182 return self
.cmp(rhs
) < 0
184 def __ge__(self
, rhs
):
185 return self
.cmp(rhs
) >= 0
187 def __le__(self
, rhs
):
188 return self
.cmp(rhs
) <= 0
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
194 fract_mask
= (1 << self
.frac_wid
) - 1
195 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
199 return "-" + str(-self
)
201 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
202 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
203 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
204 return hex(math
.floor(self
)) + "." + frac_str
207 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
209 def __add__(self
, rhs
):
210 rhs
= FixedPoint
.cast(rhs
)
211 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
212 lhs
= self
.to_frac_wid(common_frac_wid
)
213 rhs
= rhs
.to_frac_wid(common_frac_wid
)
214 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
216 def __radd__(self
, lhs
):
218 return self
.__add
__(lhs
)
221 return FixedPoint(-self
.bits
, self
.frac_wid
)
223 def __sub__(self
, rhs
):
224 rhs
= FixedPoint
.cast(rhs
)
225 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
226 lhs
= self
.to_frac_wid(common_frac_wid
)
227 rhs
= rhs
.to_frac_wid(common_frac_wid
)
228 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
230 def __rsub__(self
, lhs
):
232 return -self
.__sub
__(lhs
)
234 def __mul__(self
, rhs
):
235 rhs
= FixedPoint
.cast(rhs
)
236 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
238 def __rmul__(self
, lhs
):
240 return self
.__mul
__(lhs
)
243 return self
.bits
>> self
.frac_wid
247 class GoldschmidtDivState
:
249 """original numerator"""
252 """original denominator"""
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
260 f
: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
263 quotient
: "int | None" = None
266 remainder
: "int | None" = None
267 """final remainder"""
269 n_shift
: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
280 def _assert_accuracy(condition
, msg
="not accurate enough"):
283 raise ParamsNotAccurateEnough(msg
)
286 @dataclass(frozen
=True, unsafe_hash
=True)
287 class GoldschmidtDivParams
:
288 """parameters for a Goldschmidt division algorithm.
289 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
298 """number of bits of additional precision used inside the algorithm."""
301 """the number of address bits used in the lookup-table."""
304 """the number of data bits used in the lookup-table."""
307 """the total number of iterations of the division algorithm's loop"""
309 # tuple to be immutable, default so repr() works for debugging even when
310 # __post_init__ hasn't finished running yet
311 table
: "tuple[FixedPoint, ...]" = field(init
=False, default
=NotImplemented)
312 """the lookup-table"""
314 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False,
315 default
=NotImplemented)
316 """the operations needed to perform the goldschmidt division algorithm."""
318 def _shrink_bound(self
, bound
, round_dir
):
319 """prevent fractions from having huge numerators/denominators by
320 rounding to a `FixedPoint` and converting back to a `Fraction`.
322 This is intended only for values used to compute bounds, and not for
323 values that end up in the hardware.
325 assert isinstance(bound
, (Fraction
, int))
326 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
327 "you shouldn't use that round_dir on bounds"
328 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
329 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
330 return fixed
.as_fraction()
332 def _shrink_min(self
, min_bound
):
333 """prevent fractions used as minimum bounds from having huge
334 numerators/denominators by rounding down to a `FixedPoint` and
335 converting back to a `Fraction`.
337 This is intended only for values used to compute bounds, and not for
338 values that end up in the hardware.
340 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
342 def _shrink_max(self
, max_bound
):
343 """prevent fractions used as maximum bounds from having huge
344 numerators/denominators by rounding up to a `FixedPoint` and
345 converting back to a `Fraction`.
347 This is intended only for values used to compute bounds, and not for
348 values that end up in the hardware.
350 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
353 def table_addr_count(self
):
354 """number of distinct addresses in the lookup-table."""
355 # used while computing self.table, so can't just do len(self.table)
356 return 1 << self
.table_addr_bits
358 def table_input_exact_range(self
, addr
):
359 """return the range of inputs as `Fraction`s used for the table entry
360 with address `addr`."""
361 assert isinstance(addr
, int)
362 assert 0 <= addr
< self
.table_addr_count
363 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
364 addr_shift
= self
.io_width
- self
.table_addr_bits
365 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
366 denominator
= 1 << self
.io_width
367 values_per_table_entry
= 1 << addr_shift
368 max_numerator
= min_numerator
+ values_per_table_entry
- 1
369 min_input
= Fraction(min_numerator
, denominator
)
370 max_input
= Fraction(max_numerator
, denominator
)
371 min_input
= self
._shrink
_min
(min_input
)
372 max_input
= self
._shrink
_max
(max_input
)
373 assert 1 <= min_input
<= max_input
< 2
374 return min_input
, max_input
376 def table_value_exact_range(self
, addr
):
377 """return the range of values as `Fraction`s used for the table entry
378 with address `addr`."""
379 min_input
, max_input
= self
.table_input_exact_range(addr
)
380 # division swaps min/max
381 min_value
= 1 / max_input
382 max_value
= 1 / min_input
383 min_value
= self
._shrink
_min
(min_value
)
384 max_value
= self
._shrink
_max
(max_value
)
385 assert 0.5 < min_value
<= max_value
<= 1
386 return min_value
, max_value
388 def table_exact_value(self
, index
):
389 min_value
, max_value
= self
.table_value_exact_range(index
)
393 def __post_init__(self
):
394 # called by the autogenerated __init__
395 assert self
.io_width
>= 1
396 assert self
.extra_precision
>= 0
397 assert self
.table_addr_bits
>= 1
398 assert self
.table_data_bits
>= 1
399 assert self
.iter_count
>= 1
401 for addr
in range(1 << self
.table_addr_bits
):
402 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
403 self
.table_data_bits
,
405 # we have to use object.__setattr__ since frozen=True
406 object.__setattr
__(self
, "table", tuple(table
))
407 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
410 def expanded_width(self
):
411 """the total number of bits of precision used inside the algorithm."""
412 return self
.io_width
+ self
.extra_precision
415 def max_neps(self
, i
):
416 """maximum value of `neps[i]`.
417 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
419 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
420 return Fraction(1, 1 << self
.expanded_width
)
423 def max_deps(self
, i
):
424 """maximum value of `deps[i]`.
425 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
427 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
428 return Fraction(1, 1 << self
.expanded_width
)
431 def max_feps(self
, i
):
432 """maximum value of `feps[i]`.
433 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
435 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
436 # zero, because the computation of `F_prime[i]` in
437 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
442 """minimum and maximum values of `e[0]`
443 (the relative error in `F_prime[-1]`)
447 for addr
in range(self
.table_addr_count
):
448 # `F_prime[-1] = (1 - e[0]) / B`
449 # => `e[0] = 1 - B * F_prime[-1]`
450 min_b
, max_b
= self
.table_input_exact_range(addr
)
451 f_prime_m1
= self
.table
[addr
].as_fraction()
452 assert min_b
>= 0 and f_prime_m1
>= 0, \
453 "only positive quadrant of interval multiplication implemented"
454 min_product
= min_b
* f_prime_m1
455 max_product
= max_b
* f_prime_m1
456 # negation swaps min/max
457 cur_min_e0
= 1 - max_product
458 cur_max_e0
= 1 - min_product
459 min_e0
= min(min_e0
, cur_min_e0
)
460 max_e0
= max(max_e0
, cur_max_e0
)
461 min_e0
= self
._shrink
_min
(min_e0
)
462 max_e0
= self
._shrink
_max
(max_e0
)
463 return min_e0
, max_e0
467 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
469 min_e0
, max_e0
= self
.e0_range
474 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
476 min_e0
, max_e0
= self
.e0_range
480 def max_abs_e0(self
):
481 """maximum value of `abs(e[0])`."""
482 return max(abs(self
.min_e0
), abs(self
.max_e0
))
485 def min_abs_e0(self
):
486 """minimum value of `abs(e[0])`."""
491 """maximum value of `n[i]` (the relative error in `N_prime[i]`
492 relative to the previous iteration)
494 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
497 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
498 # `n[0] <= 2 * neps[0] / (1 - e[0])`
500 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
501 "only one quadrant of interval division implemented"
502 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
505 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
506 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
507 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
508 "only one quadrant of interval multiplication implemented"
509 prod
= (1 - self
.max_f(0)) * min_mpd
510 assert self
.max_neps(1) >= 0 and prod
> 0, \
511 "only one quadrant of interval division implemented"
512 retval
= self
.max_neps(1) / prod
515 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
516 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
517 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
518 "only one quadrant of interval division implemented"
519 retval
= self
.max_neps(i
) / min_mpd
521 return self
._shrink
_max
(retval
)
525 """maximum value of `d[i]` (the relative error in `D_prime[i]`
526 relative to the previous iteration)
528 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
531 # `d[0] = deps[0] / (1 - e[0])`
533 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
534 "only one quadrant of interval division implemented"
535 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
538 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
539 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
540 "only one quadrant of interval multiplication implemented"
541 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
542 assert self
.max_deps(1) >= 0 and divisor
> 0, \
543 "only one quadrant of interval division implemented"
544 retval
= self
.max_deps(1) / divisor
547 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
548 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
549 "only one quadrant of interval division implemented"
550 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
552 return self
._shrink
_max
(retval
)
556 """maximum value of `f[i]` (the relative error in `F_prime[i]`
557 relative to the previous iteration)
559 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
562 # `f[0] = feps[0] / (1 - delta[0])`
564 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
565 "only one quadrant of interval division implemented"
566 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
570 retval
= self
.max_feps(1)
573 # `f[i] <= max_feps[i]`
574 retval
= self
.max_feps(i
)
576 return self
._shrink
_max
(retval
)
579 def max_delta(self
, i
):
580 """ maximum value of `delta[i]`.
581 `delta[i]` is defined in Definition 4 of paper.
583 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
585 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
586 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
588 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
589 prev_max_delta
= self
.max_delta(i
- 1)
590 assert prev_max_delta
>= 0
591 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
593 # `delta[i]` has to be smaller than one otherwise errors would go off
595 _assert_accuracy(retval
< 1)
597 return self
._shrink
_max
(retval
)
601 """ maximum value of `pi[i]`.
602 `pi[i]` is defined right below Theorem 5 of paper.
604 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
605 # `pi[i] = 1 - (1 - n[i]) * prod`
606 # where `prod` is the product of,
607 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
608 min_prod
= Fraction(1)
610 max_n_j
= self
.max_n(j
)
611 max_d_j
= self
.max_d(j
)
612 assert max_n_j
<= 1 and max_d_j
> -1, \
613 "only one quadrant of interval division implemented"
614 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
615 max_n_i
= self
.max_n(i
)
616 assert max_n_i
<= 1 and min_prod
>= 0, \
617 "only one quadrant of interval multiplication implemented"
618 retval
= 1 - (1 - max_n_i
) * min_prod
619 return self
._shrink
_max
(retval
)
622 def max_n_shift(self
):
623 """ maximum value of `state.n_shift`.
625 # input numerator is `2*io_width`-bits
626 max_n
= (1 << (self
.io_width
* 2)) - 1
628 # normalize so 1 <= n < 2
634 def __make_ops(self
):
635 """ Goldschmidt division algorithm.
638 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
639 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
640 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
642 yields: GoldschmidtDivOp
643 the operations needed to perform the division.
645 # establish assumptions of the paper's error analysis (section 3.1):
647 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
648 yield GoldschmidtDivOp
.Normalize
650 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
651 # the assumption is met by multipliers with > 4-bits precision
652 _assert_accuracy(self
.expanded_width
> 4)
654 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
655 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
656 + self
.max_f(0) < Fraction(1, 2))
658 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
659 # (B is the denominator)
661 for addr
in range(self
.table_addr_count
):
662 f_prime_m1
= self
.table
[addr
]
663 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
665 yield GoldschmidtDivOp
.FEqTableLookup
667 # we use Setting I (section 4.1 of the paper):
668 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
670 for i
in range(self
.iter_count
):
671 _assert_accuracy(self
.max_f(i
) == 0)
672 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
673 yield GoldschmidtDivOp
.MulNByF
674 if i
!= self
.iter_count
- 1:
675 yield GoldschmidtDivOp
.MulDByF
676 yield GoldschmidtDivOp
.FEq2MinusD
678 # relative approximation error `p(N_prime[i])`:
679 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
680 # `0 <= p(N_prime[i])`
681 # `p(N_prime[i]) <= (2 * i) * n_hat \`
682 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
683 i
= self
.iter_count
- 1 # last used `i`
684 # compute power manually to prevent huge intermediate values
685 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * n_hat
/ 2)
687 power
= self
._shrink
_max
(power
* power
)
689 max_rel_error
= (2 * i
) * n_hat
+ power
691 min_a_over_b
= Fraction(1, 2)
692 max_a_over_b
= Fraction(2)
693 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
694 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
696 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
697 f
"not accurate enough: max_rel_error={max_rel_error}"
698 f
" max_allowed_rel_error={max_allowed_rel_error}")
700 yield GoldschmidtDivOp
.CalcResult
704 """ find efficient parameters for a goldschmidt division algorithm
705 with `params.io_width == io_width`.
707 assert isinstance(io_width
, int) and io_width
>= 1
710 for extra_precision
in range(io_width
* 2 + 4):
711 for table_addr_bits
in range(1, 7 + 1):
712 table_data_bits
= io_width
+ extra_precision
713 for iter_count
in range(1, 2 * io_width
.bit_length()):
715 return GoldschmidtDivParams(
717 extra_precision
=extra_precision
,
718 table_addr_bits
=table_addr_bits
,
719 table_data_bits
=table_data_bits
,
720 iter_count
=iter_count
)
721 except ParamsNotAccurateEnough
as e
:
722 last_params
= (f
"GoldschmidtDivParams("
723 f
"io_width={io_width!r}, "
724 f
"extra_precision={extra_precision!r}, "
725 f
"table_addr_bits={table_addr_bits!r}, "
726 f
"table_data_bits={table_data_bits!r}, "
727 f
"iter_count={iter_count!r})")
729 raise ValueError(f
"can't find working parameters for a goldschmidt "
730 f
"division algorithm: last params: {last_params}"
735 class GoldschmidtDivOp(enum
.Enum
):
736 Normalize
= "n, d, n_shift = normalize(n, d)"
737 FEqTableLookup
= "f = table_lookup(d)"
740 FEq2MinusD
= "f = 2 - d"
741 CalcResult
= "result = unnormalize_and_round(n)"
743 def run(self
, params
, state
):
744 assert isinstance(params
, GoldschmidtDivParams
)
745 assert isinstance(state
, GoldschmidtDivState
)
746 expanded_width
= params
.expanded_width
747 table_addr_bits
= params
.table_addr_bits
748 if self
== GoldschmidtDivOp
.Normalize
:
749 # normalize so 1 <= d < 2
750 # can easily be done with count-leading-zeros and left shift
752 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
753 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
756 # normalize so 1 <= n < 2
758 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
760 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
761 # compute initial f by table lookup
763 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
764 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
765 state
.f
= params
.table
[d_m_1
.bits
]
766 elif self
== GoldschmidtDivOp
.MulNByF
:
767 assert state
.f
is not None
768 n
= state
.n
* state
.f
769 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
770 elif self
== GoldschmidtDivOp
.MulDByF
:
771 assert state
.f
is not None
772 d
= state
.d
* state
.f
773 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
774 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
775 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
776 elif self
== GoldschmidtDivOp
.CalcResult
:
777 assert state
.n_shift
is not None
778 # scale to correct value
779 n
= state
.n
* (1 << state
.n_shift
)
781 state
.quotient
= math
.floor(n
)
782 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
783 if state
.remainder
>= state
.orig_d
:
785 state
.remainder
-= state
.orig_d
787 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
790 def goldschmidt_div(n
, d
, params
):
791 """ Goldschmidt division algorithm.
794 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
795 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
796 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
800 numerator. a `2*width`-bit unsigned integer.
801 must be less than `d << width`, otherwise the quotient wouldn't
804 denominator. a `width`-bit unsigned integer. must not be zero.
806 the bit-width of the inputs/outputs. must be a positive integer.
808 returns: tuple[int, int]
809 the quotient and remainder. a tuple of two `width`-bit unsigned
812 assert isinstance(params
, GoldschmidtDivParams
)
813 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
814 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
816 # this whole algorithm is done with fixed-point arithmetic where values
817 # have `width` fractional bits
819 state
= GoldschmidtDivState(
822 n
=FixedPoint(n
, params
.io_width
),
823 d
=FixedPoint(d
, params
.io_width
),
826 for op
in params
.ops
:
827 op
.run(params
, state
)
829 assert state
.quotient
is not None
830 assert state
.remainder
is not None
832 return state
.quotient
, state
.remainder