add the goldschmidt sqrt/rsqrt algorithm, still need code to calculate good parameters
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field, fields, replace
8 import logging
9 import math
10 import enum
11 from fractions import Fraction
12 from types import FunctionType
13 from functools import lru_cache
14
15 try:
16 from functools import cached_property
17 except ImportError:
18 from cached_property import cached_property
19
20 # fix broken IDE type detection for cached_property
21 from typing import TYPE_CHECKING, Any
22 if TYPE_CHECKING:
23 from functools import cached_property
24
25
26 _NOT_FOUND = object()
27
28
29 def cache_on_self(func):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
32 per-method."""
33
34 assert isinstance(func, FunctionType), \
35 "non-plain methods are not supported"
36
37 cache_name = func.__name__ + "__cache"
38
39 def wrapper(self, *args, **kwargs):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache = self.__dict__.get(cache_name, _NOT_FOUND)
42 if cache is _NOT_FOUND:
43 self.__dict__[cache_name] = cache = {}
44 key = (args, *kwargs.items())
45 retval = cache.get(key, _NOT_FOUND)
46 if retval is _NOT_FOUND:
47 retval = func(self, *args, **kwargs)
48 cache[key] = retval
49 return retval
50
51 wrapper.__doc__ = func.__doc__
52 return wrapper
53
54
55 @enum.unique
56 class RoundDir(enum.Enum):
57 DOWN = enum.auto()
58 UP = enum.auto()
59 NEAREST_TIES_UP = enum.auto()
60 ERROR_IF_INEXACT = enum.auto()
61
62
63 @dataclass(frozen=True)
64 class FixedPoint:
65 bits: int
66 frac_wid: int
67
68 def __post_init__(self):
69 assert isinstance(self.bits, int)
70 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
71
72 @staticmethod
73 def cast(value):
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value, FixedPoint):
77 return value
78 if isinstance(value, int):
79 return FixedPoint(value, 0)
80 if isinstance(value, str):
81 value = value.strip()
82 neg = value.startswith("-")
83 if neg or value.startswith("+"):
84 value = value[1:]
85 if value.startswith(("0x", "0X")) and "." in value:
86 value = value[2:]
87 got_dot = False
88 bits = 0
89 frac_wid = 0
90 for digit in value:
91 if digit == "_":
92 continue
93 if got_dot:
94 if digit == ".":
95 raise ValueError("too many `.` in string")
96 frac_wid += 4
97 if digit == ".":
98 got_dot = True
99 continue
100 if not digit.isalnum():
101 raise ValueError("invalid hexadecimal digit")
102 bits <<= 4
103 bits |= int("0x" + digit, base=16)
104 else:
105 bits = int(value, base=0)
106 frac_wid = 0
107 if neg:
108 bits = -bits
109 return FixedPoint(bits, frac_wid)
110
111 if isinstance(value, float):
112 n, d = value.as_integer_ratio()
113 log2_d = d.bit_length() - 1
114 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n, log2_d)
117 raise TypeError("can't convert type to FixedPoint")
118
119 @staticmethod
120 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid, int) and frac_wid >= 0
124 assert isinstance(round_dir, RoundDir)
125 if isinstance(value, Fraction):
126 numerator = value.numerator
127 denominator = value.denominator
128 else:
129 value = FixedPoint.cast(value)
130 numerator = value.bits
131 denominator = 1 << value.frac_wid
132 if denominator < 0:
133 numerator = -numerator
134 denominator = -denominator
135 bits, remainder = divmod(numerator << frac_wid, denominator)
136 if round_dir == RoundDir.DOWN:
137 pass
138 elif round_dir == RoundDir.UP:
139 if remainder != 0:
140 bits += 1
141 elif round_dir == RoundDir.NEAREST_TIES_UP:
142 if remainder * 2 >= denominator:
143 bits += 1
144 elif round_dir == RoundDir.ERROR_IF_INEXACT:
145 if remainder != 0:
146 raise ValueError("inexact conversion")
147 else:
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits, frac_wid)
150
151 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
155
156 def __float__(self):
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self.bits / (1 << self.frac_wid))
160
161 def as_fraction(self):
162 return Fraction(self.bits, 1 << self.frac_wid)
163
164 def cmp(self, rhs):
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs = FixedPoint.cast(rhs)
169 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
170 lhs = self.to_frac_wid(common_frac_wid)
171 rhs = rhs.to_frac_wid(common_frac_wid)
172 return lhs.bits - rhs.bits
173
174 def __eq__(self, rhs):
175 return self.cmp(rhs) == 0
176
177 def __ne__(self, rhs):
178 return self.cmp(rhs) != 0
179
180 def __gt__(self, rhs):
181 return self.cmp(rhs) > 0
182
183 def __lt__(self, rhs):
184 return self.cmp(rhs) < 0
185
186 def __ge__(self, rhs):
187 return self.cmp(rhs) >= 0
188
189 def __le__(self, rhs):
190 return self.cmp(rhs) <= 0
191
192 def fract(self):
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
195 """
196 fract_mask = (1 << self.frac_wid) - 1
197 return FixedPoint(self.bits & fract_mask, self.frac_wid)
198
199 def __str__(self):
200 if self < 0:
201 return "-" + str(-self)
202 digit_bits = 4
203 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
204 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
205 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
206 return hex(math.floor(self)) + "." + frac_str
207
208 def __repr__(self):
209 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
210
211 def __add__(self, rhs):
212 rhs = FixedPoint.cast(rhs)
213 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
214 lhs = self.to_frac_wid(common_frac_wid)
215 rhs = rhs.to_frac_wid(common_frac_wid)
216 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
217
218 def __radd__(self, lhs):
219 # symmetric
220 return self.__add__(lhs)
221
222 def __neg__(self):
223 return FixedPoint(-self.bits, self.frac_wid)
224
225 def __sub__(self, rhs):
226 rhs = FixedPoint.cast(rhs)
227 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
228 lhs = self.to_frac_wid(common_frac_wid)
229 rhs = rhs.to_frac_wid(common_frac_wid)
230 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
231
232 def __rsub__(self, lhs):
233 # a - b == -(b - a)
234 return -self.__sub__(lhs)
235
236 def __mul__(self, rhs):
237 rhs = FixedPoint.cast(rhs)
238 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
239
240 def __rmul__(self, lhs):
241 # symmetric
242 return self.__mul__(lhs)
243
244 def __floor__(self):
245 return self.bits >> self.frac_wid
246
247 def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
248 assert isinstance(frac_wid, int) and frac_wid >= 0
249 assert isinstance(round_dir, RoundDir)
250 rhs = FixedPoint.cast(rhs)
251 return FixedPoint.with_frac_wid(self.as_fraction()
252 / rhs.as_fraction(),
253 frac_wid, round_dir)
254
255 def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
256 assert isinstance(round_dir, RoundDir)
257 if self < 0:
258 raise ValueError("can't compute sqrt of negative number")
259 if self == 0:
260 return self
261 retval = FixedPoint(0, self.frac_wid)
262 int_part_wid = self.bits.bit_length() - self.frac_wid
263 first_bit_index = -(-int_part_wid // 2) # division rounds up
264 last_bit_index = -self.frac_wid
265 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
266 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
267 self.frac_wid)
268 if trial * trial <= self:
269 retval = trial
270 if round_dir == RoundDir.DOWN:
271 pass
272 elif round_dir == RoundDir.UP:
273 if retval * retval < self:
274 retval += FixedPoint(1, self.frac_wid)
275 elif round_dir == RoundDir.NEAREST_TIES_UP:
276 half_way = retval + FixedPoint(1, self.frac_wid + 1)
277 if half_way * half_way <= self:
278 retval += FixedPoint(1, self.frac_wid)
279 elif round_dir == RoundDir.ERROR_IF_INEXACT:
280 if retval * retval != self:
281 raise ValueError("inexact sqrt")
282 else:
283 assert False, "unimplemented round_dir"
284 return retval
285
286 def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
287 """compute the reciprocal-sqrt of `self`"""
288 assert isinstance(round_dir, RoundDir)
289 if self < 0:
290 raise ValueError("can't compute rsqrt of negative number")
291 if self == 0:
292 raise ZeroDivisionError("can't compute rsqrt of zero")
293 retval = FixedPoint(0, self.frac_wid)
294 first_bit_index = -(-self.frac_wid // 2) # division rounds up
295 last_bit_index = -self.frac_wid
296 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
297 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
298 self.frac_wid)
299 if trial * trial * self <= 1:
300 retval = trial
301 if round_dir == RoundDir.DOWN:
302 pass
303 elif round_dir == RoundDir.UP:
304 if retval * retval * self < 1:
305 retval += FixedPoint(1, self.frac_wid)
306 elif round_dir == RoundDir.NEAREST_TIES_UP:
307 half_way = retval + FixedPoint(1, self.frac_wid + 1)
308 if half_way * half_way * self <= 1:
309 retval += FixedPoint(1, self.frac_wid)
310 elif round_dir == RoundDir.ERROR_IF_INEXACT:
311 if retval * retval * self != 1:
312 raise ValueError("inexact rsqrt")
313 else:
314 assert False, "unimplemented round_dir"
315 return retval
316
317
318 @dataclass
319 class GoldschmidtDivState:
320 orig_n: int
321 """original numerator"""
322
323 orig_d: int
324 """original denominator"""
325
326 n: FixedPoint
327 """numerator -- N_prime[i] in the paper's algorithm 2"""
328
329 d: FixedPoint
330 """denominator -- D_prime[i] in the paper's algorithm 2"""
331
332 f: "FixedPoint | None" = None
333 """current factor -- F_prime[i] in the paper's algorithm 2"""
334
335 quotient: "int | None" = None
336 """final quotient"""
337
338 remainder: "int | None" = None
339 """final remainder"""
340
341 n_shift: "int | None" = None
342 """amount the numerator needs to be left-shifted at the end of the
343 algorithm.
344 """
345
346
347 class ParamsNotAccurateEnough(Exception):
348 """raised when the parameters aren't accurate enough to have goldschmidt
349 division work."""
350
351
352 def _assert_accuracy(condition, msg="not accurate enough"):
353 if condition:
354 return
355 raise ParamsNotAccurateEnough(msg)
356
357
358 @dataclass(frozen=True, unsafe_hash=True)
359 class GoldschmidtDivParamsBase:
360 """parameters for a Goldschmidt division algorithm, excluding derived
361 parameters.
362 """
363
364 io_width: int
365 """bit-width of the input divisor and the result.
366 the input numerator is `2 * io_width`-bits wide.
367 """
368
369 extra_precision: int
370 """number of bits of additional precision used inside the algorithm."""
371
372 table_addr_bits: int
373 """the number of address bits used in the lookup-table."""
374
375 table_data_bits: int
376 """the number of data bits used in the lookup-table."""
377
378 iter_count: int
379 """the total number of iterations of the division algorithm's loop"""
380
381
382 @dataclass(frozen=True, unsafe_hash=True)
383 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
384 """parameters for a Goldschmidt division algorithm.
385 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
386 """
387
388 # tuple to be immutable, repr=False so repr() works for debugging even when
389 # __post_init__ hasn't finished running yet
390 table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
391 """the lookup-table"""
392
393 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
394 """the operations needed to perform the goldschmidt division algorithm."""
395
396 def _shrink_bound(self, bound, round_dir):
397 """prevent fractions from having huge numerators/denominators by
398 rounding to a `FixedPoint` and converting back to a `Fraction`.
399
400 This is intended only for values used to compute bounds, and not for
401 values that end up in the hardware.
402 """
403 assert isinstance(bound, (Fraction, int))
404 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
405 "you shouldn't use that round_dir on bounds"
406 frac_wid = self.io_width * 4 + 100 # should be enough precision
407 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
408 return fixed.as_fraction()
409
410 def _shrink_min(self, min_bound):
411 """prevent fractions used as minimum bounds from having huge
412 numerators/denominators by rounding down to a `FixedPoint` and
413 converting back to a `Fraction`.
414
415 This is intended only for values used to compute bounds, and not for
416 values that end up in the hardware.
417 """
418 return self._shrink_bound(min_bound, RoundDir.DOWN)
419
420 def _shrink_max(self, max_bound):
421 """prevent fractions used as maximum bounds from having huge
422 numerators/denominators by rounding up to a `FixedPoint` and
423 converting back to a `Fraction`.
424
425 This is intended only for values used to compute bounds, and not for
426 values that end up in the hardware.
427 """
428 return self._shrink_bound(max_bound, RoundDir.UP)
429
430 @property
431 def table_addr_count(self):
432 """number of distinct addresses in the lookup-table."""
433 # used while computing self.table, so can't just do len(self.table)
434 return 1 << self.table_addr_bits
435
436 def table_input_exact_range(self, addr):
437 """return the range of inputs as `Fraction`s used for the table entry
438 with address `addr`."""
439 assert isinstance(addr, int)
440 assert 0 <= addr < self.table_addr_count
441 _assert_accuracy(self.io_width >= self.table_addr_bits)
442 addr_shift = self.io_width - self.table_addr_bits
443 min_numerator = (1 << self.io_width) + (addr << addr_shift)
444 denominator = 1 << self.io_width
445 values_per_table_entry = 1 << addr_shift
446 max_numerator = min_numerator + values_per_table_entry - 1
447 min_input = Fraction(min_numerator, denominator)
448 max_input = Fraction(max_numerator, denominator)
449 min_input = self._shrink_min(min_input)
450 max_input = self._shrink_max(max_input)
451 assert 1 <= min_input <= max_input < 2
452 return min_input, max_input
453
454 def table_value_exact_range(self, addr):
455 """return the range of values as `Fraction`s used for the table entry
456 with address `addr`."""
457 min_input, max_input = self.table_input_exact_range(addr)
458 # division swaps min/max
459 min_value = 1 / max_input
460 max_value = 1 / min_input
461 min_value = self._shrink_min(min_value)
462 max_value = self._shrink_max(max_value)
463 assert 0.5 < min_value <= max_value <= 1
464 return min_value, max_value
465
466 def table_exact_value(self, index):
467 min_value, max_value = self.table_value_exact_range(index)
468 # we round down
469 return min_value
470
471 def __post_init__(self):
472 # called by the autogenerated __init__
473 _assert_accuracy(self.io_width >= 1, "io_width out of range")
474 _assert_accuracy(self.extra_precision >= 0,
475 "extra_precision out of range")
476 _assert_accuracy(self.table_addr_bits >= 1,
477 "table_addr_bits out of range")
478 _assert_accuracy(self.table_data_bits >= 1,
479 "table_data_bits out of range")
480 _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
481 table = []
482 for addr in range(1 << self.table_addr_bits):
483 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
484 self.table_data_bits,
485 RoundDir.DOWN))
486 # we have to use object.__setattr__ since frozen=True
487 object.__setattr__(self, "table", tuple(table))
488 object.__setattr__(self, "ops", tuple(self.__make_ops()))
489
490 @property
491 def expanded_width(self):
492 """the total number of bits of precision used inside the algorithm."""
493 return self.io_width + self.extra_precision
494
495 @cache_on_self
496 def max_neps(self, i):
497 """maximum value of `neps[i]`.
498 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
499 """
500 assert isinstance(i, int) and 0 <= i < self.iter_count
501 return Fraction(1, 1 << self.expanded_width)
502
503 @cache_on_self
504 def max_deps(self, i):
505 """maximum value of `deps[i]`.
506 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
507 """
508 assert isinstance(i, int) and 0 <= i < self.iter_count
509 return Fraction(1, 1 << self.expanded_width)
510
511 @cache_on_self
512 def max_feps(self, i):
513 """maximum value of `feps[i]`.
514 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
515 """
516 assert isinstance(i, int) and 0 <= i < self.iter_count
517 # zero, because the computation of `F_prime[i]` in
518 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
519 return Fraction(0)
520
521 @cached_property
522 def e0_range(self):
523 """minimum and maximum values of `e[0]`
524 (the relative error in `F_prime[-1]`)
525 """
526 min_e0 = Fraction(0)
527 max_e0 = Fraction(0)
528 for addr in range(self.table_addr_count):
529 # `F_prime[-1] = (1 - e[0]) / B`
530 # => `e[0] = 1 - B * F_prime[-1]`
531 min_b, max_b = self.table_input_exact_range(addr)
532 f_prime_m1 = self.table[addr].as_fraction()
533 assert min_b >= 0 and f_prime_m1 >= 0, \
534 "only positive quadrant of interval multiplication implemented"
535 min_product = min_b * f_prime_m1
536 max_product = max_b * f_prime_m1
537 # negation swaps min/max
538 cur_min_e0 = 1 - max_product
539 cur_max_e0 = 1 - min_product
540 min_e0 = min(min_e0, cur_min_e0)
541 max_e0 = max(max_e0, cur_max_e0)
542 min_e0 = self._shrink_min(min_e0)
543 max_e0 = self._shrink_max(max_e0)
544 return min_e0, max_e0
545
546 @cached_property
547 def min_e0(self):
548 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
549 """
550 min_e0, max_e0 = self.e0_range
551 return min_e0
552
553 @cached_property
554 def max_e0(self):
555 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
556 """
557 min_e0, max_e0 = self.e0_range
558 return max_e0
559
560 @cached_property
561 def max_abs_e0(self):
562 """maximum value of `abs(e[0])`."""
563 return max(abs(self.min_e0), abs(self.max_e0))
564
565 @cached_property
566 def min_abs_e0(self):
567 """minimum value of `abs(e[0])`."""
568 return Fraction(0)
569
570 @cache_on_self
571 def max_n(self, i):
572 """maximum value of `n[i]` (the relative error in `N_prime[i]`
573 relative to the previous iteration)
574 """
575 assert isinstance(i, int) and 0 <= i < self.iter_count
576 if i == 0:
577 # from Claim 10
578 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
579 # `n[0] <= 2 * neps[0] / (1 - e[0])`
580
581 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
582 "only one quadrant of interval division implemented"
583 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
584 elif i == 1:
585 # from Claim 10
586 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
587 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
588 assert self.max_f(0) <= 1 and min_mpd >= 0, \
589 "only one quadrant of interval multiplication implemented"
590 prod = (1 - self.max_f(0)) * min_mpd
591 assert self.max_neps(1) >= 0 and prod > 0, \
592 "only one quadrant of interval division implemented"
593 retval = self.max_neps(1) / prod
594 else:
595 # from Claim 6
596 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
597 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
598 assert self.max_neps(i) >= 0 and min_mpd > 0, \
599 "only one quadrant of interval division implemented"
600 retval = self.max_neps(i) / min_mpd
601
602 return self._shrink_max(retval)
603
604 @cache_on_self
605 def max_d(self, i):
606 """maximum value of `d[i]` (the relative error in `D_prime[i]`
607 relative to the previous iteration)
608 """
609 assert isinstance(i, int) and 0 <= i < self.iter_count
610 if i == 0:
611 # from Claim 10
612 # `d[0] = deps[0] / (1 - e[0])`
613
614 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
615 "only one quadrant of interval division implemented"
616 retval = self.max_deps(0) / (1 - self.max_e0)
617 elif i == 1:
618 # from Claim 10
619 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
620 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
621 "only one quadrant of interval multiplication implemented"
622 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
623 assert self.max_deps(1) >= 0 and divisor > 0, \
624 "only one quadrant of interval division implemented"
625 retval = self.max_deps(1) / divisor
626 else:
627 # from Claim 6
628 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
629 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
630 "only one quadrant of interval division implemented"
631 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
632
633 return self._shrink_max(retval)
634
635 @cache_on_self
636 def max_f(self, i):
637 """maximum value of `f[i]` (the relative error in `F_prime[i]`
638 relative to the previous iteration)
639 """
640 assert isinstance(i, int) and 0 <= i < self.iter_count
641 if i == 0:
642 # from Claim 10
643 # `f[0] = feps[0] / (1 - delta[0])`
644
645 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
646 "only one quadrant of interval division implemented"
647 retval = self.max_feps(0) / (1 - self.max_delta(0))
648 elif i == 1:
649 # from Claim 10
650 # `f[1] = feps[1]`
651 retval = self.max_feps(1)
652 else:
653 # from Claim 6
654 # `f[i] <= max_feps[i]`
655 retval = self.max_feps(i)
656
657 return self._shrink_max(retval)
658
659 @cache_on_self
660 def max_delta(self, i):
661 """ maximum value of `delta[i]`.
662 `delta[i]` is defined in Definition 4 of paper.
663 """
664 assert isinstance(i, int) and 0 <= i < self.iter_count
665 if i == 0:
666 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
667 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
668 else:
669 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
670 prev_max_delta = self.max_delta(i - 1)
671 assert prev_max_delta >= 0
672 retval = prev_max_delta ** 2 + self.max_f(i - 1)
673
674 # `delta[i]` has to be smaller than one otherwise errors would go off
675 # to infinity
676 _assert_accuracy(retval < 1)
677
678 return self._shrink_max(retval)
679
680 @cache_on_self
681 def max_pi(self, i):
682 """ maximum value of `pi[i]`.
683 `pi[i]` is defined right below Theorem 5 of paper.
684 """
685 assert isinstance(i, int) and 0 <= i < self.iter_count
686 # `pi[i] = 1 - (1 - n[i]) * prod`
687 # where `prod` is the product of,
688 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
689 min_prod = Fraction(1)
690 for j in range(i):
691 max_n_j = self.max_n(j)
692 max_d_j = self.max_d(j)
693 assert max_n_j <= 1 and max_d_j > -1, \
694 "only one quadrant of interval division implemented"
695 min_prod *= (1 - max_n_j) / (1 + max_d_j)
696 max_n_i = self.max_n(i)
697 assert max_n_i <= 1 and min_prod >= 0, \
698 "only one quadrant of interval multiplication implemented"
699 retval = 1 - (1 - max_n_i) * min_prod
700 return self._shrink_max(retval)
701
702 @cached_property
703 def max_n_shift(self):
704 """ maximum value of `state.n_shift`.
705 """
706 # input numerator is `2*io_width`-bits
707 max_n = (1 << (self.io_width * 2)) - 1
708 max_n_shift = 0
709 # normalize so 1 <= n < 2
710 while max_n >= 2:
711 max_n >>= 1
712 max_n_shift += 1
713 return max_n_shift
714
715 @cached_property
716 def n_hat(self):
717 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
718 """
719 n_hat = Fraction(0)
720 for i in range(self.iter_count):
721 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
722 return self._shrink_max(n_hat)
723
724 def __make_ops(self):
725 """ Goldschmidt division algorithm.
726
727 based on:
728 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
729 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
730 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
731
732 yields: GoldschmidtDivOp
733 the operations needed to perform the division.
734 """
735 # establish assumptions of the paper's error analysis (section 3.1):
736
737 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
738 yield GoldschmidtDivOp.Normalize
739
740 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
741 # the assumption is met by multipliers with > 4-bits precision
742 _assert_accuracy(self.expanded_width > 4)
743
744 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
745 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
746 + self.max_f(0) < Fraction(1, 2))
747
748 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
749 # (B is the denominator)
750
751 for addr in range(self.table_addr_count):
752 f_prime_m1 = self.table[addr]
753 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
754
755 yield GoldschmidtDivOp.FEqTableLookup
756
757 # we use Setting I (section 4.1 of the paper):
758 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
759 # the conditions on n_hat are satisfied by construction.
760 for i in range(self.iter_count):
761 _assert_accuracy(self.max_f(i) == 0)
762 yield GoldschmidtDivOp.MulNByF
763 if i != self.iter_count - 1:
764 yield GoldschmidtDivOp.MulDByF
765 yield GoldschmidtDivOp.FEq2MinusD
766
767 # relative approximation error `p(N_prime[i])`:
768 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
769 # `0 <= p(N_prime[i])`
770 # `p(N_prime[i]) <= (2 * i) * n_hat \`
771 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
772 i = self.iter_count - 1 # last used `i`
773 # compute power manually to prevent huge intermediate values
774 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
775 for _ in range(i):
776 power = self._shrink_max(power * power)
777
778 max_rel_error = (2 * i) * self.n_hat + power
779
780 min_a_over_b = Fraction(1, 2)
781 max_a_over_b = Fraction(2)
782 max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
783 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
784
785 _assert_accuracy(max_rel_error < max_allowed_rel_error,
786 f"not accurate enough: max_rel_error={max_rel_error}"
787 f" max_allowed_rel_error={max_allowed_rel_error}")
788
789 yield GoldschmidtDivOp.CalcResult
790
791 @cache_on_self
792 def default_cost_fn(self):
793 """ calculate the estimated cost on an arbitrary scale of implementing
794 goldschmidt division with the specified parameters. larger cost
795 values mean worse parameters.
796
797 This is the default cost function for `GoldschmidtDivParams.get`.
798
799 returns: float
800 """
801 rom_cells = self.table_data_bits << self.table_addr_bits
802 cost = float(rom_cells)
803 for op in self.ops:
804 if op == GoldschmidtDivOp.MulNByF \
805 or op == GoldschmidtDivOp.MulDByF:
806 mul_cost = self.expanded_width ** 2
807 mul_cost *= self.expanded_width.bit_length()
808 cost += mul_cost
809 cost += 5e7 * self.iter_count
810 return cost
811
812 @staticmethod
813 @lru_cache(maxsize=1 << 16)
814 def __cached_new(base_params):
815 assert isinstance(base_params, GoldschmidtDivParamsBase)
816 # can't use dataclasses.asdict, since it's recursive and will also give
817 # child class fields too, which we don't want.
818 kwargs = {}
819 for field in fields(GoldschmidtDivParamsBase):
820 kwargs[field.name] = getattr(base_params, field.name)
821 try:
822 return GoldschmidtDivParams(**kwargs), None
823 except ParamsNotAccurateEnough as e:
824 return None, e
825
826 @staticmethod
827 def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
828 raise e
829
830 @staticmethod
831 def cached_new(base_params, handle_error=__raise):
832 assert isinstance(base_params, GoldschmidtDivParamsBase)
833 params, error = GoldschmidtDivParams.__cached_new(base_params)
834 if error is None:
835 return params
836 else:
837 return handle_error(error)
838
839 @staticmethod
840 def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
841 """ find efficient parameters for a goldschmidt division algorithm
842 with `params.io_width == io_width`.
843
844 arguments:
845 io_width: int
846 bit-width of the input divisor and the result.
847 the input numerator is `2 * io_width`-bits wide.
848 cost_fn: Callable[[GoldschmidtDivParams], float]
849 return the estimated cost on an arbitrary scale of implementing
850 goldschmidt division with the specified parameters. larger cost
851 values mean worse parameters.
852 max_table_addr_bits: int
853 maximum allowable value of `table_addr_bits`
854 """
855 assert isinstance(io_width, int) and io_width >= 1
856 assert callable(cost_fn)
857
858 last_error = None
859 last_error_params = None
860
861 def cached_new(base_params):
862 def handle_error(e):
863 nonlocal last_error, last_error_params
864 last_error = e
865 last_error_params = base_params
866 return None
867
868 retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
869 if retval is None:
870 logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
871 else:
872 logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
873 return retval
874
875 @lru_cache(maxsize=None)
876 def get_cost(base_params):
877 params = cached_new(base_params)
878 if params is None:
879 return math.inf
880 retval = cost_fn(params)
881 logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
882 return retval
883
884 # start with parameters big enough to always work.
885 initial_extra_precision = io_width * 2 + 4
886 initial_params = GoldschmidtDivParamsBase(
887 io_width=io_width,
888 extra_precision=initial_extra_precision,
889 table_addr_bits=min(max_table_addr_bits, io_width),
890 table_data_bits=io_width + initial_extra_precision,
891 iter_count=1 + io_width.bit_length())
892
893 if cached_new(initial_params) is None:
894 raise ValueError(f"initial goldschmidt division algorithm "
895 f"parameters are invalid: {initial_params}"
896 ) from last_error
897
898 # find good initial `iter_count`
899 params = initial_params
900 for iter_count in range(1, initial_params.iter_count):
901 trial_params = replace(params, iter_count=iter_count)
902 if cached_new(trial_params) is not None:
903 params = trial_params
904 break
905
906 # now find `table_addr_bits`
907 cost = get_cost(params)
908 for table_addr_bits in range(1, max_table_addr_bits):
909 trial_params = replace(params, table_addr_bits=table_addr_bits)
910 trial_cost = get_cost(trial_params)
911 if trial_cost < cost:
912 params = trial_params
913 cost = trial_cost
914 break
915
916 # check one higher `iter_count` to see if it has lower cost
917 for table_addr_bits in range(1, max_table_addr_bits + 1):
918 trial_params = replace(params,
919 table_addr_bits=table_addr_bits,
920 iter_count=params.iter_count + 1)
921 trial_cost = get_cost(trial_params)
922 if trial_cost < cost:
923 params = trial_params
924 cost = trial_cost
925 break
926
927 # now shrink `table_data_bits`
928 while True:
929 trial_params = replace(params,
930 table_data_bits=params.table_data_bits - 1)
931 trial_cost = get_cost(trial_params)
932 if trial_cost < cost:
933 params = trial_params
934 cost = trial_cost
935 else:
936 break
937
938 # and shrink `extra_precision`
939 while True:
940 trial_params = replace(params,
941 extra_precision=params.extra_precision - 1)
942 trial_cost = get_cost(trial_params)
943 if trial_cost < cost:
944 params = trial_params
945 cost = trial_cost
946 else:
947 break
948
949 return cached_new(params)
950
951
952 @enum.unique
953 class GoldschmidtDivOp(enum.Enum):
954 Normalize = "n, d, n_shift = normalize(n, d)"
955 FEqTableLookup = "f = table_lookup(d)"
956 MulNByF = "n *= f"
957 MulDByF = "d *= f"
958 FEq2MinusD = "f = 2 - d"
959 CalcResult = "result = unnormalize_and_round(n)"
960
961 def run(self, params, state):
962 assert isinstance(params, GoldschmidtDivParams)
963 assert isinstance(state, GoldschmidtDivState)
964 expanded_width = params.expanded_width
965 table_addr_bits = params.table_addr_bits
966 if self == GoldschmidtDivOp.Normalize:
967 # normalize so 1 <= d < 2
968 # can easily be done with count-leading-zeros and left shift
969 while state.d < 1:
970 state.n = (state.n * 2).to_frac_wid(expanded_width)
971 state.d = (state.d * 2).to_frac_wid(expanded_width)
972
973 state.n_shift = 0
974 # normalize so 1 <= n < 2
975 while state.n >= 2:
976 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
977 state.n_shift += 1
978 elif self == GoldschmidtDivOp.FEqTableLookup:
979 # compute initial f by table lookup
980 d_m_1 = state.d - 1
981 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
982 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
983 state.f = params.table[d_m_1.bits]
984 elif self == GoldschmidtDivOp.MulNByF:
985 assert state.f is not None
986 n = state.n * state.f
987 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
988 elif self == GoldschmidtDivOp.MulDByF:
989 assert state.f is not None
990 d = state.d * state.f
991 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
992 elif self == GoldschmidtDivOp.FEq2MinusD:
993 state.f = (2 - state.d).to_frac_wid(expanded_width)
994 elif self == GoldschmidtDivOp.CalcResult:
995 assert state.n_shift is not None
996 # scale to correct value
997 n = state.n * (1 << state.n_shift)
998
999 state.quotient = math.floor(n)
1000 state.remainder = state.orig_n - state.quotient * state.orig_d
1001 if state.remainder >= state.orig_d:
1002 state.quotient += 1
1003 state.remainder -= state.orig_d
1004 else:
1005 assert False, f"unimplemented GoldschmidtDivOp: {self}"
1006
1007
1008 def goldschmidt_div(n, d, params):
1009 """ Goldschmidt division algorithm.
1010
1011 based on:
1012 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1013 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1014 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1015
1016 arguments:
1017 n: int
1018 numerator. a `2*width`-bit unsigned integer.
1019 must be less than `d << width`, otherwise the quotient wouldn't
1020 fit in `width` bits.
1021 d: int
1022 denominator. a `width`-bit unsigned integer. must not be zero.
1023 width: int
1024 the bit-width of the inputs/outputs. must be a positive integer.
1025
1026 returns: tuple[int, int]
1027 the quotient and remainder. a tuple of two `width`-bit unsigned
1028 integers.
1029 """
1030 assert isinstance(params, GoldschmidtDivParams)
1031 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
1032 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
1033
1034 # this whole algorithm is done with fixed-point arithmetic where values
1035 # have `width` fractional bits
1036
1037 state = GoldschmidtDivState(
1038 orig_n=n,
1039 orig_d=d,
1040 n=FixedPoint(n, params.io_width),
1041 d=FixedPoint(d, params.io_width),
1042 )
1043
1044 for op in params.ops:
1045 op.run(params, state)
1046
1047 assert state.quotient is not None
1048 assert state.remainder is not None
1049
1050 return state.quotient, state.remainder
1051
1052
1053 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
1054
1055
1056 @lru_cache()
1057 def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
1058 """Generate the look-up table needed for Goldschmidt's square-root and
1059 reciprocal-square-root algorithm.
1060
1061 arguments:
1062 table_addr_bits: int
1063 the number of address bits for the look-up table.
1064 table_data_bits: int
1065 the number of data bits for the look-up table.
1066 """
1067 assert isinstance(table_addr_bits, int) and \
1068 table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1069 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1070 table = []
1071 table_len = 1 << table_addr_bits
1072 for addr in range(table_len):
1073 if addr == 0:
1074 value = FixedPoint(0, table_data_bits)
1075 elif (addr << 2) < table_len:
1076 value = None # table entries should be unused
1077 else:
1078 table_addr_frac_wid = table_addr_bits
1079 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1080 max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
1081 max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
1082 value = max_input_value.to_frac_wid(max_frac_wid)
1083 value = value.rsqrt(RoundDir.DOWN)
1084 value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
1085 table.append(value)
1086
1087 # tuple for immutability
1088 return tuple(table)
1089
1090
1091 def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
1092 table_addr_bits, table_data_bits, iter_count):
1093 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1094
1095 uses algorithm based on second method at:
1096 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1097
1098 arguments:
1099 radicand: FixedPoint(frac_wid=frac_wid)
1100 the input value to take the square-root and reciprocal-square-root of.
1101 io_width: int
1102 the number of bits in the input (`radicand`) and output values.
1103 frac_wid: int
1104 the number of fraction bits in the input (`radicand`) and output
1105 values.
1106 extra_precision: int
1107 the number of bits of internal extra precision.
1108 table_addr_bits: int
1109 the number of address bits for the look-up table.
1110 table_data_bits: int
1111 the number of data bits for the look-up table.
1112
1113 returns: tuple[FixedPoint, FixedPoint]
1114 the square-root and reciprocal-square-root, rounded down to the
1115 nearest representable value. If `radicand == 0`, then the
1116 reciprocal-square-root value returned is zero.
1117 """
1118 assert (isinstance(radicand, FixedPoint)
1119 and radicand.frac_wid == frac_wid
1120 and 0 <= radicand.bits < (1 << io_width))
1121 assert isinstance(io_width, int) and io_width >= 1
1122 assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
1123 assert isinstance(extra_precision, int) and extra_precision >= io_width
1124 assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
1125 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1126 assert isinstance(iter_count, int) and iter_count >= 0
1127 expanded_frac_wid = frac_wid + extra_precision
1128 s = radicand.to_frac_wid(expanded_frac_wid)
1129 sqrt_rshift = extra_precision
1130 rsqrt_rshift = extra_precision
1131 while s != 0 and s < 1:
1132 s = (s * 4).to_frac_wid(expanded_frac_wid)
1133 sqrt_rshift += 1
1134 rsqrt_rshift -= 1
1135 while s >= 4:
1136 s = s.div(4, expanded_frac_wid)
1137 sqrt_rshift -= 1
1138 rsqrt_rshift += 1
1139 table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
1140 table_data_bits=table_data_bits)
1141 # core goldschmidt sqrt/rsqrt algorithm:
1142 # initial setup:
1143 table_addr_frac_wid = table_addr_bits
1144 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1145 addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
1146 assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
1147 f = table[addr.bits]
1148 assert f is not None, "accessed invalid table entry"
1149 # use with_frac_wid to fix IDE type deduction
1150 f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
1151 x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1152 h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1153 for _ in range(iter_count):
1154 # iteration step:
1155 f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1156 x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1157 h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1158 r = 2 * h
1159 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1160
1161 sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
1162 rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
1163
1164 next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
1165 if next_sqrt * next_sqrt <= radicand:
1166 sqrt = next_sqrt
1167
1168 next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
1169 if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
1170 rsqrt = next_rsqrt
1171 return sqrt, rsqrt