1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
, fields
, replace
11 from fractions
import Fraction
12 from types
import FunctionType
13 from functools
import lru_cache
16 from functools
import cached_property
18 from cached_property
import cached_property
20 # fix broken IDE type detection for cached_property
21 from typing
import TYPE_CHECKING
, Any
23 from functools
import cached_property
29 def cache_on_self(func
):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
34 assert isinstance(func
, FunctionType
), \
35 "non-plain methods are not supported"
37 cache_name
= func
.__name
__ + "__cache"
39 def wrapper(self
, *args
, **kwargs
):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
42 if cache
is _NOT_FOUND
:
43 self
.__dict
__[cache_name
] = cache
= {}
44 key
= (args
, *kwargs
.items())
45 retval
= cache
.get(key
, _NOT_FOUND
)
46 if retval
is _NOT_FOUND
:
47 retval
= func(self
, *args
, **kwargs
)
51 wrapper
.__doc
__ = func
.__doc
__
56 class RoundDir(enum
.Enum
):
59 NEAREST_TIES_UP
= enum
.auto()
60 ERROR_IF_INEXACT
= enum
.auto()
63 @dataclass(frozen
=True)
68 def __post_init__(self
):
69 assert isinstance(self
.bits
, int)
70 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value
, FixedPoint
):
78 if isinstance(value
, int):
79 return FixedPoint(value
, 0)
80 if isinstance(value
, str):
82 neg
= value
.startswith("-")
83 if neg
or value
.startswith("+"):
85 if value
.startswith(("0x", "0X")) and "." in value
:
95 raise ValueError("too many `.` in string")
100 if not digit
.isalnum():
101 raise ValueError("invalid hexadecimal digit")
103 bits |
= int("0x" + digit
, base
=16)
105 bits
= int(value
, base
=0)
109 return FixedPoint(bits
, frac_wid
)
111 if isinstance(value
, float):
112 n
, d
= value
.as_integer_ratio()
113 log2_d
= d
.bit_length() - 1
114 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n
, log2_d
)
117 raise TypeError("can't convert type to FixedPoint")
120 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid
, int) and frac_wid
>= 0
124 assert isinstance(round_dir
, RoundDir
)
125 if isinstance(value
, Fraction
):
126 numerator
= value
.numerator
127 denominator
= value
.denominator
129 value
= FixedPoint
.cast(value
)
130 numerator
= value
.bits
131 denominator
= 1 << value
.frac_wid
133 numerator
= -numerator
134 denominator
= -denominator
135 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
136 if round_dir
== RoundDir
.DOWN
:
138 elif round_dir
== RoundDir
.UP
:
141 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
142 if remainder
* 2 >= denominator
:
144 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
146 raise ValueError("inexact conversion")
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits
, frac_wid
)
151 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self
.bits
/ (1 << self
.frac_wid
))
161 def as_fraction(self
):
162 return Fraction(self
.bits
, 1 << self
.frac_wid
)
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs
= FixedPoint
.cast(rhs
)
169 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
170 lhs
= self
.to_frac_wid(common_frac_wid
)
171 rhs
= rhs
.to_frac_wid(common_frac_wid
)
172 return lhs
.bits
- rhs
.bits
174 def __eq__(self
, rhs
):
175 return self
.cmp(rhs
) == 0
177 def __ne__(self
, rhs
):
178 return self
.cmp(rhs
) != 0
180 def __gt__(self
, rhs
):
181 return self
.cmp(rhs
) > 0
183 def __lt__(self
, rhs
):
184 return self
.cmp(rhs
) < 0
186 def __ge__(self
, rhs
):
187 return self
.cmp(rhs
) >= 0
189 def __le__(self
, rhs
):
190 return self
.cmp(rhs
) <= 0
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
196 fract_mask
= (1 << self
.frac_wid
) - 1
197 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
201 return "-" + str(-self
)
203 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
204 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
205 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
206 return hex(math
.floor(self
)) + "." + frac_str
209 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
211 def __add__(self
, rhs
):
212 rhs
= FixedPoint
.cast(rhs
)
213 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
214 lhs
= self
.to_frac_wid(common_frac_wid
)
215 rhs
= rhs
.to_frac_wid(common_frac_wid
)
216 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
218 def __radd__(self
, lhs
):
220 return self
.__add
__(lhs
)
223 return FixedPoint(-self
.bits
, self
.frac_wid
)
225 def __sub__(self
, rhs
):
226 rhs
= FixedPoint
.cast(rhs
)
227 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
228 lhs
= self
.to_frac_wid(common_frac_wid
)
229 rhs
= rhs
.to_frac_wid(common_frac_wid
)
230 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
232 def __rsub__(self
, lhs
):
234 return -self
.__sub
__(lhs
)
236 def __mul__(self
, rhs
):
237 rhs
= FixedPoint
.cast(rhs
)
238 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
240 def __rmul__(self
, lhs
):
242 return self
.__mul
__(lhs
)
245 return self
.bits
>> self
.frac_wid
247 def div(self
, rhs
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
248 assert isinstance(frac_wid
, int) and frac_wid
>= 0
249 assert isinstance(round_dir
, RoundDir
)
250 rhs
= FixedPoint
.cast(rhs
)
251 return FixedPoint
.with_frac_wid(self
.as_fraction()
255 def sqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
256 assert isinstance(round_dir
, RoundDir
)
258 raise ValueError("can't compute sqrt of negative number")
261 retval
= FixedPoint(0, self
.frac_wid
)
262 int_part_wid
= self
.bits
.bit_length() - self
.frac_wid
263 first_bit_index
= -(-int_part_wid
// 2) # division rounds up
264 last_bit_index
= -self
.frac_wid
265 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
266 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
268 if trial
* trial
<= self
:
270 if round_dir
== RoundDir
.DOWN
:
272 elif round_dir
== RoundDir
.UP
:
273 if retval
* retval
< self
:
274 retval
+= FixedPoint(1, self
.frac_wid
)
275 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
276 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
277 if half_way
* half_way
<= self
:
278 retval
+= FixedPoint(1, self
.frac_wid
)
279 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
280 if retval
* retval
!= self
:
281 raise ValueError("inexact sqrt")
283 assert False, "unimplemented round_dir"
286 def rsqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
287 """compute the reciprocal-sqrt of `self`"""
288 assert isinstance(round_dir
, RoundDir
)
290 raise ValueError("can't compute rsqrt of negative number")
292 raise ZeroDivisionError("can't compute rsqrt of zero")
293 retval
= FixedPoint(0, self
.frac_wid
)
294 first_bit_index
= -(-self
.frac_wid
// 2) # division rounds up
295 last_bit_index
= -self
.frac_wid
296 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
297 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
299 if trial
* trial
* self
<= 1:
301 if round_dir
== RoundDir
.DOWN
:
303 elif round_dir
== RoundDir
.UP
:
304 if retval
* retval
* self
< 1:
305 retval
+= FixedPoint(1, self
.frac_wid
)
306 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
307 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
308 if half_way
* half_way
* self
<= 1:
309 retval
+= FixedPoint(1, self
.frac_wid
)
310 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
311 if retval
* retval
* self
!= 1:
312 raise ValueError("inexact rsqrt")
314 assert False, "unimplemented round_dir"
319 class GoldschmidtDivState
:
321 """original numerator"""
324 """original denominator"""
327 """numerator -- N_prime[i] in the paper's algorithm 2"""
330 """denominator -- D_prime[i] in the paper's algorithm 2"""
332 f
: "FixedPoint | None" = None
333 """current factor -- F_prime[i] in the paper's algorithm 2"""
335 quotient
: "int | None" = None
338 remainder
: "int | None" = None
339 """final remainder"""
341 n_shift
: "int | None" = None
342 """amount the numerator needs to be left-shifted at the end of the
347 class ParamsNotAccurateEnough(Exception):
348 """raised when the parameters aren't accurate enough to have goldschmidt
352 def _assert_accuracy(condition
, msg
="not accurate enough"):
355 raise ParamsNotAccurateEnough(msg
)
358 @dataclass(frozen
=True, unsafe_hash
=True)
359 class GoldschmidtDivParamsBase
:
360 """parameters for a Goldschmidt division algorithm, excluding derived
365 """bit-width of the input divisor and the result.
366 the input numerator is `2 * io_width`-bits wide.
370 """number of bits of additional precision used inside the algorithm."""
373 """the number of address bits used in the lookup-table."""
376 """the number of data bits used in the lookup-table."""
379 """the total number of iterations of the division algorithm's loop"""
382 @dataclass(frozen
=True, unsafe_hash
=True)
383 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
384 """parameters for a Goldschmidt division algorithm.
385 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
388 # tuple to be immutable, repr=False so repr() works for debugging even when
389 # __post_init__ hasn't finished running yet
390 table
: "tuple[FixedPoint, ...]" = field(init
=False, repr=False)
391 """the lookup-table"""
393 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False, repr=False)
394 """the operations needed to perform the goldschmidt division algorithm."""
396 def _shrink_bound(self
, bound
, round_dir
):
397 """prevent fractions from having huge numerators/denominators by
398 rounding to a `FixedPoint` and converting back to a `Fraction`.
400 This is intended only for values used to compute bounds, and not for
401 values that end up in the hardware.
403 assert isinstance(bound
, (Fraction
, int))
404 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
405 "you shouldn't use that round_dir on bounds"
406 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
407 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
408 return fixed
.as_fraction()
410 def _shrink_min(self
, min_bound
):
411 """prevent fractions used as minimum bounds from having huge
412 numerators/denominators by rounding down to a `FixedPoint` and
413 converting back to a `Fraction`.
415 This is intended only for values used to compute bounds, and not for
416 values that end up in the hardware.
418 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
420 def _shrink_max(self
, max_bound
):
421 """prevent fractions used as maximum bounds from having huge
422 numerators/denominators by rounding up to a `FixedPoint` and
423 converting back to a `Fraction`.
425 This is intended only for values used to compute bounds, and not for
426 values that end up in the hardware.
428 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
431 def table_addr_count(self
):
432 """number of distinct addresses in the lookup-table."""
433 # used while computing self.table, so can't just do len(self.table)
434 return 1 << self
.table_addr_bits
436 def table_input_exact_range(self
, addr
):
437 """return the range of inputs as `Fraction`s used for the table entry
438 with address `addr`."""
439 assert isinstance(addr
, int)
440 assert 0 <= addr
< self
.table_addr_count
441 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
442 addr_shift
= self
.io_width
- self
.table_addr_bits
443 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
444 denominator
= 1 << self
.io_width
445 values_per_table_entry
= 1 << addr_shift
446 max_numerator
= min_numerator
+ values_per_table_entry
- 1
447 min_input
= Fraction(min_numerator
, denominator
)
448 max_input
= Fraction(max_numerator
, denominator
)
449 min_input
= self
._shrink
_min
(min_input
)
450 max_input
= self
._shrink
_max
(max_input
)
451 assert 1 <= min_input
<= max_input
< 2
452 return min_input
, max_input
454 def table_value_exact_range(self
, addr
):
455 """return the range of values as `Fraction`s used for the table entry
456 with address `addr`."""
457 min_input
, max_input
= self
.table_input_exact_range(addr
)
458 # division swaps min/max
459 min_value
= 1 / max_input
460 max_value
= 1 / min_input
461 min_value
= self
._shrink
_min
(min_value
)
462 max_value
= self
._shrink
_max
(max_value
)
463 assert 0.5 < min_value
<= max_value
<= 1
464 return min_value
, max_value
466 def table_exact_value(self
, index
):
467 min_value
, max_value
= self
.table_value_exact_range(index
)
471 def __post_init__(self
):
472 # called by the autogenerated __init__
473 _assert_accuracy(self
.io_width
>= 1, "io_width out of range")
474 _assert_accuracy(self
.extra_precision
>= 0,
475 "extra_precision out of range")
476 _assert_accuracy(self
.table_addr_bits
>= 1,
477 "table_addr_bits out of range")
478 _assert_accuracy(self
.table_data_bits
>= 1,
479 "table_data_bits out of range")
480 _assert_accuracy(self
.iter_count
>= 1, "iter_count out of range")
482 for addr
in range(1 << self
.table_addr_bits
):
483 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
484 self
.table_data_bits
,
486 # we have to use object.__setattr__ since frozen=True
487 object.__setattr
__(self
, "table", tuple(table
))
488 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
491 def expanded_width(self
):
492 """the total number of bits of precision used inside the algorithm."""
493 return self
.io_width
+ self
.extra_precision
496 def max_neps(self
, i
):
497 """maximum value of `neps[i]`.
498 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
500 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
501 return Fraction(1, 1 << self
.expanded_width
)
504 def max_deps(self
, i
):
505 """maximum value of `deps[i]`.
506 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
508 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
509 return Fraction(1, 1 << self
.expanded_width
)
512 def max_feps(self
, i
):
513 """maximum value of `feps[i]`.
514 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
516 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
517 # zero, because the computation of `F_prime[i]` in
518 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
523 """minimum and maximum values of `e[0]`
524 (the relative error in `F_prime[-1]`)
528 for addr
in range(self
.table_addr_count
):
529 # `F_prime[-1] = (1 - e[0]) / B`
530 # => `e[0] = 1 - B * F_prime[-1]`
531 min_b
, max_b
= self
.table_input_exact_range(addr
)
532 f_prime_m1
= self
.table
[addr
].as_fraction()
533 assert min_b
>= 0 and f_prime_m1
>= 0, \
534 "only positive quadrant of interval multiplication implemented"
535 min_product
= min_b
* f_prime_m1
536 max_product
= max_b
* f_prime_m1
537 # negation swaps min/max
538 cur_min_e0
= 1 - max_product
539 cur_max_e0
= 1 - min_product
540 min_e0
= min(min_e0
, cur_min_e0
)
541 max_e0
= max(max_e0
, cur_max_e0
)
542 min_e0
= self
._shrink
_min
(min_e0
)
543 max_e0
= self
._shrink
_max
(max_e0
)
544 return min_e0
, max_e0
548 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
550 min_e0
, max_e0
= self
.e0_range
555 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
557 min_e0
, max_e0
= self
.e0_range
561 def max_abs_e0(self
):
562 """maximum value of `abs(e[0])`."""
563 return max(abs(self
.min_e0
), abs(self
.max_e0
))
566 def min_abs_e0(self
):
567 """minimum value of `abs(e[0])`."""
572 """maximum value of `n[i]` (the relative error in `N_prime[i]`
573 relative to the previous iteration)
575 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
578 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
579 # `n[0] <= 2 * neps[0] / (1 - e[0])`
581 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
582 "only one quadrant of interval division implemented"
583 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
586 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
587 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
588 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
589 "only one quadrant of interval multiplication implemented"
590 prod
= (1 - self
.max_f(0)) * min_mpd
591 assert self
.max_neps(1) >= 0 and prod
> 0, \
592 "only one quadrant of interval division implemented"
593 retval
= self
.max_neps(1) / prod
596 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
597 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
598 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
599 "only one quadrant of interval division implemented"
600 retval
= self
.max_neps(i
) / min_mpd
602 return self
._shrink
_max
(retval
)
606 """maximum value of `d[i]` (the relative error in `D_prime[i]`
607 relative to the previous iteration)
609 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
612 # `d[0] = deps[0] / (1 - e[0])`
614 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
615 "only one quadrant of interval division implemented"
616 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
619 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
620 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
621 "only one quadrant of interval multiplication implemented"
622 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
623 assert self
.max_deps(1) >= 0 and divisor
> 0, \
624 "only one quadrant of interval division implemented"
625 retval
= self
.max_deps(1) / divisor
628 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
629 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
630 "only one quadrant of interval division implemented"
631 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
633 return self
._shrink
_max
(retval
)
637 """maximum value of `f[i]` (the relative error in `F_prime[i]`
638 relative to the previous iteration)
640 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
643 # `f[0] = feps[0] / (1 - delta[0])`
645 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
646 "only one quadrant of interval division implemented"
647 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
651 retval
= self
.max_feps(1)
654 # `f[i] <= max_feps[i]`
655 retval
= self
.max_feps(i
)
657 return self
._shrink
_max
(retval
)
660 def max_delta(self
, i
):
661 """ maximum value of `delta[i]`.
662 `delta[i]` is defined in Definition 4 of paper.
664 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
666 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
667 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
669 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
670 prev_max_delta
= self
.max_delta(i
- 1)
671 assert prev_max_delta
>= 0
672 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
674 # `delta[i]` has to be smaller than one otherwise errors would go off
676 _assert_accuracy(retval
< 1)
678 return self
._shrink
_max
(retval
)
682 """ maximum value of `pi[i]`.
683 `pi[i]` is defined right below Theorem 5 of paper.
685 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
686 # `pi[i] = 1 - (1 - n[i]) * prod`
687 # where `prod` is the product of,
688 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
689 min_prod
= Fraction(1)
691 max_n_j
= self
.max_n(j
)
692 max_d_j
= self
.max_d(j
)
693 assert max_n_j
<= 1 and max_d_j
> -1, \
694 "only one quadrant of interval division implemented"
695 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
696 max_n_i
= self
.max_n(i
)
697 assert max_n_i
<= 1 and min_prod
>= 0, \
698 "only one quadrant of interval multiplication implemented"
699 retval
= 1 - (1 - max_n_i
) * min_prod
700 return self
._shrink
_max
(retval
)
703 def max_n_shift(self
):
704 """ maximum value of `state.n_shift`.
706 # input numerator is `2*io_width`-bits
707 max_n
= (1 << (self
.io_width
* 2)) - 1
709 # normalize so 1 <= n < 2
717 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
720 for i
in range(self
.iter_count
):
721 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
722 return self
._shrink
_max
(n_hat
)
724 def __make_ops(self
):
725 """ Goldschmidt division algorithm.
728 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
729 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
730 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
732 yields: GoldschmidtDivOp
733 the operations needed to perform the division.
735 # establish assumptions of the paper's error analysis (section 3.1):
737 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
738 yield GoldschmidtDivOp
.Normalize
740 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
741 # the assumption is met by multipliers with > 4-bits precision
742 _assert_accuracy(self
.expanded_width
> 4)
744 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
745 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
746 + self
.max_f(0) < Fraction(1, 2))
748 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
749 # (B is the denominator)
751 for addr
in range(self
.table_addr_count
):
752 f_prime_m1
= self
.table
[addr
]
753 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
755 yield GoldschmidtDivOp
.FEqTableLookup
757 # we use Setting I (section 4.1 of the paper):
758 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
759 # the conditions on n_hat are satisfied by construction.
760 for i
in range(self
.iter_count
):
761 _assert_accuracy(self
.max_f(i
) == 0)
762 yield GoldschmidtDivOp
.MulNByF
763 if i
!= self
.iter_count
- 1:
764 yield GoldschmidtDivOp
.MulDByF
765 yield GoldschmidtDivOp
.FEq2MinusD
767 # relative approximation error `p(N_prime[i])`:
768 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
769 # `0 <= p(N_prime[i])`
770 # `p(N_prime[i]) <= (2 * i) * n_hat \`
771 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
772 i
= self
.iter_count
- 1 # last used `i`
773 # compute power manually to prevent huge intermediate values
774 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
776 power
= self
._shrink
_max
(power
* power
)
778 max_rel_error
= (2 * i
) * self
.n_hat
+ power
780 min_a_over_b
= Fraction(1, 2)
781 max_a_over_b
= Fraction(2)
782 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
783 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
785 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
786 f
"not accurate enough: max_rel_error={max_rel_error}"
787 f
" max_allowed_rel_error={max_allowed_rel_error}")
789 yield GoldschmidtDivOp
.CalcResult
792 def default_cost_fn(self
):
793 """ calculate the estimated cost on an arbitrary scale of implementing
794 goldschmidt division with the specified parameters. larger cost
795 values mean worse parameters.
797 This is the default cost function for `GoldschmidtDivParams.get`.
801 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
802 cost
= float(rom_cells
)
804 if op
== GoldschmidtDivOp
.MulNByF \
805 or op
== GoldschmidtDivOp
.MulDByF
:
806 mul_cost
= self
.expanded_width
** 2
807 mul_cost
*= self
.expanded_width
.bit_length()
809 cost
+= 5e7
* self
.iter_count
813 @lru_cache(maxsize
=1 << 16)
814 def __cached_new(base_params
):
815 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
816 # can't use dataclasses.asdict, since it's recursive and will also give
817 # child class fields too, which we don't want.
819 for field
in fields(GoldschmidtDivParamsBase
):
820 kwargs
[field
.name
] = getattr(base_params
, field
.name
)
822 return GoldschmidtDivParams(**kwargs
), None
823 except ParamsNotAccurateEnough
as e
:
827 def __raise(e
): # type: (ParamsNotAccurateEnough) -> Any
831 def cached_new(base_params
, handle_error
=__raise
):
832 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
833 params
, error
= GoldschmidtDivParams
.__cached
_new
(base_params
)
837 return handle_error(error
)
840 def get(io_width
, cost_fn
=default_cost_fn
, max_table_addr_bits
=12):
841 """ find efficient parameters for a goldschmidt division algorithm
842 with `params.io_width == io_width`.
846 bit-width of the input divisor and the result.
847 the input numerator is `2 * io_width`-bits wide.
848 cost_fn: Callable[[GoldschmidtDivParams], float]
849 return the estimated cost on an arbitrary scale of implementing
850 goldschmidt division with the specified parameters. larger cost
851 values mean worse parameters.
852 max_table_addr_bits: int
853 maximum allowable value of `table_addr_bits`
855 assert isinstance(io_width
, int) and io_width
>= 1
856 assert callable(cost_fn
)
859 last_error_params
= None
861 def cached_new(base_params
):
863 nonlocal last_error
, last_error_params
865 last_error_params
= base_params
868 retval
= GoldschmidtDivParams
.cached_new(base_params
, handle_error
)
870 logging
.debug(f
"GoldschmidtDivParams.get: err: {base_params}")
872 logging
.debug(f
"GoldschmidtDivParams.get: ok: {base_params}")
875 @lru_cache(maxsize
=None)
876 def get_cost(base_params
):
877 params
= cached_new(base_params
)
880 retval
= cost_fn(params
)
881 logging
.debug(f
"GoldschmidtDivParams.get: cost={retval}: {params}")
884 # start with parameters big enough to always work.
885 initial_extra_precision
= io_width
* 2 + 4
886 initial_params
= GoldschmidtDivParamsBase(
888 extra_precision
=initial_extra_precision
,
889 table_addr_bits
=min(max_table_addr_bits
, io_width
),
890 table_data_bits
=io_width
+ initial_extra_precision
,
891 iter_count
=1 + io_width
.bit_length())
893 if cached_new(initial_params
) is None:
894 raise ValueError(f
"initial goldschmidt division algorithm "
895 f
"parameters are invalid: {initial_params}"
898 # find good initial `iter_count`
899 params
= initial_params
900 for iter_count
in range(1, initial_params
.iter_count
):
901 trial_params
= replace(params
, iter_count
=iter_count
)
902 if cached_new(trial_params
) is not None:
903 params
= trial_params
906 # now find `table_addr_bits`
907 cost
= get_cost(params
)
908 for table_addr_bits
in range(1, max_table_addr_bits
):
909 trial_params
= replace(params
, table_addr_bits
=table_addr_bits
)
910 trial_cost
= get_cost(trial_params
)
911 if trial_cost
< cost
:
912 params
= trial_params
916 # check one higher `iter_count` to see if it has lower cost
917 for table_addr_bits
in range(1, max_table_addr_bits
+ 1):
918 trial_params
= replace(params
,
919 table_addr_bits
=table_addr_bits
,
920 iter_count
=params
.iter_count
+ 1)
921 trial_cost
= get_cost(trial_params
)
922 if trial_cost
< cost
:
923 params
= trial_params
927 # now shrink `table_data_bits`
929 trial_params
= replace(params
,
930 table_data_bits
=params
.table_data_bits
- 1)
931 trial_cost
= get_cost(trial_params
)
932 if trial_cost
< cost
:
933 params
= trial_params
938 # and shrink `extra_precision`
940 trial_params
= replace(params
,
941 extra_precision
=params
.extra_precision
- 1)
942 trial_cost
= get_cost(trial_params
)
943 if trial_cost
< cost
:
944 params
= trial_params
949 return cached_new(params
)
953 class GoldschmidtDivOp(enum
.Enum
):
954 Normalize
= "n, d, n_shift = normalize(n, d)"
955 FEqTableLookup
= "f = table_lookup(d)"
958 FEq2MinusD
= "f = 2 - d"
959 CalcResult
= "result = unnormalize_and_round(n)"
961 def run(self
, params
, state
):
962 assert isinstance(params
, GoldschmidtDivParams
)
963 assert isinstance(state
, GoldschmidtDivState
)
964 expanded_width
= params
.expanded_width
965 table_addr_bits
= params
.table_addr_bits
966 if self
== GoldschmidtDivOp
.Normalize
:
967 # normalize so 1 <= d < 2
968 # can easily be done with count-leading-zeros and left shift
970 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
971 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
974 # normalize so 1 <= n < 2
976 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
978 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
979 # compute initial f by table lookup
981 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
982 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
983 state
.f
= params
.table
[d_m_1
.bits
]
984 elif self
== GoldschmidtDivOp
.MulNByF
:
985 assert state
.f
is not None
986 n
= state
.n
* state
.f
987 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
988 elif self
== GoldschmidtDivOp
.MulDByF
:
989 assert state
.f
is not None
990 d
= state
.d
* state
.f
991 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
992 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
993 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
994 elif self
== GoldschmidtDivOp
.CalcResult
:
995 assert state
.n_shift
is not None
996 # scale to correct value
997 n
= state
.n
* (1 << state
.n_shift
)
999 state
.quotient
= math
.floor(n
)
1000 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
1001 if state
.remainder
>= state
.orig_d
:
1003 state
.remainder
-= state
.orig_d
1005 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
1008 def goldschmidt_div(n
, d
, params
):
1009 """ Goldschmidt division algorithm.
1012 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1013 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1014 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1018 numerator. a `2*width`-bit unsigned integer.
1019 must be less than `d << width`, otherwise the quotient wouldn't
1020 fit in `width` bits.
1022 denominator. a `width`-bit unsigned integer. must not be zero.
1024 the bit-width of the inputs/outputs. must be a positive integer.
1026 returns: tuple[int, int]
1027 the quotient and remainder. a tuple of two `width`-bit unsigned
1030 assert isinstance(params
, GoldschmidtDivParams
)
1031 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
1032 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
1034 # this whole algorithm is done with fixed-point arithmetic where values
1035 # have `width` fractional bits
1037 state
= GoldschmidtDivState(
1040 n
=FixedPoint(n
, params
.io_width
),
1041 d
=FixedPoint(d
, params
.io_width
),
1044 for op
in params
.ops
:
1045 op
.run(params
, state
)
1047 assert state
.quotient
is not None
1048 assert state
.remainder
is not None
1050 return state
.quotient
, state
.remainder
1053 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
= 2
1057 def goldschmidt_sqrt_rsqrt_table(table_addr_bits
, table_data_bits
):
1058 """Generate the look-up table needed for Goldschmidt's square-root and
1059 reciprocal-square-root algorithm.
1062 table_addr_bits: int
1063 the number of address bits for the look-up table.
1064 table_data_bits: int
1065 the number of data bits for the look-up table.
1067 assert isinstance(table_addr_bits
, int) and \
1068 table_addr_bits
>= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1069 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1071 table_len
= 1 << table_addr_bits
1072 for addr
in range(table_len
):
1074 value
= FixedPoint(0, table_data_bits
)
1075 elif (addr
<< 2) < table_len
:
1076 value
= None # table entries should be unused
1078 table_addr_frac_wid
= table_addr_bits
1079 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1080 max_input_value
= FixedPoint(addr
+ 1, table_addr_bits
- 2)
1081 max_frac_wid
= max(max_input_value
.frac_wid
, table_data_bits
)
1082 value
= max_input_value
.to_frac_wid(max_frac_wid
)
1083 value
= value
.rsqrt(RoundDir
.DOWN
)
1084 value
= value
.to_frac_wid(table_data_bits
, RoundDir
.DOWN
)
1087 # tuple for immutability
1091 def goldschmidt_sqrt_rsqrt(radicand
, io_width
, frac_wid
, extra_precision
,
1092 table_addr_bits
, table_data_bits
, iter_count
):
1093 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1095 uses algorithm based on second method at:
1096 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1099 radicand: FixedPoint(frac_wid=frac_wid)
1100 the input value to take the square-root and reciprocal-square-root of.
1102 the number of bits in the input (`radicand`) and output values.
1104 the number of fraction bits in the input (`radicand`) and output
1106 extra_precision: int
1107 the number of bits of internal extra precision.
1108 table_addr_bits: int
1109 the number of address bits for the look-up table.
1110 table_data_bits: int
1111 the number of data bits for the look-up table.
1113 returns: tuple[FixedPoint, FixedPoint]
1114 the square-root and reciprocal-square-root, rounded down to the
1115 nearest representable value. If `radicand == 0`, then the
1116 reciprocal-square-root value returned is zero.
1118 assert (isinstance(radicand
, FixedPoint
)
1119 and radicand
.frac_wid
== frac_wid
1120 and 0 <= radicand
.bits
< (1 << io_width
))
1121 assert isinstance(io_width
, int) and io_width
>= 1
1122 assert isinstance(frac_wid
, int) and 0 <= frac_wid
< io_width
1123 assert isinstance(extra_precision
, int) and extra_precision
>= io_width
1124 assert isinstance(table_addr_bits
, int) and table_addr_bits
>= 1
1125 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1126 assert isinstance(iter_count
, int) and iter_count
>= 0
1127 expanded_frac_wid
= frac_wid
+ extra_precision
1128 s
= radicand
.to_frac_wid(expanded_frac_wid
)
1129 sqrt_rshift
= extra_precision
1130 rsqrt_rshift
= extra_precision
1131 while s
!= 0 and s
< 1:
1132 s
= (s
* 4).to_frac_wid(expanded_frac_wid
)
1136 s
= s
.div(4, expanded_frac_wid
)
1139 table
= goldschmidt_sqrt_rsqrt_table(table_addr_bits
=table_addr_bits
,
1140 table_data_bits
=table_data_bits
)
1141 # core goldschmidt sqrt/rsqrt algorithm:
1143 table_addr_frac_wid
= table_addr_bits
1144 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1145 addr
= s
.to_frac_wid(table_addr_frac_wid
, RoundDir
.DOWN
)
1146 assert 0 <= addr
.bits
< (1 << table_addr_bits
), "table addr out of range"
1147 f
= table
[addr
.bits
]
1148 assert f
is not None, "accessed invalid table entry"
1149 # use with_frac_wid to fix IDE type deduction
1150 f
= FixedPoint
.with_frac_wid(f
, expanded_frac_wid
, RoundDir
.DOWN
)
1151 x
= (s
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1152 h
= (f
* 0.5).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1153 for _
in range(iter_count
):
1155 f
= (1.5 - x
* h
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1156 x
= (x
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1157 h
= (h
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1159 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1161 sqrt
= FixedPoint(x
.bits
>> sqrt_rshift
, frac_wid
)
1162 rsqrt
= FixedPoint(r
.bits
>> rsqrt_rshift
, frac_wid
)
1164 next_sqrt
= FixedPoint(sqrt
.bits
+ 1, frac_wid
)
1165 if next_sqrt
* next_sqrt
<= radicand
:
1168 next_rsqrt
= FixedPoint(rsqrt
.bits
+ 1, frac_wid
)
1169 if next_rsqrt
* next_rsqrt
* radicand
<= 1 and radicand
!= 0: