c3837e9afc2f661ffc864d088d3ead1a3b0cd2ab
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field, fields, replace
8 import logging
9 import math
10 import enum
11 from fractions import Fraction
12 from types import FunctionType
13 from functools import lru_cache
14
15 try:
16 from functools import cached_property
17 except ImportError:
18 from cached_property import cached_property
19
20 # fix broken IDE type detection for cached_property
21 from typing import TYPE_CHECKING, Any
22 if TYPE_CHECKING:
23 from functools import cached_property
24
25
26 _NOT_FOUND = object()
27
28
29 def cache_on_self(func):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
32 per-method."""
33
34 assert isinstance(func, FunctionType), \
35 "non-plain methods are not supported"
36
37 cache_name = func.__name__ + "__cache"
38
39 def wrapper(self, *args, **kwargs):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache = self.__dict__.get(cache_name, _NOT_FOUND)
42 if cache is _NOT_FOUND:
43 self.__dict__[cache_name] = cache = {}
44 key = (args, *kwargs.items())
45 retval = cache.get(key, _NOT_FOUND)
46 if retval is _NOT_FOUND:
47 retval = func(self, *args, **kwargs)
48 cache[key] = retval
49 return retval
50
51 wrapper.__doc__ = func.__doc__
52 return wrapper
53
54
55 @enum.unique
56 class RoundDir(enum.Enum):
57 DOWN = enum.auto()
58 UP = enum.auto()
59 NEAREST_TIES_UP = enum.auto()
60 ERROR_IF_INEXACT = enum.auto()
61
62
63 @dataclass(frozen=True)
64 class FixedPoint:
65 bits: int
66 frac_wid: int
67
68 def __post_init__(self):
69 assert isinstance(self.bits, int)
70 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
71
72 @staticmethod
73 def cast(value):
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value, FixedPoint):
77 return value
78 if isinstance(value, int):
79 return FixedPoint(value, 0)
80 if isinstance(value, str):
81 value = value.strip()
82 neg = value.startswith("-")
83 if neg or value.startswith("+"):
84 value = value[1:]
85 if value.startswith(("0x", "0X")) and "." in value:
86 value = value[2:]
87 got_dot = False
88 bits = 0
89 frac_wid = 0
90 for digit in value:
91 if digit == "_":
92 continue
93 if got_dot:
94 if digit == ".":
95 raise ValueError("too many `.` in string")
96 frac_wid += 4
97 if digit == ".":
98 got_dot = True
99 continue
100 if not digit.isalnum():
101 raise ValueError("invalid hexadecimal digit")
102 bits <<= 4
103 bits |= int("0x" + digit, base=16)
104 else:
105 bits = int(value, base=0)
106 frac_wid = 0
107 if neg:
108 bits = -bits
109 return FixedPoint(bits, frac_wid)
110
111 if isinstance(value, float):
112 n, d = value.as_integer_ratio()
113 log2_d = d.bit_length() - 1
114 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n, log2_d)
117 raise TypeError("can't convert type to FixedPoint")
118
119 @staticmethod
120 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid, int) and frac_wid >= 0
124 assert isinstance(round_dir, RoundDir)
125 if isinstance(value, Fraction):
126 numerator = value.numerator
127 denominator = value.denominator
128 else:
129 value = FixedPoint.cast(value)
130 numerator = value.bits
131 denominator = 1 << value.frac_wid
132 if denominator < 0:
133 numerator = -numerator
134 denominator = -denominator
135 bits, remainder = divmod(numerator << frac_wid, denominator)
136 if round_dir == RoundDir.DOWN:
137 pass
138 elif round_dir == RoundDir.UP:
139 if remainder != 0:
140 bits += 1
141 elif round_dir == RoundDir.NEAREST_TIES_UP:
142 if remainder * 2 >= denominator:
143 bits += 1
144 elif round_dir == RoundDir.ERROR_IF_INEXACT:
145 if remainder != 0:
146 raise ValueError("inexact conversion")
147 else:
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits, frac_wid)
150
151 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
155
156 def __float__(self):
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self.bits / (1 << self.frac_wid))
160
161 def as_fraction(self):
162 return Fraction(self.bits, 1 << self.frac_wid)
163
164 def cmp(self, rhs):
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs = FixedPoint.cast(rhs)
169 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
170 lhs = self.to_frac_wid(common_frac_wid)
171 rhs = rhs.to_frac_wid(common_frac_wid)
172 return lhs.bits - rhs.bits
173
174 def __eq__(self, rhs):
175 return self.cmp(rhs) == 0
176
177 def __ne__(self, rhs):
178 return self.cmp(rhs) != 0
179
180 def __gt__(self, rhs):
181 return self.cmp(rhs) > 0
182
183 def __lt__(self, rhs):
184 return self.cmp(rhs) < 0
185
186 def __ge__(self, rhs):
187 return self.cmp(rhs) >= 0
188
189 def __le__(self, rhs):
190 return self.cmp(rhs) <= 0
191
192 def fract(self):
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
195 """
196 fract_mask = (1 << self.frac_wid) - 1
197 return FixedPoint(self.bits & fract_mask, self.frac_wid)
198
199 def __str__(self):
200 if self < 0:
201 return "-" + str(-self)
202 digit_bits = 4
203 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
204 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
205 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
206 return hex(math.floor(self)) + "." + frac_str
207
208 def __repr__(self):
209 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
210
211 def __add__(self, rhs):
212 rhs = FixedPoint.cast(rhs)
213 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
214 lhs = self.to_frac_wid(common_frac_wid)
215 rhs = rhs.to_frac_wid(common_frac_wid)
216 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
217
218 def __radd__(self, lhs):
219 # symmetric
220 return self.__add__(lhs)
221
222 def __neg__(self):
223 return FixedPoint(-self.bits, self.frac_wid)
224
225 def __sub__(self, rhs):
226 rhs = FixedPoint.cast(rhs)
227 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
228 lhs = self.to_frac_wid(common_frac_wid)
229 rhs = rhs.to_frac_wid(common_frac_wid)
230 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
231
232 def __rsub__(self, lhs):
233 # a - b == -(b - a)
234 return -self.__sub__(lhs)
235
236 def __mul__(self, rhs):
237 rhs = FixedPoint.cast(rhs)
238 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
239
240 def __rmul__(self, lhs):
241 # symmetric
242 return self.__mul__(lhs)
243
244 def __floor__(self):
245 return self.bits >> self.frac_wid
246
247 def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
248 assert isinstance(frac_wid, int) and frac_wid >= 0
249 assert isinstance(round_dir, RoundDir)
250 rhs = FixedPoint.cast(rhs)
251 return FixedPoint.with_frac_wid(self.as_fraction()
252 / rhs.as_fraction(),
253 frac_wid, round_dir)
254
255 def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
256 assert isinstance(round_dir, RoundDir)
257 if self < 0:
258 raise ValueError("can't compute sqrt of negative number")
259 if self == 0:
260 return self
261 retval = FixedPoint(0, self.frac_wid)
262 int_part_wid = self.bits.bit_length() - self.frac_wid
263 first_bit_index = -(-int_part_wid // 2) # division rounds up
264 last_bit_index = -self.frac_wid
265 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
266 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
267 self.frac_wid)
268 if trial * trial <= self:
269 retval = trial
270 if round_dir == RoundDir.DOWN:
271 pass
272 elif round_dir == RoundDir.UP:
273 if retval * retval < self:
274 retval += FixedPoint(1, self.frac_wid)
275 elif round_dir == RoundDir.NEAREST_TIES_UP:
276 half_way = retval + FixedPoint(1, self.frac_wid + 1)
277 if half_way * half_way <= self:
278 retval += FixedPoint(1, self.frac_wid)
279 elif round_dir == RoundDir.ERROR_IF_INEXACT:
280 if retval * retval != self:
281 raise ValueError("inexact sqrt")
282 else:
283 assert False, "unimplemented round_dir"
284 return retval
285
286 def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
287 """compute the reciprocal-sqrt of `self`"""
288 assert isinstance(round_dir, RoundDir)
289 if self < 0:
290 raise ValueError("can't compute rsqrt of negative number")
291 if self == 0:
292 raise ZeroDivisionError("can't compute rsqrt of zero")
293 retval = FixedPoint(0, self.frac_wid)
294 first_bit_index = -(-self.frac_wid // 2) # division rounds up
295 last_bit_index = -self.frac_wid
296 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
297 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
298 self.frac_wid)
299 if trial * trial * self <= 1:
300 retval = trial
301 if round_dir == RoundDir.DOWN:
302 pass
303 elif round_dir == RoundDir.UP:
304 if retval * retval * self < 1:
305 retval += FixedPoint(1, self.frac_wid)
306 elif round_dir == RoundDir.NEAREST_TIES_UP:
307 half_way = retval + FixedPoint(1, self.frac_wid + 1)
308 if half_way * half_way * self <= 1:
309 retval += FixedPoint(1, self.frac_wid)
310 elif round_dir == RoundDir.ERROR_IF_INEXACT:
311 if retval * retval * self != 1:
312 raise ValueError("inexact rsqrt")
313 else:
314 assert False, "unimplemented round_dir"
315 return retval
316
317
318 class ParamsNotAccurateEnough(Exception):
319 """raised when the parameters aren't accurate enough to have goldschmidt
320 division work."""
321
322
323 def _assert_accuracy(condition, msg="not accurate enough"):
324 if condition:
325 return
326 raise ParamsNotAccurateEnough(msg)
327
328
329 @dataclass(frozen=True, unsafe_hash=True)
330 class GoldschmidtDivParamsBase:
331 """parameters for a Goldschmidt division algorithm, excluding derived
332 parameters.
333 """
334
335 io_width: int
336 """bit-width of the input divisor and the result.
337 the input numerator is `2 * io_width`-bits wide.
338 """
339
340 extra_precision: int
341 """number of bits of additional precision used inside the algorithm."""
342
343 table_addr_bits: int
344 """the number of address bits used in the lookup-table."""
345
346 table_data_bits: int
347 """the number of data bits used in the lookup-table."""
348
349 iter_count: int
350 """the total number of iterations of the division algorithm's loop"""
351
352
353 @dataclass(frozen=True, unsafe_hash=True)
354 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
355 """parameters for a Goldschmidt division algorithm.
356 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
357 """
358
359 # tuple to be immutable, repr=False so repr() works for debugging even when
360 # __post_init__ hasn't finished running yet
361 table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
362 """the lookup-table"""
363
364 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
365 """the operations needed to perform the goldschmidt division algorithm."""
366
367 def _shrink_bound(self, bound, round_dir):
368 """prevent fractions from having huge numerators/denominators by
369 rounding to a `FixedPoint` and converting back to a `Fraction`.
370
371 This is intended only for values used to compute bounds, and not for
372 values that end up in the hardware.
373 """
374 assert isinstance(bound, (Fraction, int))
375 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
376 "you shouldn't use that round_dir on bounds"
377 frac_wid = self.io_width * 4 + 100 # should be enough precision
378 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
379 return fixed.as_fraction()
380
381 def _shrink_min(self, min_bound):
382 """prevent fractions used as minimum bounds from having huge
383 numerators/denominators by rounding down to a `FixedPoint` and
384 converting back to a `Fraction`.
385
386 This is intended only for values used to compute bounds, and not for
387 values that end up in the hardware.
388 """
389 return self._shrink_bound(min_bound, RoundDir.DOWN)
390
391 def _shrink_max(self, max_bound):
392 """prevent fractions used as maximum bounds from having huge
393 numerators/denominators by rounding up to a `FixedPoint` and
394 converting back to a `Fraction`.
395
396 This is intended only for values used to compute bounds, and not for
397 values that end up in the hardware.
398 """
399 return self._shrink_bound(max_bound, RoundDir.UP)
400
401 @property
402 def table_addr_count(self):
403 """number of distinct addresses in the lookup-table."""
404 # used while computing self.table, so can't just do len(self.table)
405 return 1 << self.table_addr_bits
406
407 def table_input_exact_range(self, addr):
408 """return the range of inputs as `Fraction`s used for the table entry
409 with address `addr`."""
410 assert isinstance(addr, int)
411 assert 0 <= addr < self.table_addr_count
412 _assert_accuracy(self.io_width >= self.table_addr_bits)
413 addr_shift = self.io_width - self.table_addr_bits
414 min_numerator = (1 << self.io_width) + (addr << addr_shift)
415 denominator = 1 << self.io_width
416 values_per_table_entry = 1 << addr_shift
417 max_numerator = min_numerator + values_per_table_entry - 1
418 min_input = Fraction(min_numerator, denominator)
419 max_input = Fraction(max_numerator, denominator)
420 min_input = self._shrink_min(min_input)
421 max_input = self._shrink_max(max_input)
422 assert 1 <= min_input <= max_input < 2
423 return min_input, max_input
424
425 def table_value_exact_range(self, addr):
426 """return the range of values as `Fraction`s used for the table entry
427 with address `addr`."""
428 min_input, max_input = self.table_input_exact_range(addr)
429 # division swaps min/max
430 min_value = 1 / max_input
431 max_value = 1 / min_input
432 min_value = self._shrink_min(min_value)
433 max_value = self._shrink_max(max_value)
434 assert 0.5 < min_value <= max_value <= 1
435 return min_value, max_value
436
437 def table_exact_value(self, index):
438 min_value, max_value = self.table_value_exact_range(index)
439 # we round down
440 return min_value
441
442 def __post_init__(self):
443 # called by the autogenerated __init__
444 _assert_accuracy(self.io_width >= 1, "io_width out of range")
445 _assert_accuracy(self.extra_precision >= 0,
446 "extra_precision out of range")
447 _assert_accuracy(self.table_addr_bits >= 1,
448 "table_addr_bits out of range")
449 _assert_accuracy(self.table_data_bits >= 1,
450 "table_data_bits out of range")
451 _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
452 table = []
453 for addr in range(1 << self.table_addr_bits):
454 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
455 self.table_data_bits,
456 RoundDir.DOWN))
457 # we have to use object.__setattr__ since frozen=True
458 object.__setattr__(self, "table", tuple(table))
459 object.__setattr__(self, "ops", tuple(self.__make_ops()))
460
461 @property
462 def expanded_width(self):
463 """the total number of bits of precision used inside the algorithm."""
464 return self.io_width + self.extra_precision
465
466 @cache_on_self
467 def max_neps(self, i):
468 """maximum value of `neps[i]`.
469 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
470 """
471 assert isinstance(i, int) and 0 <= i < self.iter_count
472 return Fraction(1, 1 << self.expanded_width)
473
474 @cache_on_self
475 def max_deps(self, i):
476 """maximum value of `deps[i]`.
477 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
478 """
479 assert isinstance(i, int) and 0 <= i < self.iter_count
480 return Fraction(1, 1 << self.expanded_width)
481
482 @cache_on_self
483 def max_feps(self, i):
484 """maximum value of `feps[i]`.
485 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
486 """
487 assert isinstance(i, int) and 0 <= i < self.iter_count
488 # zero, because the computation of `F_prime[i]` in
489 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
490 return Fraction(0)
491
492 @cached_property
493 def e0_range(self):
494 """minimum and maximum values of `e[0]`
495 (the relative error in `F_prime[-1]`)
496 """
497 min_e0 = Fraction(0)
498 max_e0 = Fraction(0)
499 for addr in range(self.table_addr_count):
500 # `F_prime[-1] = (1 - e[0]) / B`
501 # => `e[0] = 1 - B * F_prime[-1]`
502 min_b, max_b = self.table_input_exact_range(addr)
503 f_prime_m1 = self.table[addr].as_fraction()
504 assert min_b >= 0 and f_prime_m1 >= 0, \
505 "only positive quadrant of interval multiplication implemented"
506 min_product = min_b * f_prime_m1
507 max_product = max_b * f_prime_m1
508 # negation swaps min/max
509 cur_min_e0 = 1 - max_product
510 cur_max_e0 = 1 - min_product
511 min_e0 = min(min_e0, cur_min_e0)
512 max_e0 = max(max_e0, cur_max_e0)
513 min_e0 = self._shrink_min(min_e0)
514 max_e0 = self._shrink_max(max_e0)
515 return min_e0, max_e0
516
517 @cached_property
518 def min_e0(self):
519 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
520 """
521 min_e0, max_e0 = self.e0_range
522 return min_e0
523
524 @cached_property
525 def max_e0(self):
526 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
527 """
528 min_e0, max_e0 = self.e0_range
529 return max_e0
530
531 @cached_property
532 def max_abs_e0(self):
533 """maximum value of `abs(e[0])`."""
534 return max(abs(self.min_e0), abs(self.max_e0))
535
536 @cached_property
537 def min_abs_e0(self):
538 """minimum value of `abs(e[0])`."""
539 return Fraction(0)
540
541 @cache_on_self
542 def max_n(self, i):
543 """maximum value of `n[i]` (the relative error in `N_prime[i]`
544 relative to the previous iteration)
545 """
546 assert isinstance(i, int) and 0 <= i < self.iter_count
547 if i == 0:
548 # from Claim 10
549 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
550 # `n[0] <= 2 * neps[0] / (1 - e[0])`
551
552 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
553 "only one quadrant of interval division implemented"
554 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
555 elif i == 1:
556 # from Claim 10
557 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
558 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
559 assert self.max_f(0) <= 1 and min_mpd >= 0, \
560 "only one quadrant of interval multiplication implemented"
561 prod = (1 - self.max_f(0)) * min_mpd
562 assert self.max_neps(1) >= 0 and prod > 0, \
563 "only one quadrant of interval division implemented"
564 retval = self.max_neps(1) / prod
565 else:
566 # from Claim 6
567 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
568 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
569 assert self.max_neps(i) >= 0 and min_mpd > 0, \
570 "only one quadrant of interval division implemented"
571 retval = self.max_neps(i) / min_mpd
572
573 return self._shrink_max(retval)
574
575 @cache_on_self
576 def max_d(self, i):
577 """maximum value of `d[i]` (the relative error in `D_prime[i]`
578 relative to the previous iteration)
579 """
580 assert isinstance(i, int) and 0 <= i < self.iter_count
581 if i == 0:
582 # from Claim 10
583 # `d[0] = deps[0] / (1 - e[0])`
584
585 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
586 "only one quadrant of interval division implemented"
587 retval = self.max_deps(0) / (1 - self.max_e0)
588 elif i == 1:
589 # from Claim 10
590 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
591 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
592 "only one quadrant of interval multiplication implemented"
593 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
594 assert self.max_deps(1) >= 0 and divisor > 0, \
595 "only one quadrant of interval division implemented"
596 retval = self.max_deps(1) / divisor
597 else:
598 # from Claim 6
599 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
600 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
601 "only one quadrant of interval division implemented"
602 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
603
604 return self._shrink_max(retval)
605
606 @cache_on_self
607 def max_f(self, i):
608 """maximum value of `f[i]` (the relative error in `F_prime[i]`
609 relative to the previous iteration)
610 """
611 assert isinstance(i, int) and 0 <= i < self.iter_count
612 if i == 0:
613 # from Claim 10
614 # `f[0] = feps[0] / (1 - delta[0])`
615
616 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
617 "only one quadrant of interval division implemented"
618 retval = self.max_feps(0) / (1 - self.max_delta(0))
619 elif i == 1:
620 # from Claim 10
621 # `f[1] = feps[1]`
622 retval = self.max_feps(1)
623 else:
624 # from Claim 6
625 # `f[i] <= max_feps[i]`
626 retval = self.max_feps(i)
627
628 return self._shrink_max(retval)
629
630 @cache_on_self
631 def max_delta(self, i):
632 """ maximum value of `delta[i]`.
633 `delta[i]` is defined in Definition 4 of paper.
634 """
635 assert isinstance(i, int) and 0 <= i < self.iter_count
636 if i == 0:
637 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
638 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
639 else:
640 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
641 prev_max_delta = self.max_delta(i - 1)
642 assert prev_max_delta >= 0
643 retval = prev_max_delta ** 2 + self.max_f(i - 1)
644
645 # `delta[i]` has to be smaller than one otherwise errors would go off
646 # to infinity
647 _assert_accuracy(retval < 1)
648
649 return self._shrink_max(retval)
650
651 @cache_on_self
652 def max_pi(self, i):
653 """ maximum value of `pi[i]`.
654 `pi[i]` is defined right below Theorem 5 of paper.
655 """
656 assert isinstance(i, int) and 0 <= i < self.iter_count
657 # `pi[i] = 1 - (1 - n[i]) * prod`
658 # where `prod` is the product of,
659 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
660 min_prod = Fraction(1)
661 for j in range(i):
662 max_n_j = self.max_n(j)
663 max_d_j = self.max_d(j)
664 assert max_n_j <= 1 and max_d_j > -1, \
665 "only one quadrant of interval division implemented"
666 min_prod *= (1 - max_n_j) / (1 + max_d_j)
667 max_n_i = self.max_n(i)
668 assert max_n_i <= 1 and min_prod >= 0, \
669 "only one quadrant of interval multiplication implemented"
670 retval = 1 - (1 - max_n_i) * min_prod
671 return self._shrink_max(retval)
672
673 @cached_property
674 def max_n_shift(self):
675 """ maximum value of `state.n_shift`.
676 """
677 # input numerator is `2*io_width`-bits
678 max_n = (1 << (self.io_width * 2)) - 1
679 max_n_shift = 0
680 # normalize so 1 <= n < 2
681 while max_n >= 2:
682 max_n >>= 1
683 max_n_shift += 1
684 return max_n_shift
685
686 @cached_property
687 def n_hat(self):
688 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
689 """
690 n_hat = Fraction(0)
691 for i in range(self.iter_count):
692 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
693 return self._shrink_max(n_hat)
694
695 def __make_ops(self):
696 """ Goldschmidt division algorithm.
697
698 based on:
699 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
700 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
701 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
702
703 yields: GoldschmidtDivOp
704 the operations needed to perform the division.
705 """
706 # establish assumptions of the paper's error analysis (section 3.1):
707
708 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
709 yield GoldschmidtDivOp.Normalize
710
711 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
712 # the assumption is met by multipliers with > 4-bits precision
713 _assert_accuracy(self.expanded_width > 4)
714
715 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
716 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
717 + self.max_f(0) < Fraction(1, 2))
718
719 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
720 # (B is the denominator)
721
722 for addr in range(self.table_addr_count):
723 f_prime_m1 = self.table[addr]
724 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
725
726 yield GoldschmidtDivOp.FEqTableLookup
727
728 # we use Setting I (section 4.1 of the paper):
729 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
730 # the conditions on n_hat are satisfied by construction.
731 for i in range(self.iter_count):
732 _assert_accuracy(self.max_f(i) == 0)
733 yield GoldschmidtDivOp.MulNByF
734 if i != self.iter_count - 1:
735 yield GoldschmidtDivOp.MulDByF
736 yield GoldschmidtDivOp.FEq2MinusD
737
738 # relative approximation error `p(N_prime[i])`:
739 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
740 # `0 <= p(N_prime[i])`
741 # `p(N_prime[i]) <= (2 * i) * n_hat \`
742 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
743 i = self.iter_count - 1 # last used `i`
744 # compute power manually to prevent huge intermediate values
745 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
746 for _ in range(i):
747 power = self._shrink_max(power * power)
748
749 max_rel_error = (2 * i) * self.n_hat + power
750
751 min_a_over_b = Fraction(1, 2)
752 max_a_over_b = Fraction(2)
753 max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
754 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
755
756 _assert_accuracy(max_rel_error < max_allowed_rel_error,
757 f"not accurate enough: max_rel_error={max_rel_error}"
758 f" max_allowed_rel_error={max_allowed_rel_error}")
759
760 yield GoldschmidtDivOp.CalcResult
761
762 @cache_on_self
763 def default_cost_fn(self):
764 """ calculate the estimated cost on an arbitrary scale of implementing
765 goldschmidt division with the specified parameters. larger cost
766 values mean worse parameters.
767
768 This is the default cost function for `GoldschmidtDivParams.get`.
769
770 returns: float
771 """
772 rom_cells = self.table_data_bits << self.table_addr_bits
773 cost = float(rom_cells)
774 for op in self.ops:
775 if op == GoldschmidtDivOp.MulNByF \
776 or op == GoldschmidtDivOp.MulDByF:
777 mul_cost = self.expanded_width ** 2
778 mul_cost *= self.expanded_width.bit_length()
779 cost += mul_cost
780 cost += 5e7 * self.iter_count
781 return cost
782
783 @staticmethod
784 @lru_cache(maxsize=1 << 16)
785 def __cached_new(base_params):
786 assert isinstance(base_params, GoldschmidtDivParamsBase)
787 # can't use dataclasses.asdict, since it's recursive and will also give
788 # child class fields too, which we don't want.
789 kwargs = {}
790 for field in fields(GoldschmidtDivParamsBase):
791 kwargs[field.name] = getattr(base_params, field.name)
792 try:
793 return GoldschmidtDivParams(**kwargs), None
794 except ParamsNotAccurateEnough as e:
795 return None, e
796
797 @staticmethod
798 def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
799 raise e
800
801 @staticmethod
802 def cached_new(base_params, handle_error=__raise):
803 assert isinstance(base_params, GoldschmidtDivParamsBase)
804 params, error = GoldschmidtDivParams.__cached_new(base_params)
805 if error is None:
806 return params
807 else:
808 return handle_error(error)
809
810 @staticmethod
811 def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
812 """ find efficient parameters for a goldschmidt division algorithm
813 with `params.io_width == io_width`.
814
815 arguments:
816 io_width: int
817 bit-width of the input divisor and the result.
818 the input numerator is `2 * io_width`-bits wide.
819 cost_fn: Callable[[GoldschmidtDivParams], float]
820 return the estimated cost on an arbitrary scale of implementing
821 goldschmidt division with the specified parameters. larger cost
822 values mean worse parameters.
823 max_table_addr_bits: int
824 maximum allowable value of `table_addr_bits`
825 """
826 assert isinstance(io_width, int) and io_width >= 1
827 assert callable(cost_fn)
828
829 last_error = None
830 last_error_params = None
831
832 def cached_new(base_params):
833 def handle_error(e):
834 nonlocal last_error, last_error_params
835 last_error = e
836 last_error_params = base_params
837 return None
838
839 retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
840 if retval is None:
841 logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
842 else:
843 logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
844 return retval
845
846 @lru_cache(maxsize=None)
847 def get_cost(base_params):
848 params = cached_new(base_params)
849 if params is None:
850 return math.inf
851 retval = cost_fn(params)
852 logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
853 return retval
854
855 # start with parameters big enough to always work.
856 initial_extra_precision = io_width * 2 + 4
857 initial_params = GoldschmidtDivParamsBase(
858 io_width=io_width,
859 extra_precision=initial_extra_precision,
860 table_addr_bits=min(max_table_addr_bits, io_width),
861 table_data_bits=io_width + initial_extra_precision,
862 iter_count=1 + io_width.bit_length())
863
864 if cached_new(initial_params) is None:
865 raise ValueError(f"initial goldschmidt division algorithm "
866 f"parameters are invalid: {initial_params}"
867 ) from last_error
868
869 # find good initial `iter_count`
870 params = initial_params
871 for iter_count in range(1, initial_params.iter_count):
872 trial_params = replace(params, iter_count=iter_count)
873 if cached_new(trial_params) is not None:
874 params = trial_params
875 break
876
877 # now find `table_addr_bits`
878 cost = get_cost(params)
879 for table_addr_bits in range(1, max_table_addr_bits):
880 trial_params = replace(params, table_addr_bits=table_addr_bits)
881 trial_cost = get_cost(trial_params)
882 if trial_cost < cost:
883 params = trial_params
884 cost = trial_cost
885 break
886
887 # check one higher `iter_count` to see if it has lower cost
888 for table_addr_bits in range(1, max_table_addr_bits + 1):
889 trial_params = replace(params,
890 table_addr_bits=table_addr_bits,
891 iter_count=params.iter_count + 1)
892 trial_cost = get_cost(trial_params)
893 if trial_cost < cost:
894 params = trial_params
895 cost = trial_cost
896 break
897
898 # now shrink `table_data_bits`
899 while True:
900 trial_params = replace(params,
901 table_data_bits=params.table_data_bits - 1)
902 trial_cost = get_cost(trial_params)
903 if trial_cost < cost:
904 params = trial_params
905 cost = trial_cost
906 else:
907 break
908
909 # and shrink `extra_precision`
910 while True:
911 trial_params = replace(params,
912 extra_precision=params.extra_precision - 1)
913 trial_cost = get_cost(trial_params)
914 if trial_cost < cost:
915 params = trial_params
916 cost = trial_cost
917 else:
918 break
919
920 return cached_new(params)
921
922
923 @enum.unique
924 class GoldschmidtDivOp(enum.Enum):
925 Normalize = "n, d, n_shift = normalize(n, d)"
926 FEqTableLookup = "f = table_lookup(d)"
927 MulNByF = "n *= f"
928 MulDByF = "d *= f"
929 FEq2MinusD = "f = 2 - d"
930 CalcResult = "result = unnormalize_and_round(n)"
931
932 def run(self, params, state):
933 assert isinstance(params, GoldschmidtDivParams)
934 assert isinstance(state, GoldschmidtDivState)
935 expanded_width = params.expanded_width
936 table_addr_bits = params.table_addr_bits
937 if self == GoldschmidtDivOp.Normalize:
938 # normalize so 1 <= d < 2
939 # can easily be done with count-leading-zeros and left shift
940 while state.d < 1:
941 state.n = (state.n * 2).to_frac_wid(expanded_width)
942 state.d = (state.d * 2).to_frac_wid(expanded_width)
943
944 state.n_shift = 0
945 # normalize so 1 <= n < 2
946 while state.n >= 2:
947 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
948 state.n_shift += 1
949 elif self == GoldschmidtDivOp.FEqTableLookup:
950 # compute initial f by table lookup
951 d_m_1 = state.d - 1
952 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
953 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
954 state.f = params.table[d_m_1.bits]
955 elif self == GoldschmidtDivOp.MulNByF:
956 assert state.f is not None
957 n = state.n * state.f
958 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
959 elif self == GoldschmidtDivOp.MulDByF:
960 assert state.f is not None
961 d = state.d * state.f
962 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
963 elif self == GoldschmidtDivOp.FEq2MinusD:
964 state.f = (2 - state.d).to_frac_wid(expanded_width)
965 elif self == GoldschmidtDivOp.CalcResult:
966 assert state.n_shift is not None
967 # scale to correct value
968 n = state.n * (1 << state.n_shift)
969
970 state.quotient = math.floor(n)
971 state.remainder = state.orig_n - state.quotient * state.orig_d
972 if state.remainder >= state.orig_d:
973 state.quotient += 1
974 state.remainder -= state.orig_d
975 else:
976 assert False, f"unimplemented GoldschmidtDivOp: {self}"
977
978
979 @dataclass
980 class GoldschmidtDivState:
981 orig_n: int
982 """original numerator"""
983
984 orig_d: int
985 """original denominator"""
986
987 n: FixedPoint
988 """numerator -- N_prime[i] in the paper's algorithm 2"""
989
990 d: FixedPoint
991 """denominator -- D_prime[i] in the paper's algorithm 2"""
992
993 f: "FixedPoint | None" = None
994 """current factor -- F_prime[i] in the paper's algorithm 2"""
995
996 quotient: "int | None" = None
997 """final quotient"""
998
999 remainder: "int | None" = None
1000 """final remainder"""
1001
1002 n_shift: "int | None" = None
1003 """amount the numerator needs to be left-shifted at the end of the
1004 algorithm.
1005 """
1006
1007
1008 def goldschmidt_div(n, d, params):
1009 """ Goldschmidt division algorithm.
1010
1011 based on:
1012 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1013 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1014 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1015
1016 arguments:
1017 n: int
1018 numerator. a `2*width`-bit unsigned integer.
1019 must be less than `d << width`, otherwise the quotient wouldn't
1020 fit in `width` bits.
1021 d: int
1022 denominator. a `width`-bit unsigned integer. must not be zero.
1023 width: int
1024 the bit-width of the inputs/outputs. must be a positive integer.
1025
1026 returns: tuple[int, int]
1027 the quotient and remainder. a tuple of two `width`-bit unsigned
1028 integers.
1029 """
1030 assert isinstance(params, GoldschmidtDivParams)
1031 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
1032 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
1033
1034 # this whole algorithm is done with fixed-point arithmetic where values
1035 # have `width` fractional bits
1036
1037 state = GoldschmidtDivState(
1038 orig_n=n,
1039 orig_d=d,
1040 n=FixedPoint(n, params.io_width),
1041 d=FixedPoint(d, params.io_width),
1042 )
1043
1044 for op in params.ops:
1045 op.run(params, state)
1046
1047 assert state.quotient is not None
1048 assert state.remainder is not None
1049
1050 return state.quotient, state.remainder
1051
1052
1053 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
1054
1055
1056 @lru_cache()
1057 def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
1058 """Generate the look-up table needed for Goldschmidt's square-root and
1059 reciprocal-square-root algorithm.
1060
1061 arguments:
1062 table_addr_bits: int
1063 the number of address bits for the look-up table.
1064 table_data_bits: int
1065 the number of data bits for the look-up table.
1066 """
1067 assert isinstance(table_addr_bits, int) and \
1068 table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1069 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1070 table = []
1071 table_len = 1 << table_addr_bits
1072 for addr in range(table_len):
1073 if addr == 0:
1074 value = FixedPoint(0, table_data_bits)
1075 elif (addr << 2) < table_len:
1076 value = None # table entries should be unused
1077 else:
1078 table_addr_frac_wid = table_addr_bits
1079 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1080 max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
1081 max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
1082 value = max_input_value.to_frac_wid(max_frac_wid)
1083 value = value.rsqrt(RoundDir.DOWN)
1084 value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
1085 table.append(value)
1086
1087 # tuple for immutability
1088 return tuple(table)
1089
1090 # FIXME: add code to calculate error bounds and check that the algorithm will
1091 # actually work (like in the goldschmidt division algorithm).
1092 # FIXME: add code to calculate a good set of parameters based on the error
1093 # bounds checking.
1094
1095
1096 def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
1097 table_addr_bits, table_data_bits, iter_count):
1098 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1099
1100 uses algorithm based on second method at:
1101 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1102
1103 arguments:
1104 radicand: FixedPoint(frac_wid=frac_wid)
1105 the input value to take the square-root and reciprocal-square-root of.
1106 io_width: int
1107 the number of bits in the input (`radicand`) and output values.
1108 frac_wid: int
1109 the number of fraction bits in the input (`radicand`) and output
1110 values.
1111 extra_precision: int
1112 the number of bits of internal extra precision.
1113 table_addr_bits: int
1114 the number of address bits for the look-up table.
1115 table_data_bits: int
1116 the number of data bits for the look-up table.
1117
1118 returns: tuple[FixedPoint, FixedPoint]
1119 the square-root and reciprocal-square-root, rounded down to the
1120 nearest representable value. If `radicand == 0`, then the
1121 reciprocal-square-root value returned is zero.
1122 """
1123 assert (isinstance(radicand, FixedPoint)
1124 and radicand.frac_wid == frac_wid
1125 and 0 <= radicand.bits < (1 << io_width))
1126 assert isinstance(io_width, int) and io_width >= 1
1127 assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
1128 assert isinstance(extra_precision, int) and extra_precision >= io_width
1129 assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
1130 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1131 assert isinstance(iter_count, int) and iter_count >= 0
1132 expanded_frac_wid = frac_wid + extra_precision
1133 s = radicand.to_frac_wid(expanded_frac_wid)
1134 sqrt_rshift = extra_precision
1135 rsqrt_rshift = extra_precision
1136 while s != 0 and s < 1:
1137 s = (s * 4).to_frac_wid(expanded_frac_wid)
1138 sqrt_rshift += 1
1139 rsqrt_rshift -= 1
1140 while s >= 4:
1141 s = s.div(4, expanded_frac_wid)
1142 sqrt_rshift -= 1
1143 rsqrt_rshift += 1
1144 table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
1145 table_data_bits=table_data_bits)
1146 # core goldschmidt sqrt/rsqrt algorithm:
1147 # initial setup:
1148 table_addr_frac_wid = table_addr_bits
1149 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1150 addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
1151 assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
1152 f = table[addr.bits]
1153 assert f is not None, "accessed invalid table entry"
1154 # use with_frac_wid to fix IDE type deduction
1155 f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
1156 x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1157 h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1158 for _ in range(iter_count):
1159 # iteration step:
1160 f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1161 x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1162 h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1163 r = 2 * h
1164 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1165
1166 sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
1167 rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
1168
1169 next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
1170 if next_sqrt * next_sqrt <= radicand:
1171 sqrt = next_sqrt
1172
1173 next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
1174 if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
1175 rsqrt = next_rsqrt
1176 return sqrt, rsqrt