1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
, fields
, replace
11 from fractions
import Fraction
12 from types
import FunctionType
13 from functools
import lru_cache
16 from functools
import cached_property
18 from cached_property
import cached_property
20 # fix broken IDE type detection for cached_property
21 from typing
import TYPE_CHECKING
, Any
23 from functools
import cached_property
29 def cache_on_self(func
):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
34 assert isinstance(func
, FunctionType
), \
35 "non-plain methods are not supported"
37 cache_name
= func
.__name
__ + "__cache"
39 def wrapper(self
, *args
, **kwargs
):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
42 if cache
is _NOT_FOUND
:
43 self
.__dict
__[cache_name
] = cache
= {}
44 key
= (args
, *kwargs
.items())
45 retval
= cache
.get(key
, _NOT_FOUND
)
46 if retval
is _NOT_FOUND
:
47 retval
= func(self
, *args
, **kwargs
)
51 wrapper
.__doc
__ = func
.__doc
__
56 class RoundDir(enum
.Enum
):
59 NEAREST_TIES_UP
= enum
.auto()
60 ERROR_IF_INEXACT
= enum
.auto()
63 @dataclass(frozen
=True)
68 def __post_init__(self
):
69 assert isinstance(self
.bits
, int)
70 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value
, FixedPoint
):
78 if isinstance(value
, int):
79 return FixedPoint(value
, 0)
80 if isinstance(value
, str):
82 neg
= value
.startswith("-")
83 if neg
or value
.startswith("+"):
85 if value
.startswith(("0x", "0X")) and "." in value
:
95 raise ValueError("too many `.` in string")
100 if not digit
.isalnum():
101 raise ValueError("invalid hexadecimal digit")
103 bits |
= int("0x" + digit
, base
=16)
105 bits
= int(value
, base
=0)
109 return FixedPoint(bits
, frac_wid
)
111 if isinstance(value
, float):
112 n
, d
= value
.as_integer_ratio()
113 log2_d
= d
.bit_length() - 1
114 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n
, log2_d
)
117 raise TypeError("can't convert type to FixedPoint")
120 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid
, int) and frac_wid
>= 0
124 assert isinstance(round_dir
, RoundDir
)
125 if isinstance(value
, Fraction
):
126 numerator
= value
.numerator
127 denominator
= value
.denominator
129 value
= FixedPoint
.cast(value
)
130 numerator
= value
.bits
131 denominator
= 1 << value
.frac_wid
133 numerator
= -numerator
134 denominator
= -denominator
135 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
136 if round_dir
== RoundDir
.DOWN
:
138 elif round_dir
== RoundDir
.UP
:
141 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
142 if remainder
* 2 >= denominator
:
144 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
146 raise ValueError("inexact conversion")
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits
, frac_wid
)
151 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self
.bits
/ (1 << self
.frac_wid
))
161 def as_fraction(self
):
162 return Fraction(self
.bits
, 1 << self
.frac_wid
)
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs
= FixedPoint
.cast(rhs
)
169 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
170 lhs
= self
.to_frac_wid(common_frac_wid
)
171 rhs
= rhs
.to_frac_wid(common_frac_wid
)
172 return lhs
.bits
- rhs
.bits
174 def __eq__(self
, rhs
):
175 return self
.cmp(rhs
) == 0
177 def __ne__(self
, rhs
):
178 return self
.cmp(rhs
) != 0
180 def __gt__(self
, rhs
):
181 return self
.cmp(rhs
) > 0
183 def __lt__(self
, rhs
):
184 return self
.cmp(rhs
) < 0
186 def __ge__(self
, rhs
):
187 return self
.cmp(rhs
) >= 0
189 def __le__(self
, rhs
):
190 return self
.cmp(rhs
) <= 0
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
196 fract_mask
= (1 << self
.frac_wid
) - 1
197 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
201 return "-" + str(-self
)
203 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
204 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
205 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
206 return hex(math
.floor(self
)) + "." + frac_str
209 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
211 def __add__(self
, rhs
):
212 rhs
= FixedPoint
.cast(rhs
)
213 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
214 lhs
= self
.to_frac_wid(common_frac_wid
)
215 rhs
= rhs
.to_frac_wid(common_frac_wid
)
216 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
218 def __radd__(self
, lhs
):
220 return self
.__add
__(lhs
)
223 return FixedPoint(-self
.bits
, self
.frac_wid
)
225 def __sub__(self
, rhs
):
226 rhs
= FixedPoint
.cast(rhs
)
227 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
228 lhs
= self
.to_frac_wid(common_frac_wid
)
229 rhs
= rhs
.to_frac_wid(common_frac_wid
)
230 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
232 def __rsub__(self
, lhs
):
234 return -self
.__sub
__(lhs
)
236 def __mul__(self
, rhs
):
237 rhs
= FixedPoint
.cast(rhs
)
238 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
240 def __rmul__(self
, lhs
):
242 return self
.__mul
__(lhs
)
245 return self
.bits
>> self
.frac_wid
247 def div(self
, rhs
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
248 assert isinstance(frac_wid
, int) and frac_wid
>= 0
249 assert isinstance(round_dir
, RoundDir
)
250 rhs
= FixedPoint
.cast(rhs
)
251 return FixedPoint
.with_frac_wid(self
.as_fraction()
255 def sqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
256 assert isinstance(round_dir
, RoundDir
)
258 raise ValueError("can't compute sqrt of negative number")
261 retval
= FixedPoint(0, self
.frac_wid
)
262 int_part_wid
= self
.bits
.bit_length() - self
.frac_wid
263 first_bit_index
= -(-int_part_wid
// 2) # division rounds up
264 last_bit_index
= -self
.frac_wid
265 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
266 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
268 if trial
* trial
<= self
:
270 if round_dir
== RoundDir
.DOWN
:
272 elif round_dir
== RoundDir
.UP
:
273 if retval
* retval
< self
:
274 retval
+= FixedPoint(1, self
.frac_wid
)
275 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
276 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
277 if half_way
* half_way
<= self
:
278 retval
+= FixedPoint(1, self
.frac_wid
)
279 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
280 if retval
* retval
!= self
:
281 raise ValueError("inexact sqrt")
283 assert False, "unimplemented round_dir"
286 def rsqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
287 """compute the reciprocal-sqrt of `self`"""
288 assert isinstance(round_dir
, RoundDir
)
290 raise ValueError("can't compute rsqrt of negative number")
292 raise ZeroDivisionError("can't compute rsqrt of zero")
293 retval
= FixedPoint(0, self
.frac_wid
)
294 first_bit_index
= -(-self
.frac_wid
// 2) # division rounds up
295 last_bit_index
= -self
.frac_wid
296 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
297 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
299 if trial
* trial
* self
<= 1:
301 if round_dir
== RoundDir
.DOWN
:
303 elif round_dir
== RoundDir
.UP
:
304 if retval
* retval
* self
< 1:
305 retval
+= FixedPoint(1, self
.frac_wid
)
306 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
307 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
308 if half_way
* half_way
* self
<= 1:
309 retval
+= FixedPoint(1, self
.frac_wid
)
310 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
311 if retval
* retval
* self
!= 1:
312 raise ValueError("inexact rsqrt")
314 assert False, "unimplemented round_dir"
318 class ParamsNotAccurateEnough(Exception):
319 """raised when the parameters aren't accurate enough to have goldschmidt
323 def _assert_accuracy(condition
, msg
="not accurate enough"):
326 raise ParamsNotAccurateEnough(msg
)
329 @dataclass(frozen
=True, unsafe_hash
=True)
330 class GoldschmidtDivParamsBase
:
331 """parameters for a Goldschmidt division algorithm, excluding derived
336 """bit-width of the input divisor and the result.
337 the input numerator is `2 * io_width`-bits wide.
341 """number of bits of additional precision used inside the algorithm."""
344 """the number of address bits used in the lookup-table."""
347 """the number of data bits used in the lookup-table."""
350 """the total number of iterations of the division algorithm's loop"""
353 @dataclass(frozen
=True, unsafe_hash
=True)
354 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
355 """parameters for a Goldschmidt division algorithm.
356 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
359 # tuple to be immutable, repr=False so repr() works for debugging even when
360 # __post_init__ hasn't finished running yet
361 table
: "tuple[FixedPoint, ...]" = field(init
=False, repr=False)
362 """the lookup-table"""
364 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False, repr=False)
365 """the operations needed to perform the goldschmidt division algorithm."""
367 def _shrink_bound(self
, bound
, round_dir
):
368 """prevent fractions from having huge numerators/denominators by
369 rounding to a `FixedPoint` and converting back to a `Fraction`.
371 This is intended only for values used to compute bounds, and not for
372 values that end up in the hardware.
374 assert isinstance(bound
, (Fraction
, int))
375 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
376 "you shouldn't use that round_dir on bounds"
377 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
378 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
379 return fixed
.as_fraction()
381 def _shrink_min(self
, min_bound
):
382 """prevent fractions used as minimum bounds from having huge
383 numerators/denominators by rounding down to a `FixedPoint` and
384 converting back to a `Fraction`.
386 This is intended only for values used to compute bounds, and not for
387 values that end up in the hardware.
389 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
391 def _shrink_max(self
, max_bound
):
392 """prevent fractions used as maximum bounds from having huge
393 numerators/denominators by rounding up to a `FixedPoint` and
394 converting back to a `Fraction`.
396 This is intended only for values used to compute bounds, and not for
397 values that end up in the hardware.
399 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
402 def table_addr_count(self
):
403 """number of distinct addresses in the lookup-table."""
404 # used while computing self.table, so can't just do len(self.table)
405 return 1 << self
.table_addr_bits
407 def table_input_exact_range(self
, addr
):
408 """return the range of inputs as `Fraction`s used for the table entry
409 with address `addr`."""
410 assert isinstance(addr
, int)
411 assert 0 <= addr
< self
.table_addr_count
412 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
413 addr_shift
= self
.io_width
- self
.table_addr_bits
414 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
415 denominator
= 1 << self
.io_width
416 values_per_table_entry
= 1 << addr_shift
417 max_numerator
= min_numerator
+ values_per_table_entry
- 1
418 min_input
= Fraction(min_numerator
, denominator
)
419 max_input
= Fraction(max_numerator
, denominator
)
420 min_input
= self
._shrink
_min
(min_input
)
421 max_input
= self
._shrink
_max
(max_input
)
422 assert 1 <= min_input
<= max_input
< 2
423 return min_input
, max_input
425 def table_value_exact_range(self
, addr
):
426 """return the range of values as `Fraction`s used for the table entry
427 with address `addr`."""
428 min_input
, max_input
= self
.table_input_exact_range(addr
)
429 # division swaps min/max
430 min_value
= 1 / max_input
431 max_value
= 1 / min_input
432 min_value
= self
._shrink
_min
(min_value
)
433 max_value
= self
._shrink
_max
(max_value
)
434 assert 0.5 < min_value
<= max_value
<= 1
435 return min_value
, max_value
437 def table_exact_value(self
, index
):
438 min_value
, max_value
= self
.table_value_exact_range(index
)
442 def __post_init__(self
):
443 # called by the autogenerated __init__
444 _assert_accuracy(self
.io_width
>= 1, "io_width out of range")
445 _assert_accuracy(self
.extra_precision
>= 0,
446 "extra_precision out of range")
447 _assert_accuracy(self
.table_addr_bits
>= 1,
448 "table_addr_bits out of range")
449 _assert_accuracy(self
.table_data_bits
>= 1,
450 "table_data_bits out of range")
451 _assert_accuracy(self
.iter_count
>= 1, "iter_count out of range")
453 for addr
in range(1 << self
.table_addr_bits
):
454 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
455 self
.table_data_bits
,
457 # we have to use object.__setattr__ since frozen=True
458 object.__setattr
__(self
, "table", tuple(table
))
459 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
462 def expanded_width(self
):
463 """the total number of bits of precision used inside the algorithm."""
464 return self
.io_width
+ self
.extra_precision
467 def max_neps(self
, i
):
468 """maximum value of `neps[i]`.
469 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
471 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
472 return Fraction(1, 1 << self
.expanded_width
)
475 def max_deps(self
, i
):
476 """maximum value of `deps[i]`.
477 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
479 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
480 return Fraction(1, 1 << self
.expanded_width
)
483 def max_feps(self
, i
):
484 """maximum value of `feps[i]`.
485 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
487 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
488 # zero, because the computation of `F_prime[i]` in
489 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
494 """minimum and maximum values of `e[0]`
495 (the relative error in `F_prime[-1]`)
499 for addr
in range(self
.table_addr_count
):
500 # `F_prime[-1] = (1 - e[0]) / B`
501 # => `e[0] = 1 - B * F_prime[-1]`
502 min_b
, max_b
= self
.table_input_exact_range(addr
)
503 f_prime_m1
= self
.table
[addr
].as_fraction()
504 assert min_b
>= 0 and f_prime_m1
>= 0, \
505 "only positive quadrant of interval multiplication implemented"
506 min_product
= min_b
* f_prime_m1
507 max_product
= max_b
* f_prime_m1
508 # negation swaps min/max
509 cur_min_e0
= 1 - max_product
510 cur_max_e0
= 1 - min_product
511 min_e0
= min(min_e0
, cur_min_e0
)
512 max_e0
= max(max_e0
, cur_max_e0
)
513 min_e0
= self
._shrink
_min
(min_e0
)
514 max_e0
= self
._shrink
_max
(max_e0
)
515 return min_e0
, max_e0
519 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
521 min_e0
, max_e0
= self
.e0_range
526 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
528 min_e0
, max_e0
= self
.e0_range
532 def max_abs_e0(self
):
533 """maximum value of `abs(e[0])`."""
534 return max(abs(self
.min_e0
), abs(self
.max_e0
))
537 def min_abs_e0(self
):
538 """minimum value of `abs(e[0])`."""
543 """maximum value of `n[i]` (the relative error in `N_prime[i]`
544 relative to the previous iteration)
546 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
549 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
550 # `n[0] <= 2 * neps[0] / (1 - e[0])`
552 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
553 "only one quadrant of interval division implemented"
554 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
557 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
558 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
559 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
560 "only one quadrant of interval multiplication implemented"
561 prod
= (1 - self
.max_f(0)) * min_mpd
562 assert self
.max_neps(1) >= 0 and prod
> 0, \
563 "only one quadrant of interval division implemented"
564 retval
= self
.max_neps(1) / prod
567 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
568 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
569 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
570 "only one quadrant of interval division implemented"
571 retval
= self
.max_neps(i
) / min_mpd
573 return self
._shrink
_max
(retval
)
577 """maximum value of `d[i]` (the relative error in `D_prime[i]`
578 relative to the previous iteration)
580 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
583 # `d[0] = deps[0] / (1 - e[0])`
585 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
586 "only one quadrant of interval division implemented"
587 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
590 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
591 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
592 "only one quadrant of interval multiplication implemented"
593 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
594 assert self
.max_deps(1) >= 0 and divisor
> 0, \
595 "only one quadrant of interval division implemented"
596 retval
= self
.max_deps(1) / divisor
599 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
600 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
601 "only one quadrant of interval division implemented"
602 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
604 return self
._shrink
_max
(retval
)
608 """maximum value of `f[i]` (the relative error in `F_prime[i]`
609 relative to the previous iteration)
611 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
614 # `f[0] = feps[0] / (1 - delta[0])`
616 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
617 "only one quadrant of interval division implemented"
618 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
622 retval
= self
.max_feps(1)
625 # `f[i] <= max_feps[i]`
626 retval
= self
.max_feps(i
)
628 return self
._shrink
_max
(retval
)
631 def max_delta(self
, i
):
632 """ maximum value of `delta[i]`.
633 `delta[i]` is defined in Definition 4 of paper.
635 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
637 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
638 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
640 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
641 prev_max_delta
= self
.max_delta(i
- 1)
642 assert prev_max_delta
>= 0
643 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
645 # `delta[i]` has to be smaller than one otherwise errors would go off
647 _assert_accuracy(retval
< 1)
649 return self
._shrink
_max
(retval
)
653 """ maximum value of `pi[i]`.
654 `pi[i]` is defined right below Theorem 5 of paper.
656 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
657 # `pi[i] = 1 - (1 - n[i]) * prod`
658 # where `prod` is the product of,
659 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
660 min_prod
= Fraction(1)
662 max_n_j
= self
.max_n(j
)
663 max_d_j
= self
.max_d(j
)
664 assert max_n_j
<= 1 and max_d_j
> -1, \
665 "only one quadrant of interval division implemented"
666 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
667 max_n_i
= self
.max_n(i
)
668 assert max_n_i
<= 1 and min_prod
>= 0, \
669 "only one quadrant of interval multiplication implemented"
670 retval
= 1 - (1 - max_n_i
) * min_prod
671 return self
._shrink
_max
(retval
)
674 def max_n_shift(self
):
675 """ maximum value of `state.n_shift`.
677 # input numerator is `2*io_width`-bits
678 max_n
= (1 << (self
.io_width
* 2)) - 1
680 # normalize so 1 <= n < 2
688 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
691 for i
in range(self
.iter_count
):
692 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
693 return self
._shrink
_max
(n_hat
)
695 def __make_ops(self
):
696 """ Goldschmidt division algorithm.
699 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
700 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
701 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
703 yields: GoldschmidtDivOp
704 the operations needed to perform the division.
706 # establish assumptions of the paper's error analysis (section 3.1):
708 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
709 yield GoldschmidtDivOp
.Normalize
711 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
712 # the assumption is met by multipliers with > 4-bits precision
713 _assert_accuracy(self
.expanded_width
> 4)
715 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
716 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
717 + self
.max_f(0) < Fraction(1, 2))
719 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
720 # (B is the denominator)
722 for addr
in range(self
.table_addr_count
):
723 f_prime_m1
= self
.table
[addr
]
724 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
726 yield GoldschmidtDivOp
.FEqTableLookup
728 # we use Setting I (section 4.1 of the paper):
729 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
730 # the conditions on n_hat are satisfied by construction.
731 for i
in range(self
.iter_count
):
732 _assert_accuracy(self
.max_f(i
) == 0)
733 yield GoldschmidtDivOp
.MulNByF
734 if i
!= self
.iter_count
- 1:
735 yield GoldschmidtDivOp
.MulDByF
736 yield GoldschmidtDivOp
.FEq2MinusD
738 # relative approximation error `p(N_prime[i])`:
739 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
740 # `0 <= p(N_prime[i])`
741 # `p(N_prime[i]) <= (2 * i) * n_hat \`
742 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
743 i
= self
.iter_count
- 1 # last used `i`
744 # compute power manually to prevent huge intermediate values
745 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
747 power
= self
._shrink
_max
(power
* power
)
749 max_rel_error
= (2 * i
) * self
.n_hat
+ power
751 min_a_over_b
= Fraction(1, 2)
752 max_a_over_b
= Fraction(2)
753 max_allowed_abs_error
= max_a_over_b
/ (1 << self
.max_n_shift
)
754 max_allowed_rel_error
= max_allowed_abs_error
/ min_a_over_b
756 _assert_accuracy(max_rel_error
< max_allowed_rel_error
,
757 f
"not accurate enough: max_rel_error={max_rel_error}"
758 f
" max_allowed_rel_error={max_allowed_rel_error}")
760 yield GoldschmidtDivOp
.CalcResult
763 def default_cost_fn(self
):
764 """ calculate the estimated cost on an arbitrary scale of implementing
765 goldschmidt division with the specified parameters. larger cost
766 values mean worse parameters.
768 This is the default cost function for `GoldschmidtDivParams.get`.
772 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
773 cost
= float(rom_cells
)
775 if op
== GoldschmidtDivOp
.MulNByF \
776 or op
== GoldschmidtDivOp
.MulDByF
:
777 mul_cost
= self
.expanded_width
** 2
778 mul_cost
*= self
.expanded_width
.bit_length()
780 cost
+= 5e7
* self
.iter_count
784 @lru_cache(maxsize
=1 << 16)
785 def __cached_new(base_params
):
786 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
787 # can't use dataclasses.asdict, since it's recursive and will also give
788 # child class fields too, which we don't want.
790 for field
in fields(GoldschmidtDivParamsBase
):
791 kwargs
[field
.name
] = getattr(base_params
, field
.name
)
793 return GoldschmidtDivParams(**kwargs
), None
794 except ParamsNotAccurateEnough
as e
:
798 def __raise(e
): # type: (ParamsNotAccurateEnough) -> Any
802 def cached_new(base_params
, handle_error
=__raise
):
803 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
804 params
, error
= GoldschmidtDivParams
.__cached
_new
(base_params
)
808 return handle_error(error
)
811 def get(io_width
, cost_fn
=default_cost_fn
, max_table_addr_bits
=12):
812 """ find efficient parameters for a goldschmidt division algorithm
813 with `params.io_width == io_width`.
817 bit-width of the input divisor and the result.
818 the input numerator is `2 * io_width`-bits wide.
819 cost_fn: Callable[[GoldschmidtDivParams], float]
820 return the estimated cost on an arbitrary scale of implementing
821 goldschmidt division with the specified parameters. larger cost
822 values mean worse parameters.
823 max_table_addr_bits: int
824 maximum allowable value of `table_addr_bits`
826 assert isinstance(io_width
, int) and io_width
>= 1
827 assert callable(cost_fn
)
830 last_error_params
= None
832 def cached_new(base_params
):
834 nonlocal last_error
, last_error_params
836 last_error_params
= base_params
839 retval
= GoldschmidtDivParams
.cached_new(base_params
, handle_error
)
841 logging
.debug(f
"GoldschmidtDivParams.get: err: {base_params}")
843 logging
.debug(f
"GoldschmidtDivParams.get: ok: {base_params}")
846 @lru_cache(maxsize
=None)
847 def get_cost(base_params
):
848 params
= cached_new(base_params
)
851 retval
= cost_fn(params
)
852 logging
.debug(f
"GoldschmidtDivParams.get: cost={retval}: {params}")
855 # start with parameters big enough to always work.
856 initial_extra_precision
= io_width
* 2 + 4
857 initial_params
= GoldschmidtDivParamsBase(
859 extra_precision
=initial_extra_precision
,
860 table_addr_bits
=min(max_table_addr_bits
, io_width
),
861 table_data_bits
=io_width
+ initial_extra_precision
,
862 iter_count
=1 + io_width
.bit_length())
864 if cached_new(initial_params
) is None:
865 raise ValueError(f
"initial goldschmidt division algorithm "
866 f
"parameters are invalid: {initial_params}"
869 # find good initial `iter_count`
870 params
= initial_params
871 for iter_count
in range(1, initial_params
.iter_count
):
872 trial_params
= replace(params
, iter_count
=iter_count
)
873 if cached_new(trial_params
) is not None:
874 params
= trial_params
877 # now find `table_addr_bits`
878 cost
= get_cost(params
)
879 for table_addr_bits
in range(1, max_table_addr_bits
):
880 trial_params
= replace(params
, table_addr_bits
=table_addr_bits
)
881 trial_cost
= get_cost(trial_params
)
882 if trial_cost
< cost
:
883 params
= trial_params
887 # check one higher `iter_count` to see if it has lower cost
888 for table_addr_bits
in range(1, max_table_addr_bits
+ 1):
889 trial_params
= replace(params
,
890 table_addr_bits
=table_addr_bits
,
891 iter_count
=params
.iter_count
+ 1)
892 trial_cost
= get_cost(trial_params
)
893 if trial_cost
< cost
:
894 params
= trial_params
898 # now shrink `table_data_bits`
900 trial_params
= replace(params
,
901 table_data_bits
=params
.table_data_bits
- 1)
902 trial_cost
= get_cost(trial_params
)
903 if trial_cost
< cost
:
904 params
= trial_params
909 # and shrink `extra_precision`
911 trial_params
= replace(params
,
912 extra_precision
=params
.extra_precision
- 1)
913 trial_cost
= get_cost(trial_params
)
914 if trial_cost
< cost
:
915 params
= trial_params
920 return cached_new(params
)
924 class GoldschmidtDivOp(enum
.Enum
):
925 Normalize
= "n, d, n_shift = normalize(n, d)"
926 FEqTableLookup
= "f = table_lookup(d)"
929 FEq2MinusD
= "f = 2 - d"
930 CalcResult
= "result = unnormalize_and_round(n)"
932 def run(self
, params
, state
):
933 assert isinstance(params
, GoldschmidtDivParams
)
934 assert isinstance(state
, GoldschmidtDivState
)
935 expanded_width
= params
.expanded_width
936 table_addr_bits
= params
.table_addr_bits
937 if self
== GoldschmidtDivOp
.Normalize
:
938 # normalize so 1 <= d < 2
939 # can easily be done with count-leading-zeros and left shift
941 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
942 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
945 # normalize so 1 <= n < 2
947 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
949 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
950 # compute initial f by table lookup
952 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
953 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
954 state
.f
= params
.table
[d_m_1
.bits
]
955 elif self
== GoldschmidtDivOp
.MulNByF
:
956 assert state
.f
is not None
957 n
= state
.n
* state
.f
958 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
959 elif self
== GoldschmidtDivOp
.MulDByF
:
960 assert state
.f
is not None
961 d
= state
.d
* state
.f
962 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
963 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
964 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
965 elif self
== GoldschmidtDivOp
.CalcResult
:
966 assert state
.n_shift
is not None
967 # scale to correct value
968 n
= state
.n
* (1 << state
.n_shift
)
970 state
.quotient
= math
.floor(n
)
971 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
972 if state
.remainder
>= state
.orig_d
:
974 state
.remainder
-= state
.orig_d
976 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
980 class GoldschmidtDivState
:
982 """original numerator"""
985 """original denominator"""
988 """numerator -- N_prime[i] in the paper's algorithm 2"""
991 """denominator -- D_prime[i] in the paper's algorithm 2"""
993 f
: "FixedPoint | None" = None
994 """current factor -- F_prime[i] in the paper's algorithm 2"""
996 quotient
: "int | None" = None
999 remainder
: "int | None" = None
1000 """final remainder"""
1002 n_shift
: "int | None" = None
1003 """amount the numerator needs to be left-shifted at the end of the
1008 def goldschmidt_div(n
, d
, params
):
1009 """ Goldschmidt division algorithm.
1012 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1013 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1014 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1018 numerator. a `2*width`-bit unsigned integer.
1019 must be less than `d << width`, otherwise the quotient wouldn't
1020 fit in `width` bits.
1022 denominator. a `width`-bit unsigned integer. must not be zero.
1024 the bit-width of the inputs/outputs. must be a positive integer.
1026 returns: tuple[int, int]
1027 the quotient and remainder. a tuple of two `width`-bit unsigned
1030 assert isinstance(params
, GoldschmidtDivParams
)
1031 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
1032 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
1034 # this whole algorithm is done with fixed-point arithmetic where values
1035 # have `width` fractional bits
1037 state
= GoldschmidtDivState(
1040 n
=FixedPoint(n
, params
.io_width
),
1041 d
=FixedPoint(d
, params
.io_width
),
1044 for op
in params
.ops
:
1045 op
.run(params
, state
)
1047 assert state
.quotient
is not None
1048 assert state
.remainder
is not None
1050 return state
.quotient
, state
.remainder
1053 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
= 2
1057 def goldschmidt_sqrt_rsqrt_table(table_addr_bits
, table_data_bits
):
1058 """Generate the look-up table needed for Goldschmidt's square-root and
1059 reciprocal-square-root algorithm.
1062 table_addr_bits: int
1063 the number of address bits for the look-up table.
1064 table_data_bits: int
1065 the number of data bits for the look-up table.
1067 assert isinstance(table_addr_bits
, int) and \
1068 table_addr_bits
>= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1069 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1071 table_len
= 1 << table_addr_bits
1072 for addr
in range(table_len
):
1074 value
= FixedPoint(0, table_data_bits
)
1075 elif (addr
<< 2) < table_len
:
1076 value
= None # table entries should be unused
1078 table_addr_frac_wid
= table_addr_bits
1079 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1080 max_input_value
= FixedPoint(addr
+ 1, table_addr_bits
- 2)
1081 max_frac_wid
= max(max_input_value
.frac_wid
, table_data_bits
)
1082 value
= max_input_value
.to_frac_wid(max_frac_wid
)
1083 value
= value
.rsqrt(RoundDir
.DOWN
)
1084 value
= value
.to_frac_wid(table_data_bits
, RoundDir
.DOWN
)
1087 # tuple for immutability
1090 # FIXME: add code to calculate error bounds and check that the algorithm will
1091 # actually work (like in the goldschmidt division algorithm).
1092 # FIXME: add code to calculate a good set of parameters based on the error
1096 def goldschmidt_sqrt_rsqrt(radicand
, io_width
, frac_wid
, extra_precision
,
1097 table_addr_bits
, table_data_bits
, iter_count
):
1098 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1100 uses algorithm based on second method at:
1101 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1104 radicand: FixedPoint(frac_wid=frac_wid)
1105 the input value to take the square-root and reciprocal-square-root of.
1107 the number of bits in the input (`radicand`) and output values.
1109 the number of fraction bits in the input (`radicand`) and output
1111 extra_precision: int
1112 the number of bits of internal extra precision.
1113 table_addr_bits: int
1114 the number of address bits for the look-up table.
1115 table_data_bits: int
1116 the number of data bits for the look-up table.
1118 returns: tuple[FixedPoint, FixedPoint]
1119 the square-root and reciprocal-square-root, rounded down to the
1120 nearest representable value. If `radicand == 0`, then the
1121 reciprocal-square-root value returned is zero.
1123 assert (isinstance(radicand
, FixedPoint
)
1124 and radicand
.frac_wid
== frac_wid
1125 and 0 <= radicand
.bits
< (1 << io_width
))
1126 assert isinstance(io_width
, int) and io_width
>= 1
1127 assert isinstance(frac_wid
, int) and 0 <= frac_wid
< io_width
1128 assert isinstance(extra_precision
, int) and extra_precision
>= io_width
1129 assert isinstance(table_addr_bits
, int) and table_addr_bits
>= 1
1130 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1131 assert isinstance(iter_count
, int) and iter_count
>= 0
1132 expanded_frac_wid
= frac_wid
+ extra_precision
1133 s
= radicand
.to_frac_wid(expanded_frac_wid
)
1134 sqrt_rshift
= extra_precision
1135 rsqrt_rshift
= extra_precision
1136 while s
!= 0 and s
< 1:
1137 s
= (s
* 4).to_frac_wid(expanded_frac_wid
)
1141 s
= s
.div(4, expanded_frac_wid
)
1144 table
= goldschmidt_sqrt_rsqrt_table(table_addr_bits
=table_addr_bits
,
1145 table_data_bits
=table_data_bits
)
1146 # core goldschmidt sqrt/rsqrt algorithm:
1148 table_addr_frac_wid
= table_addr_bits
1149 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1150 addr
= s
.to_frac_wid(table_addr_frac_wid
, RoundDir
.DOWN
)
1151 assert 0 <= addr
.bits
< (1 << table_addr_bits
), "table addr out of range"
1152 f
= table
[addr
.bits
]
1153 assert f
is not None, "accessed invalid table entry"
1154 # use with_frac_wid to fix IDE type deduction
1155 f
= FixedPoint
.with_frac_wid(f
, expanded_frac_wid
, RoundDir
.DOWN
)
1156 x
= (s
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1157 h
= (f
* 0.5).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1158 for _
in range(iter_count
):
1160 f
= (1.5 - x
* h
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1161 x
= (x
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1162 h
= (h
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1164 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1166 sqrt
= FixedPoint(x
.bits
>> sqrt_rshift
, frac_wid
)
1167 rsqrt
= FixedPoint(r
.bits
>> rsqrt_rshift
, frac_wid
)
1169 next_sqrt
= FixedPoint(sqrt
.bits
+ 1, frac_wid
)
1170 if next_sqrt
* next_sqrt
<= radicand
:
1173 next_rsqrt
= FixedPoint(rsqrt
.bits
+ 1, frac_wid
)
1174 if next_rsqrt
* next_rsqrt
* radicand
<= 1 and radicand
!= 0: