improved goldschmidt division algorithm parameter optimization algorithm
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field, fields, replace
8 import logging
9 import math
10 import enum
11 from fractions import Fraction
12 from types import FunctionType
13 from functools import lru_cache
14
15 try:
16 from functools import cached_property
17 except ImportError:
18 from cached_property import cached_property
19
20 # fix broken IDE type detection for cached_property
21 from typing import TYPE_CHECKING, Any
22 if TYPE_CHECKING:
23 from functools import cached_property
24
25
26 _NOT_FOUND = object()
27
28
29 def cache_on_self(func):
30 """like `functools.cached_property`, except for methods. unlike
31 `lru_cache` the cache is per-class instance rather than a global cache
32 per-method."""
33
34 assert isinstance(func, FunctionType), \
35 "non-plain methods are not supported"
36
37 cache_name = func.__name__ + "__cache"
38
39 def wrapper(self, *args, **kwargs):
40 # specifically access through `__dict__` to bypass frozen=True
41 cache = self.__dict__.get(cache_name, _NOT_FOUND)
42 if cache is _NOT_FOUND:
43 self.__dict__[cache_name] = cache = {}
44 key = (args, *kwargs.items())
45 retval = cache.get(key, _NOT_FOUND)
46 if retval is _NOT_FOUND:
47 retval = func(self, *args, **kwargs)
48 cache[key] = retval
49 return retval
50
51 wrapper.__doc__ = func.__doc__
52 return wrapper
53
54
55 @enum.unique
56 class RoundDir(enum.Enum):
57 DOWN = enum.auto()
58 UP = enum.auto()
59 NEAREST_TIES_UP = enum.auto()
60 ERROR_IF_INEXACT = enum.auto()
61
62
63 @dataclass(frozen=True)
64 class FixedPoint:
65 bits: int
66 frac_wid: int
67
68 def __post_init__(self):
69 assert isinstance(self.bits, int)
70 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
71
72 @staticmethod
73 def cast(value):
74 """convert `value` to a fixed-point number with enough fractional
75 bits to preserve its value."""
76 if isinstance(value, FixedPoint):
77 return value
78 if isinstance(value, int):
79 return FixedPoint(value, 0)
80 if isinstance(value, str):
81 value = value.strip()
82 neg = value.startswith("-")
83 if neg or value.startswith("+"):
84 value = value[1:]
85 if value.startswith(("0x", "0X")) and "." in value:
86 value = value[2:]
87 got_dot = False
88 bits = 0
89 frac_wid = 0
90 for digit in value:
91 if digit == "_":
92 continue
93 if got_dot:
94 if digit == ".":
95 raise ValueError("too many `.` in string")
96 frac_wid += 4
97 if digit == ".":
98 got_dot = True
99 continue
100 if not digit.isalnum():
101 raise ValueError("invalid hexadecimal digit")
102 bits <<= 4
103 bits |= int("0x" + digit, base=16)
104 else:
105 bits = int(value, base=0)
106 frac_wid = 0
107 if neg:
108 bits = -bits
109 return FixedPoint(bits, frac_wid)
110
111 if isinstance(value, float):
112 n, d = value.as_integer_ratio()
113 log2_d = d.bit_length() - 1
114 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
115 "fail with float being IEEE 754")
116 return FixedPoint(n, log2_d)
117 raise TypeError("can't convert type to FixedPoint")
118
119 @staticmethod
120 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
121 """convert `value` to the nearest fixed-point number with `frac_wid`
122 fractional bits, rounding according to `round_dir`."""
123 assert isinstance(frac_wid, int) and frac_wid >= 0
124 assert isinstance(round_dir, RoundDir)
125 if isinstance(value, Fraction):
126 numerator = value.numerator
127 denominator = value.denominator
128 else:
129 value = FixedPoint.cast(value)
130 numerator = value.bits
131 denominator = 1 << value.frac_wid
132 if denominator < 0:
133 numerator = -numerator
134 denominator = -denominator
135 bits, remainder = divmod(numerator << frac_wid, denominator)
136 if round_dir == RoundDir.DOWN:
137 pass
138 elif round_dir == RoundDir.UP:
139 if remainder != 0:
140 bits += 1
141 elif round_dir == RoundDir.NEAREST_TIES_UP:
142 if remainder * 2 >= denominator:
143 bits += 1
144 elif round_dir == RoundDir.ERROR_IF_INEXACT:
145 if remainder != 0:
146 raise ValueError("inexact conversion")
147 else:
148 assert False, "unimplemented round_dir"
149 return FixedPoint(bits, frac_wid)
150
151 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
152 """convert to the nearest fixed-point number with `frac_wid`
153 fractional bits, rounding according to `round_dir`."""
154 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
155
156 def __float__(self):
157 # use truediv to get correct result even when bits
158 # and frac_wid are huge
159 return float(self.bits / (1 << self.frac_wid))
160
161 def as_fraction(self):
162 return Fraction(self.bits, 1 << self.frac_wid)
163
164 def cmp(self, rhs):
165 """compare self with rhs, returning a positive integer if self is
166 greater than rhs, zero if self is equal to rhs, and a negative integer
167 if self is less than rhs."""
168 rhs = FixedPoint.cast(rhs)
169 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
170 lhs = self.to_frac_wid(common_frac_wid)
171 rhs = rhs.to_frac_wid(common_frac_wid)
172 return lhs.bits - rhs.bits
173
174 def __eq__(self, rhs):
175 return self.cmp(rhs) == 0
176
177 def __ne__(self, rhs):
178 return self.cmp(rhs) != 0
179
180 def __gt__(self, rhs):
181 return self.cmp(rhs) > 0
182
183 def __lt__(self, rhs):
184 return self.cmp(rhs) < 0
185
186 def __ge__(self, rhs):
187 return self.cmp(rhs) >= 0
188
189 def __le__(self, rhs):
190 return self.cmp(rhs) <= 0
191
192 def fract(self):
193 """return the fractional part of `self`.
194 that is `self - math.floor(self)`.
195 """
196 fract_mask = (1 << self.frac_wid) - 1
197 return FixedPoint(self.bits & fract_mask, self.frac_wid)
198
199 def __str__(self):
200 if self < 0:
201 return "-" + str(-self)
202 digit_bits = 4
203 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
204 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
205 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
206 return hex(math.floor(self)) + "." + frac_str
207
208 def __repr__(self):
209 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
210
211 def __add__(self, rhs):
212 rhs = FixedPoint.cast(rhs)
213 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
214 lhs = self.to_frac_wid(common_frac_wid)
215 rhs = rhs.to_frac_wid(common_frac_wid)
216 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
217
218 def __radd__(self, lhs):
219 # symmetric
220 return self.__add__(lhs)
221
222 def __neg__(self):
223 return FixedPoint(-self.bits, self.frac_wid)
224
225 def __sub__(self, rhs):
226 rhs = FixedPoint.cast(rhs)
227 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
228 lhs = self.to_frac_wid(common_frac_wid)
229 rhs = rhs.to_frac_wid(common_frac_wid)
230 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
231
232 def __rsub__(self, lhs):
233 # a - b == -(b - a)
234 return -self.__sub__(lhs)
235
236 def __mul__(self, rhs):
237 rhs = FixedPoint.cast(rhs)
238 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
239
240 def __rmul__(self, lhs):
241 # symmetric
242 return self.__mul__(lhs)
243
244 def __floor__(self):
245 return self.bits >> self.frac_wid
246
247
248 @dataclass
249 class GoldschmidtDivState:
250 orig_n: int
251 """original numerator"""
252
253 orig_d: int
254 """original denominator"""
255
256 n: FixedPoint
257 """numerator -- N_prime[i] in the paper's algorithm 2"""
258
259 d: FixedPoint
260 """denominator -- D_prime[i] in the paper's algorithm 2"""
261
262 f: "FixedPoint | None" = None
263 """current factor -- F_prime[i] in the paper's algorithm 2"""
264
265 quotient: "int | None" = None
266 """final quotient"""
267
268 remainder: "int | None" = None
269 """final remainder"""
270
271 n_shift: "int | None" = None
272 """amount the numerator needs to be left-shifted at the end of the
273 algorithm.
274 """
275
276
277 class ParamsNotAccurateEnough(Exception):
278 """raised when the parameters aren't accurate enough to have goldschmidt
279 division work."""
280
281
282 def _assert_accuracy(condition, msg="not accurate enough"):
283 if condition:
284 return
285 raise ParamsNotAccurateEnough(msg)
286
287
288 @dataclass(frozen=True, unsafe_hash=True)
289 class GoldschmidtDivParamsBase:
290 """parameters for a Goldschmidt division algorithm, excluding derived
291 parameters.
292 """
293
294 io_width: int
295 """bit-width of the input divisor and the result.
296 the input numerator is `2 * io_width`-bits wide.
297 """
298
299 extra_precision: int
300 """number of bits of additional precision used inside the algorithm."""
301
302 table_addr_bits: int
303 """the number of address bits used in the lookup-table."""
304
305 table_data_bits: int
306 """the number of data bits used in the lookup-table."""
307
308 iter_count: int
309 """the total number of iterations of the division algorithm's loop"""
310
311
312 @dataclass(frozen=True, unsafe_hash=True)
313 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
314 """parameters for a Goldschmidt division algorithm.
315 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
316 """
317
318 # tuple to be immutable, repr=False so repr() works for debugging even when
319 # __post_init__ hasn't finished running yet
320 table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
321 """the lookup-table"""
322
323 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
324 """the operations needed to perform the goldschmidt division algorithm."""
325
326 def _shrink_bound(self, bound, round_dir):
327 """prevent fractions from having huge numerators/denominators by
328 rounding to a `FixedPoint` and converting back to a `Fraction`.
329
330 This is intended only for values used to compute bounds, and not for
331 values that end up in the hardware.
332 """
333 assert isinstance(bound, (Fraction, int))
334 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
335 "you shouldn't use that round_dir on bounds"
336 frac_wid = self.io_width * 4 + 100 # should be enough precision
337 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
338 return fixed.as_fraction()
339
340 def _shrink_min(self, min_bound):
341 """prevent fractions used as minimum bounds from having huge
342 numerators/denominators by rounding down to a `FixedPoint` and
343 converting back to a `Fraction`.
344
345 This is intended only for values used to compute bounds, and not for
346 values that end up in the hardware.
347 """
348 return self._shrink_bound(min_bound, RoundDir.DOWN)
349
350 def _shrink_max(self, max_bound):
351 """prevent fractions used as maximum bounds from having huge
352 numerators/denominators by rounding up to a `FixedPoint` and
353 converting back to a `Fraction`.
354
355 This is intended only for values used to compute bounds, and not for
356 values that end up in the hardware.
357 """
358 return self._shrink_bound(max_bound, RoundDir.UP)
359
360 @property
361 def table_addr_count(self):
362 """number of distinct addresses in the lookup-table."""
363 # used while computing self.table, so can't just do len(self.table)
364 return 1 << self.table_addr_bits
365
366 def table_input_exact_range(self, addr):
367 """return the range of inputs as `Fraction`s used for the table entry
368 with address `addr`."""
369 assert isinstance(addr, int)
370 assert 0 <= addr < self.table_addr_count
371 _assert_accuracy(self.io_width >= self.table_addr_bits)
372 addr_shift = self.io_width - self.table_addr_bits
373 min_numerator = (1 << self.io_width) + (addr << addr_shift)
374 denominator = 1 << self.io_width
375 values_per_table_entry = 1 << addr_shift
376 max_numerator = min_numerator + values_per_table_entry - 1
377 min_input = Fraction(min_numerator, denominator)
378 max_input = Fraction(max_numerator, denominator)
379 min_input = self._shrink_min(min_input)
380 max_input = self._shrink_max(max_input)
381 assert 1 <= min_input <= max_input < 2
382 return min_input, max_input
383
384 def table_value_exact_range(self, addr):
385 """return the range of values as `Fraction`s used for the table entry
386 with address `addr`."""
387 min_input, max_input = self.table_input_exact_range(addr)
388 # division swaps min/max
389 min_value = 1 / max_input
390 max_value = 1 / min_input
391 min_value = self._shrink_min(min_value)
392 max_value = self._shrink_max(max_value)
393 assert 0.5 < min_value <= max_value <= 1
394 return min_value, max_value
395
396 def table_exact_value(self, index):
397 min_value, max_value = self.table_value_exact_range(index)
398 # we round down
399 return min_value
400
401 def __post_init__(self):
402 # called by the autogenerated __init__
403 _assert_accuracy(self.io_width >= 1, "io_width out of range")
404 _assert_accuracy(self.extra_precision >= 0,
405 "extra_precision out of range")
406 _assert_accuracy(self.table_addr_bits >= 1,
407 "table_addr_bits out of range")
408 _assert_accuracy(self.table_data_bits >= 1,
409 "table_data_bits out of range")
410 _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
411 table = []
412 for addr in range(1 << self.table_addr_bits):
413 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
414 self.table_data_bits,
415 RoundDir.DOWN))
416 # we have to use object.__setattr__ since frozen=True
417 object.__setattr__(self, "table", tuple(table))
418 object.__setattr__(self, "ops", tuple(self.__make_ops()))
419
420 @property
421 def expanded_width(self):
422 """the total number of bits of precision used inside the algorithm."""
423 return self.io_width + self.extra_precision
424
425 @cache_on_self
426 def max_neps(self, i):
427 """maximum value of `neps[i]`.
428 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
429 """
430 assert isinstance(i, int) and 0 <= i < self.iter_count
431 return Fraction(1, 1 << self.expanded_width)
432
433 @cache_on_self
434 def max_deps(self, i):
435 """maximum value of `deps[i]`.
436 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
437 """
438 assert isinstance(i, int) and 0 <= i < self.iter_count
439 return Fraction(1, 1 << self.expanded_width)
440
441 @cache_on_self
442 def max_feps(self, i):
443 """maximum value of `feps[i]`.
444 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
445 """
446 assert isinstance(i, int) and 0 <= i < self.iter_count
447 # zero, because the computation of `F_prime[i]` in
448 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
449 return Fraction(0)
450
451 @cached_property
452 def e0_range(self):
453 """minimum and maximum values of `e[0]`
454 (the relative error in `F_prime[-1]`)
455 """
456 min_e0 = Fraction(0)
457 max_e0 = Fraction(0)
458 for addr in range(self.table_addr_count):
459 # `F_prime[-1] = (1 - e[0]) / B`
460 # => `e[0] = 1 - B * F_prime[-1]`
461 min_b, max_b = self.table_input_exact_range(addr)
462 f_prime_m1 = self.table[addr].as_fraction()
463 assert min_b >= 0 and f_prime_m1 >= 0, \
464 "only positive quadrant of interval multiplication implemented"
465 min_product = min_b * f_prime_m1
466 max_product = max_b * f_prime_m1
467 # negation swaps min/max
468 cur_min_e0 = 1 - max_product
469 cur_max_e0 = 1 - min_product
470 min_e0 = min(min_e0, cur_min_e0)
471 max_e0 = max(max_e0, cur_max_e0)
472 min_e0 = self._shrink_min(min_e0)
473 max_e0 = self._shrink_max(max_e0)
474 return min_e0, max_e0
475
476 @cached_property
477 def min_e0(self):
478 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
479 """
480 min_e0, max_e0 = self.e0_range
481 return min_e0
482
483 @cached_property
484 def max_e0(self):
485 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
486 """
487 min_e0, max_e0 = self.e0_range
488 return max_e0
489
490 @cached_property
491 def max_abs_e0(self):
492 """maximum value of `abs(e[0])`."""
493 return max(abs(self.min_e0), abs(self.max_e0))
494
495 @cached_property
496 def min_abs_e0(self):
497 """minimum value of `abs(e[0])`."""
498 return Fraction(0)
499
500 @cache_on_self
501 def max_n(self, i):
502 """maximum value of `n[i]` (the relative error in `N_prime[i]`
503 relative to the previous iteration)
504 """
505 assert isinstance(i, int) and 0 <= i < self.iter_count
506 if i == 0:
507 # from Claim 10
508 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
509 # `n[0] <= 2 * neps[0] / (1 - e[0])`
510
511 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
512 "only one quadrant of interval division implemented"
513 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
514 elif i == 1:
515 # from Claim 10
516 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
517 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
518 assert self.max_f(0) <= 1 and min_mpd >= 0, \
519 "only one quadrant of interval multiplication implemented"
520 prod = (1 - self.max_f(0)) * min_mpd
521 assert self.max_neps(1) >= 0 and prod > 0, \
522 "only one quadrant of interval division implemented"
523 retval = self.max_neps(1) / prod
524 else:
525 # from Claim 6
526 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
527 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
528 assert self.max_neps(i) >= 0 and min_mpd > 0, \
529 "only one quadrant of interval division implemented"
530 retval = self.max_neps(i) / min_mpd
531
532 return self._shrink_max(retval)
533
534 @cache_on_self
535 def max_d(self, i):
536 """maximum value of `d[i]` (the relative error in `D_prime[i]`
537 relative to the previous iteration)
538 """
539 assert isinstance(i, int) and 0 <= i < self.iter_count
540 if i == 0:
541 # from Claim 10
542 # `d[0] = deps[0] / (1 - e[0])`
543
544 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
545 "only one quadrant of interval division implemented"
546 retval = self.max_deps(0) / (1 - self.max_e0)
547 elif i == 1:
548 # from Claim 10
549 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
550 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
551 "only one quadrant of interval multiplication implemented"
552 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
553 assert self.max_deps(1) >= 0 and divisor > 0, \
554 "only one quadrant of interval division implemented"
555 retval = self.max_deps(1) / divisor
556 else:
557 # from Claim 6
558 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
559 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
560 "only one quadrant of interval division implemented"
561 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
562
563 return self._shrink_max(retval)
564
565 @cache_on_self
566 def max_f(self, i):
567 """maximum value of `f[i]` (the relative error in `F_prime[i]`
568 relative to the previous iteration)
569 """
570 assert isinstance(i, int) and 0 <= i < self.iter_count
571 if i == 0:
572 # from Claim 10
573 # `f[0] = feps[0] / (1 - delta[0])`
574
575 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
576 "only one quadrant of interval division implemented"
577 retval = self.max_feps(0) / (1 - self.max_delta(0))
578 elif i == 1:
579 # from Claim 10
580 # `f[1] = feps[1]`
581 retval = self.max_feps(1)
582 else:
583 # from Claim 6
584 # `f[i] <= max_feps[i]`
585 retval = self.max_feps(i)
586
587 return self._shrink_max(retval)
588
589 @cache_on_self
590 def max_delta(self, i):
591 """ maximum value of `delta[i]`.
592 `delta[i]` is defined in Definition 4 of paper.
593 """
594 assert isinstance(i, int) and 0 <= i < self.iter_count
595 if i == 0:
596 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
597 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
598 else:
599 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
600 prev_max_delta = self.max_delta(i - 1)
601 assert prev_max_delta >= 0
602 retval = prev_max_delta ** 2 + self.max_f(i - 1)
603
604 # `delta[i]` has to be smaller than one otherwise errors would go off
605 # to infinity
606 _assert_accuracy(retval < 1)
607
608 return self._shrink_max(retval)
609
610 @cache_on_self
611 def max_pi(self, i):
612 """ maximum value of `pi[i]`.
613 `pi[i]` is defined right below Theorem 5 of paper.
614 """
615 assert isinstance(i, int) and 0 <= i < self.iter_count
616 # `pi[i] = 1 - (1 - n[i]) * prod`
617 # where `prod` is the product of,
618 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
619 min_prod = Fraction(1)
620 for j in range(i):
621 max_n_j = self.max_n(j)
622 max_d_j = self.max_d(j)
623 assert max_n_j <= 1 and max_d_j > -1, \
624 "only one quadrant of interval division implemented"
625 min_prod *= (1 - max_n_j) / (1 + max_d_j)
626 max_n_i = self.max_n(i)
627 assert max_n_i <= 1 and min_prod >= 0, \
628 "only one quadrant of interval multiplication implemented"
629 retval = 1 - (1 - max_n_i) * min_prod
630 return self._shrink_max(retval)
631
632 @cached_property
633 def max_n_shift(self):
634 """ maximum value of `state.n_shift`.
635 """
636 # input numerator is `2*io_width`-bits
637 max_n = (1 << (self.io_width * 2)) - 1
638 max_n_shift = 0
639 # normalize so 1 <= n < 2
640 while max_n >= 2:
641 max_n >>= 1
642 max_n_shift += 1
643 return max_n_shift
644
645 @cached_property
646 def n_hat(self):
647 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
648 """
649 n_hat = Fraction(0)
650 for i in range(self.iter_count):
651 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
652 return self._shrink_max(n_hat)
653
654 def __make_ops(self):
655 """ Goldschmidt division algorithm.
656
657 based on:
658 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
659 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
660 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
661
662 yields: GoldschmidtDivOp
663 the operations needed to perform the division.
664 """
665 # establish assumptions of the paper's error analysis (section 3.1):
666
667 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
668 yield GoldschmidtDivOp.Normalize
669
670 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
671 # the assumption is met by multipliers with > 4-bits precision
672 _assert_accuracy(self.expanded_width > 4)
673
674 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
675 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
676 + self.max_f(0) < Fraction(1, 2))
677
678 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
679 # (B is the denominator)
680
681 for addr in range(self.table_addr_count):
682 f_prime_m1 = self.table[addr]
683 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
684
685 yield GoldschmidtDivOp.FEqTableLookup
686
687 # we use Setting I (section 4.1 of the paper):
688 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
689 # the conditions on n_hat are satisfied by construction.
690 for i in range(self.iter_count):
691 _assert_accuracy(self.max_f(i) == 0)
692 yield GoldschmidtDivOp.MulNByF
693 if i != self.iter_count - 1:
694 yield GoldschmidtDivOp.MulDByF
695 yield GoldschmidtDivOp.FEq2MinusD
696
697 # relative approximation error `p(N_prime[i])`:
698 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
699 # `0 <= p(N_prime[i])`
700 # `p(N_prime[i]) <= (2 * i) * n_hat \`
701 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
702 i = self.iter_count - 1 # last used `i`
703 # compute power manually to prevent huge intermediate values
704 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
705 for _ in range(i):
706 power = self._shrink_max(power * power)
707
708 max_rel_error = (2 * i) * self.n_hat + power
709
710 min_a_over_b = Fraction(1, 2)
711 max_a_over_b = Fraction(2)
712 max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
713 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
714
715 _assert_accuracy(max_rel_error < max_allowed_rel_error,
716 f"not accurate enough: max_rel_error={max_rel_error}"
717 f" max_allowed_rel_error={max_allowed_rel_error}")
718
719 yield GoldschmidtDivOp.CalcResult
720
721 @cache_on_self
722 def default_cost_fn(self):
723 """ calculate the estimated cost on an arbitrary scale of implementing
724 goldschmidt division with the specified parameters. larger cost
725 values mean worse parameters.
726
727 This is the default cost function for `GoldschmidtDivParams.get`.
728
729 returns: float
730 """
731 rom_cells = self.table_data_bits << self.table_addr_bits
732 cost = float(rom_cells)
733 for op in self.ops:
734 if op == GoldschmidtDivOp.MulNByF \
735 or op == GoldschmidtDivOp.MulDByF:
736 mul_cost = self.expanded_width ** 2
737 mul_cost *= self.expanded_width.bit_length()
738 cost += mul_cost
739 cost += 5e7 * self.iter_count
740 return cost
741
742 @staticmethod
743 @lru_cache(maxsize=1 << 16)
744 def __cached_new(base_params):
745 assert isinstance(base_params, GoldschmidtDivParamsBase)
746 # can't use dataclasses.asdict, since it's recursive and will also give
747 # child class fields too, which we don't want.
748 kwargs = {}
749 for field in fields(GoldschmidtDivParamsBase):
750 kwargs[field.name] = getattr(base_params, field.name)
751 try:
752 return GoldschmidtDivParams(**kwargs), None
753 except ParamsNotAccurateEnough as e:
754 return None, e
755
756 @staticmethod
757 def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
758 raise e
759
760 @staticmethod
761 def cached_new(base_params, handle_error=__raise):
762 assert isinstance(base_params, GoldschmidtDivParamsBase)
763 params, error = GoldschmidtDivParams.__cached_new(base_params)
764 if error is None:
765 return params
766 else:
767 return handle_error(error)
768
769 @staticmethod
770 def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
771 """ find efficient parameters for a goldschmidt division algorithm
772 with `params.io_width == io_width`.
773
774 arguments:
775 io_width: int
776 bit-width of the input divisor and the result.
777 the input numerator is `2 * io_width`-bits wide.
778 cost_fn: Callable[[GoldschmidtDivParams], float]
779 return the estimated cost on an arbitrary scale of implementing
780 goldschmidt division with the specified parameters. larger cost
781 values mean worse parameters.
782 max_table_addr_bits: int
783 maximum allowable value of `table_addr_bits`
784 """
785 assert isinstance(io_width, int) and io_width >= 1
786 assert callable(cost_fn)
787
788 last_error = None
789 last_error_params = None
790
791 def cached_new(base_params):
792 def handle_error(e):
793 nonlocal last_error, last_error_params
794 last_error = e
795 last_error_params = base_params
796 return None
797
798 retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
799 if retval is None:
800 logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
801 else:
802 logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
803 return retval
804
805 @lru_cache(maxsize=None)
806 def get_cost(base_params):
807 params = cached_new(base_params)
808 if params is None:
809 return math.inf
810 retval = cost_fn(params)
811 logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
812 return retval
813
814 # start with parameters big enough to always work.
815 initial_extra_precision = io_width * 2 + 4
816 initial_params = GoldschmidtDivParamsBase(
817 io_width=io_width,
818 extra_precision=initial_extra_precision,
819 table_addr_bits=min(max_table_addr_bits, io_width),
820 table_data_bits=io_width + initial_extra_precision,
821 iter_count=1 + io_width.bit_length())
822
823 if cached_new(initial_params) is None:
824 raise ValueError(f"initial goldschmidt division algorithm "
825 f"parameters are invalid: {initial_params}"
826 ) from last_error
827
828 # find good initial `iter_count`
829 params = initial_params
830 for iter_count in range(1, initial_params.iter_count):
831 trial_params = replace(params, iter_count=iter_count)
832 if cached_new(trial_params) is not None:
833 params = trial_params
834 break
835
836 # now find `table_addr_bits`
837 cost = get_cost(params)
838 for table_addr_bits in range(1, max_table_addr_bits):
839 trial_params = replace(params, table_addr_bits=table_addr_bits)
840 trial_cost = get_cost(trial_params)
841 if trial_cost < cost:
842 params = trial_params
843 cost = trial_cost
844 break
845
846 # check one higher `iter_count` to see if it has lower cost
847 for table_addr_bits in range(1, max_table_addr_bits + 1):
848 trial_params = replace(params,
849 table_addr_bits=table_addr_bits,
850 iter_count=params.iter_count + 1)
851 trial_cost = get_cost(trial_params)
852 if trial_cost < cost:
853 params = trial_params
854 cost = trial_cost
855 break
856
857 # now shrink `table_data_bits`
858 while True:
859 trial_params = replace(params,
860 table_data_bits=params.table_data_bits - 1)
861 trial_cost = get_cost(trial_params)
862 if trial_cost < cost:
863 params = trial_params
864 cost = trial_cost
865 else:
866 break
867
868 # and shrink `extra_precision`
869 while True:
870 trial_params = replace(params,
871 extra_precision=params.extra_precision - 1)
872 trial_cost = get_cost(trial_params)
873 if trial_cost < cost:
874 params = trial_params
875 cost = trial_cost
876 else:
877 break
878
879 return cached_new(params)
880
881
882 @enum.unique
883 class GoldschmidtDivOp(enum.Enum):
884 Normalize = "n, d, n_shift = normalize(n, d)"
885 FEqTableLookup = "f = table_lookup(d)"
886 MulNByF = "n *= f"
887 MulDByF = "d *= f"
888 FEq2MinusD = "f = 2 - d"
889 CalcResult = "result = unnormalize_and_round(n)"
890
891 def run(self, params, state):
892 assert isinstance(params, GoldschmidtDivParams)
893 assert isinstance(state, GoldschmidtDivState)
894 expanded_width = params.expanded_width
895 table_addr_bits = params.table_addr_bits
896 if self == GoldschmidtDivOp.Normalize:
897 # normalize so 1 <= d < 2
898 # can easily be done with count-leading-zeros and left shift
899 while state.d < 1:
900 state.n = (state.n * 2).to_frac_wid(expanded_width)
901 state.d = (state.d * 2).to_frac_wid(expanded_width)
902
903 state.n_shift = 0
904 # normalize so 1 <= n < 2
905 while state.n >= 2:
906 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
907 state.n_shift += 1
908 elif self == GoldschmidtDivOp.FEqTableLookup:
909 # compute initial f by table lookup
910 d_m_1 = state.d - 1
911 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
912 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
913 state.f = params.table[d_m_1.bits]
914 elif self == GoldschmidtDivOp.MulNByF:
915 assert state.f is not None
916 n = state.n * state.f
917 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
918 elif self == GoldschmidtDivOp.MulDByF:
919 assert state.f is not None
920 d = state.d * state.f
921 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
922 elif self == GoldschmidtDivOp.FEq2MinusD:
923 state.f = (2 - state.d).to_frac_wid(expanded_width)
924 elif self == GoldschmidtDivOp.CalcResult:
925 assert state.n_shift is not None
926 # scale to correct value
927 n = state.n * (1 << state.n_shift)
928
929 state.quotient = math.floor(n)
930 state.remainder = state.orig_n - state.quotient * state.orig_d
931 if state.remainder >= state.orig_d:
932 state.quotient += 1
933 state.remainder -= state.orig_d
934 else:
935 assert False, f"unimplemented GoldschmidtDivOp: {self}"
936
937
938 def goldschmidt_div(n, d, params):
939 """ Goldschmidt division algorithm.
940
941 based on:
942 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
943 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
944 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
945
946 arguments:
947 n: int
948 numerator. a `2*width`-bit unsigned integer.
949 must be less than `d << width`, otherwise the quotient wouldn't
950 fit in `width` bits.
951 d: int
952 denominator. a `width`-bit unsigned integer. must not be zero.
953 width: int
954 the bit-width of the inputs/outputs. must be a positive integer.
955
956 returns: tuple[int, int]
957 the quotient and remainder. a tuple of two `width`-bit unsigned
958 integers.
959 """
960 assert isinstance(params, GoldschmidtDivParams)
961 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
962 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
963
964 # this whole algorithm is done with fixed-point arithmetic where values
965 # have `width` fractional bits
966
967 state = GoldschmidtDivState(
968 orig_n=n,
969 orig_d=d,
970 n=FixedPoint(n, params.io_width),
971 d=FixedPoint(d, params.io_width),
972 )
973
974 for op in params.ops:
975 op.run(params, state)
976
977 assert state.quotient is not None
978 assert state.remainder is not None
979
980 return state.quotient, state.remainder