1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
11 from types
import FunctionType
14 from functools
import cached_property
16 from cached_property
import cached_property
18 # fix broken IDE type detection for cached_property
19 from typing
import TYPE_CHECKING
21 from functools
import cached_property
27 def cache_on_self(func
):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
32 assert isinstance(func
, FunctionType
), \
33 "non-plain methods are not supported"
35 cache_name
= func
.__name
__ + "__cache"
37 def wrapper(self
, *args
, **kwargs
):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
40 if cache
is _NOT_FOUND
:
41 self
.__dict
__[cache_name
] = cache
= {}
42 key
= (args
, *kwargs
.items())
43 retval
= cache
.get(key
, _NOT_FOUND
)
44 if retval
is _NOT_FOUND
:
45 retval
= func(self
, *args
, **kwargs
)
49 wrapper
.__doc
__ = func
.__doc
__
54 class RoundDir(enum
.Enum
):
57 NEAREST_TIES_UP
= enum
.auto()
58 ERROR_IF_INEXACT
= enum
.auto()
61 @dataclass(frozen
=True)
66 def __post_init__(self
):
67 assert isinstance(self
.bits
, int)
68 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value
, FixedPoint
):
76 if isinstance(value
, int):
77 return FixedPoint(value
, 0)
78 if isinstance(value
, str):
80 neg
= value
.startswith("-")
81 if neg
or value
.startswith("+"):
83 if value
.startswith(("0x", "0X")) and "." in value
:
93 raise ValueError("too many `.` in string")
98 if not digit
.isalnum():
99 raise ValueError("invalid hexadecimal digit")
101 bits |
= int("0x" + digit
, base
=16)
103 bits
= int(value
, base
=0)
107 return FixedPoint(bits
, frac_wid
)
109 if isinstance(value
, float):
110 n
, d
= value
.as_integer_ratio()
111 log2_d
= d
.bit_length() - 1
112 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n
, log2_d
)
115 raise TypeError("can't convert type to FixedPoint")
118 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid
, int) and frac_wid
>= 0
122 assert isinstance(round_dir
, RoundDir
)
123 if isinstance(value
, Fraction
):
124 numerator
= value
.numerator
125 denominator
= value
.denominator
127 value
= FixedPoint
.cast(value
)
128 numerator
= value
.bits
129 denominator
= 1 << value
.frac_wid
131 numerator
= -numerator
132 denominator
= -denominator
133 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
134 if round_dir
== RoundDir
.DOWN
:
136 elif round_dir
== RoundDir
.UP
:
139 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
140 if remainder
* 2 >= denominator
:
142 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
144 raise ValueError("inexact conversion")
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits
, frac_wid
)
149 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self
.bits
/ (1 << self
.frac_wid
))
159 def as_fraction(self
):
160 return Fraction(self
.bits
, 1 << self
.frac_wid
)
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs
= FixedPoint
.cast(rhs
)
167 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
168 lhs
= self
.to_frac_wid(common_frac_wid
)
169 rhs
= rhs
.to_frac_wid(common_frac_wid
)
170 return lhs
.bits
- rhs
.bits
172 def __eq__(self
, rhs
):
173 return self
.cmp(rhs
) == 0
175 def __ne__(self
, rhs
):
176 return self
.cmp(rhs
) != 0
178 def __gt__(self
, rhs
):
179 return self
.cmp(rhs
) > 0
181 def __lt__(self
, rhs
):
182 return self
.cmp(rhs
) < 0
184 def __ge__(self
, rhs
):
185 return self
.cmp(rhs
) >= 0
187 def __le__(self
, rhs
):
188 return self
.cmp(rhs
) <= 0
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
194 fract_mask
= (1 << self
.frac_wid
) - 1
195 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
199 return "-" + str(-self
)
201 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
202 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
203 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
204 return hex(math
.floor(self
)) + "." + frac_str
207 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
209 def __add__(self
, rhs
):
210 rhs
= FixedPoint
.cast(rhs
)
211 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
212 lhs
= self
.to_frac_wid(common_frac_wid
)
213 rhs
= rhs
.to_frac_wid(common_frac_wid
)
214 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
216 def __radd__(self
, lhs
):
218 return self
.__add
__(lhs
)
221 return FixedPoint(-self
.bits
, self
.frac_wid
)
223 def __sub__(self
, rhs
):
224 rhs
= FixedPoint
.cast(rhs
)
225 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
226 lhs
= self
.to_frac_wid(common_frac_wid
)
227 rhs
= rhs
.to_frac_wid(common_frac_wid
)
228 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
230 def __rsub__(self
, lhs
):
232 return -self
.__sub
__(lhs
)
234 def __mul__(self
, rhs
):
235 rhs
= FixedPoint
.cast(rhs
)
236 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
238 def __rmul__(self
, lhs
):
240 return self
.__mul
__(lhs
)
243 return self
.bits
>> self
.frac_wid
247 class GoldschmidtDivState
:
249 """original numerator"""
252 """original denominator"""
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
260 f
: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
263 quotient
: "int | None" = None
266 remainder
: "int | None" = None
267 """final remainder"""
269 n_shift
: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
280 def _assert_accuracy(condition
, msg
="not accurate enough"):
283 raise ParamsNotAccurateEnough(msg
)
286 @dataclass(frozen
=True, unsafe_hash
=True)
287 class GoldschmidtDivParams
:
288 """parameters for a Goldschmidt division algorithm.
289 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
298 """number of bits of additional precision used inside the algorithm."""
301 """the number of address bits used in the lookup-table."""
304 """the number of data bits used in the lookup-table."""
307 """the total number of iterations of the division algorithm's loop"""
309 # tuple to be immutable, default so repr() works for debugging even when
310 # __post_init__ hasn't finished running yet
311 table
: "tuple[FixedPoint, ...]" = field(init
=False, default
=NotImplemented)
312 """the lookup-table"""
314 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False,
315 default
=NotImplemented)
316 """the operations needed to perform the goldschmidt division algorithm."""
318 def _shrink_bound(self
, bound
, round_dir
):
319 """prevent fractions from having huge numerators/denominators by
320 rounding to a `FixedPoint` and converting back to a `Fraction`.
322 This is intended only for values used to compute bounds, and not for
323 values that end up in the hardware.
325 assert isinstance(bound
, (Fraction
, int))
326 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
327 "you shouldn't use that round_dir on bounds"
328 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
329 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
330 return fixed
.as_fraction()
332 def _shrink_min(self
, min_bound
):
333 """prevent fractions used as minimum bounds from having huge
334 numerators/denominators by rounding down to a `FixedPoint` and
335 converting back to a `Fraction`.
337 This is intended only for values used to compute bounds, and not for
338 values that end up in the hardware.
340 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
342 def _shrink_max(self
, max_bound
):
343 """prevent fractions used as maximum bounds from having huge
344 numerators/denominators by rounding up to a `FixedPoint` and
345 converting back to a `Fraction`.
347 This is intended only for values used to compute bounds, and not for
348 values that end up in the hardware.
350 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
353 def table_addr_count(self
):
354 """number of distinct addresses in the lookup-table."""
355 # used while computing self.table, so can't just do len(self.table)
356 return 1 << self
.table_addr_bits
358 def table_input_exact_range(self
, addr
):
359 """return the range of inputs as `Fraction`s used for the table entry
360 with address `addr`."""
361 assert isinstance(addr
, int)
362 assert 0 <= addr
< self
.table_addr_count
363 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
364 addr_shift
= self
.io_width
- self
.table_addr_bits
365 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
366 denominator
= 1 << self
.io_width
367 values_per_table_entry
= 1 << addr_shift
368 max_numerator
= min_numerator
+ values_per_table_entry
- 1
369 min_input
= Fraction(min_numerator
, denominator
)
370 max_input
= Fraction(max_numerator
, denominator
)
371 min_input
= self
._shrink
_min
(min_input
)
372 max_input
= self
._shrink
_max
(max_input
)
373 assert 1 <= min_input
<= max_input
< 2
374 return min_input
, max_input
376 def table_value_exact_range(self
, addr
):
377 """return the range of values as `Fraction`s used for the table entry
378 with address `addr`."""
379 min_input
, max_input
= self
.table_input_exact_range(addr
)
380 # division swaps min/max
381 min_value
= 1 / max_input
382 max_value
= 1 / min_input
383 min_value
= self
._shrink
_min
(min_value
)
384 max_value
= self
._shrink
_max
(max_value
)
385 assert 0.5 < min_value
<= max_value
<= 1
386 return min_value
, max_value
388 def table_exact_value(self
, index
):
389 min_value
, max_value
= self
.table_value_exact_range(index
)
393 def __post_init__(self
):
394 # called by the autogenerated __init__
395 assert self
.io_width
>= 1
396 assert self
.extra_precision
>= 0
397 assert self
.table_addr_bits
>= 1
398 assert self
.table_data_bits
>= 1
399 assert self
.iter_count
>= 1
401 for addr
in range(1 << self
.table_addr_bits
):
402 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
403 self
.table_data_bits
,
405 # we have to use object.__setattr__ since frozen=True
406 object.__setattr
__(self
, "table", tuple(table
))
407 object.__setattr
__(self
, "ops", tuple(_goldschmidt_div_ops(self
)))
411 """ find efficient parameters for a goldschmidt division algorithm
412 with `params.io_width == io_width`.
414 assert isinstance(io_width
, int) and io_width
>= 1
417 for extra_precision
in range(io_width
* 2 + 4):
418 for table_addr_bits
in range(1, 7 + 1):
419 table_data_bits
= io_width
+ extra_precision
420 for iter_count
in range(1, 2 * io_width
.bit_length()):
422 return GoldschmidtDivParams(
424 extra_precision
=extra_precision
,
425 table_addr_bits
=table_addr_bits
,
426 table_data_bits
=table_data_bits
,
427 iter_count
=iter_count
)
428 except ParamsNotAccurateEnough
as e
:
429 last_params
= (f
"GoldschmidtDivParams("
430 f
"io_width={io_width!r}, "
431 f
"extra_precision={extra_precision!r}, "
432 f
"table_addr_bits={table_addr_bits!r}, "
433 f
"table_data_bits={table_data_bits!r}, "
434 f
"iter_count={iter_count!r})")
436 raise ValueError(f
"can't find working parameters for a goldschmidt "
437 f
"division algorithm: last params: {last_params}"
441 def expanded_width(self
):
442 """the total number of bits of precision used inside the algorithm."""
443 return self
.io_width
+ self
.extra_precision
446 def max_neps(self
, i
):
447 """maximum value of `neps[i]`.
448 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
450 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
451 return Fraction(1, 1 << self
.expanded_width
)
454 def max_deps(self
, i
):
455 """maximum value of `deps[i]`.
456 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
458 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
459 return Fraction(1, 1 << self
.expanded_width
)
462 def max_feps(self
, i
):
463 """maximum value of `feps[i]`.
464 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
466 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
467 # zero, because the computation of `F_prime[i]` in
468 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
473 """minimum and maximum values of `e[0]`
474 (the relative error in `F_prime[-1]`)
478 for addr
in range(self
.table_addr_count
):
479 # `F_prime[-1] = (1 - e[0]) / B`
480 # => `e[0] = 1 - B * F_prime[-1]`
481 min_b
, max_b
= self
.table_input_exact_range(addr
)
482 f_prime_m1
= self
.table
[addr
].as_fraction()
483 assert min_b
>= 0 and f_prime_m1
>= 0, \
484 "only positive quadrant of interval multiplication implemented"
485 min_product
= min_b
* f_prime_m1
486 max_product
= max_b
* f_prime_m1
487 # negation swaps min/max
488 cur_min_e0
= 1 - max_product
489 cur_max_e0
= 1 - min_product
490 min_e0
= min(min_e0
, cur_min_e0
)
491 max_e0
= max(max_e0
, cur_max_e0
)
492 min_e0
= self
._shrink
_min
(min_e0
)
493 max_e0
= self
._shrink
_max
(max_e0
)
494 return min_e0
, max_e0
498 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
500 min_e0
, max_e0
= self
.e0_range
505 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
507 min_e0
, max_e0
= self
.e0_range
511 def max_abs_e0(self
):
512 """maximum value of `abs(e[0])`."""
513 return max(abs(self
.min_e0
), abs(self
.max_e0
))
516 def min_abs_e0(self
):
517 """minimum value of `abs(e[0])`."""
522 """maximum value of `n[i]` (the relative error in `N_prime[i]`
523 relative to the previous iteration)
525 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
528 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
529 # `n[0] <= 2 * neps[0] / (1 - e[0])`
531 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
532 "only one quadrant of interval division implemented"
533 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
536 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
537 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
538 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
539 "only one quadrant of interval multiplication implemented"
540 prod
= (1 - self
.max_f(0)) * min_mpd
541 assert self
.max_neps(1) >= 0 and prod
> 0, \
542 "only one quadrant of interval division implemented"
543 retval
= self
.max_neps(1) / prod
546 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
547 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
548 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
549 "only one quadrant of interval division implemented"
550 retval
= self
.max_neps(i
) / min_mpd
552 return self
._shrink
_max
(retval
)
556 """maximum value of `d[i]` (the relative error in `D_prime[i]`
557 relative to the previous iteration)
559 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
562 # `d[0] = deps[0] / (1 - e[0])`
564 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
565 "only one quadrant of interval division implemented"
566 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
569 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
570 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
571 "only one quadrant of interval multiplication implemented"
572 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
573 assert self
.max_deps(1) >= 0 and divisor
> 0, \
574 "only one quadrant of interval division implemented"
575 retval
= self
.max_deps(1) / divisor
578 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
579 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
580 "only one quadrant of interval division implemented"
581 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
583 return self
._shrink
_max
(retval
)
587 """maximum value of `f[i]` (the relative error in `F_prime[i]`
588 relative to the previous iteration)
590 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
593 # `f[0] = feps[0] / (1 - delta[0])`
595 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
596 "only one quadrant of interval division implemented"
597 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
601 retval
= self
.max_feps(1)
604 # `f[i] <= max_feps[i]`
605 retval
= self
.max_feps(i
)
607 return self
._shrink
_max
(retval
)
610 def max_delta(self
, i
):
611 """ maximum value of `delta[i]`.
612 `delta[i]` is defined in Definition 4 of paper.
614 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
616 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
617 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
619 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
620 prev_max_delta
= self
.max_delta(i
- 1)
621 assert prev_max_delta
>= 0
622 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
624 # `delta[i]` has to be smaller than one otherwise errors would go off
626 _assert_accuracy(retval
< 1)
628 return self
._shrink
_max
(retval
)
632 """ maximum value of `pi[i]`.
633 `pi[i]` is defined right below Theorem 5 of paper.
635 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
636 # `pi[i] = 1 - (1 - n[i]) * prod`
637 # where `prod` is the product of,
638 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
639 min_prod
= Fraction(1)
641 max_n_j
= self
.max_n(j
)
642 max_d_j
= self
.max_d(j
)
643 assert max_n_j
<= 1 and max_d_j
> -1, \
644 "only one quadrant of interval division implemented"
645 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
646 max_n_i
= self
.max_n(i
)
647 assert max_n_i
<= 1 and min_prod
>= 0, \
648 "only one quadrant of interval multiplication implemented"
649 retval
= 1 - (1 - max_n_i
) * min_prod
650 return self
._shrink
_max
(retval
)
653 def max_n_shift(self
):
654 """ maximum value of `state.n_shift`.
656 # input numerator is `2*io_width`-bits
657 max_n
= (1 << (self
.io_width
* 2)) - 1
659 # normalize so 1 <= n < 2
667 class GoldschmidtDivOp(enum
.Enum
):
668 Normalize
= "n, d, n_shift = normalize(n, d)"
669 FEqTableLookup
= "f = table_lookup(d)"
672 FEq2MinusD
= "f = 2 - d"
673 CalcResult
= "result = unnormalize_and_round(n)"
675 def run(self
, params
, state
):
676 assert isinstance(params
, GoldschmidtDivParams
)
677 assert isinstance(state
, GoldschmidtDivState
)
678 expanded_width
= params
.expanded_width
679 table_addr_bits
= params
.table_addr_bits
680 if self
== GoldschmidtDivOp
.Normalize
:
681 # normalize so 1 <= d < 2
682 # can easily be done with count-leading-zeros and left shift
684 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
685 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
688 # normalize so 1 <= n < 2
690 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
692 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
693 # compute initial f by table lookup
695 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
696 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
697 state
.f
= params
.table
[d_m_1
.bits
]
698 elif self
== GoldschmidtDivOp
.MulNByF
:
699 assert state
.f
is not None
700 n
= state
.n
* state
.f
701 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
702 elif self
== GoldschmidtDivOp
.MulDByF
:
703 assert state
.f
is not None
704 d
= state
.d
* state
.f
705 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
706 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
707 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
708 elif self
== GoldschmidtDivOp
.CalcResult
:
709 assert state
.n_shift
is not None
710 # scale to correct value
711 n
= state
.n
* (1 << state
.n_shift
)
713 state
.quotient
= math
.floor(n
)
714 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
715 if state
.remainder
>= state
.orig_d
:
717 state
.remainder
-= state
.orig_d
719 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
722 def _goldschmidt_div_ops(params
):
723 """ Goldschmidt division algorithm.
726 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
727 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
728 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
731 params: GoldschmidtDivParams
732 the parameters for the algorithm
734 yields: GoldschmidtDivOp
735 the operations needed to perform the division.
737 assert isinstance(params
, GoldschmidtDivParams
)
739 # establish assumptions of the paper's error analysis (section 3.1):
741 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
742 yield GoldschmidtDivOp
.Normalize
744 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
745 # the assumption is met by multipliers with > 4-bits precision
746 _assert_accuracy(params
.expanded_width
> 4)
748 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
749 _assert_accuracy(params
.max_abs_e0
+ 3 * params
.max_d(0) / 2
750 + params
.max_f(0) < Fraction(1, 2))
752 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
753 # (B is the denominator)
755 for addr
in range(params
.table_addr_count
):
756 f_prime_m1
= params
.table
[addr
]
757 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
759 yield GoldschmidtDivOp
.FEqTableLookup
761 # we use Setting I (section 4.1 of the paper):
762 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
764 for i
in range(params
.iter_count
):
765 _assert_accuracy(params
.max_f(i
) == 0)
766 n_hat
= max(n_hat
, params
.max_n(i
), params
.max_d(i
))
767 yield GoldschmidtDivOp
.MulNByF
768 if i
!= params
.iter_count
- 1:
769 yield GoldschmidtDivOp
.MulDByF
770 yield GoldschmidtDivOp
.FEq2MinusD
772 # relative approximation error `p(N_prime[i])`:
773 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
774 # `0 <= p(N_prime[i])`
775 # `p(N_prime[i]) <= (2 * i) * n_hat \`
776 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
777 i
= params
.iter_count
- 1 # last used `i`
778 # compute power manually to prevent huge intermediate values
779 power
= params
._shrink
_max
(params
.max_abs_e0
+ 3 * n_hat
/ 2)
781 power
= params
._shrink
_max
(power
* power
)
783 max_rel_error
= (2 * i
) * n_hat
+ power
785 min_a_over_b
= Fraction(1, 2)
786 max_a_over_b
= Fraction(2)
787 max_allowed_abs_error
= max_a_over_b
/ (1 << params
.max_n_shift
)
788 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
790 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
791 f
"not accurate enough: max_rel_error={max_rel_error} "
792 f
"max_allowed_rel_error={max_allowed_rel_error}")
794 yield GoldschmidtDivOp
.CalcResult
797 def goldschmidt_div(n
, d
, params
):
798 """ Goldschmidt division algorithm.
801 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
802 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
803 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
807 numerator. a `2*width`-bit unsigned integer.
808 must be less than `d << width`, otherwise the quotient wouldn't
811 denominator. a `width`-bit unsigned integer. must not be zero.
813 the bit-width of the inputs/outputs. must be a positive integer.
815 returns: tuple[int, int]
816 the quotient and remainder. a tuple of two `width`-bit unsigned
819 assert isinstance(params
, GoldschmidtDivParams
)
820 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
821 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
823 # this whole algorithm is done with fixed-point arithmetic where values
824 # have `width` fractional bits
826 state
= GoldschmidtDivState(
829 n
=FixedPoint(n
, params
.io_width
),
830 d
=FixedPoint(d
, params
.io_width
),
833 for op
in params
.ops
:
834 op
.run(params
, state
)
836 assert state
.quotient
is not None
837 assert state
.remainder
is not None
839 return state
.quotient
, state
.remainder