bdce6fd950e65157c1b2657deb127a58409f18b7
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field
8 import math
9 import enum
10 from fractions import Fraction
11 from types import FunctionType
12
13 try:
14 from functools import cached_property
15 except ImportError:
16 from cached_property import cached_property
17
18 # fix broken IDE type detection for cached_property
19 from typing import TYPE_CHECKING
20 if TYPE_CHECKING:
21 from functools import cached_property
22
23
24 _NOT_FOUND = object()
25
26
27 def cache_on_self(func):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
30 per-method."""
31
32 assert isinstance(func, FunctionType), \
33 "non-plain methods are not supported"
34
35 cache_name = func.__name__ + "__cache"
36
37 def wrapper(self, *args, **kwargs):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache = self.__dict__.get(cache_name, _NOT_FOUND)
40 if cache is _NOT_FOUND:
41 self.__dict__[cache_name] = cache = {}
42 key = (args, *kwargs.items())
43 retval = cache.get(key, _NOT_FOUND)
44 if retval is _NOT_FOUND:
45 retval = func(self, *args, **kwargs)
46 cache[key] = retval
47 return retval
48
49 wrapper.__doc__ = func.__doc__
50 return wrapper
51
52
53 @enum.unique
54 class RoundDir(enum.Enum):
55 DOWN = enum.auto()
56 UP = enum.auto()
57 NEAREST_TIES_UP = enum.auto()
58 ERROR_IF_INEXACT = enum.auto()
59
60
61 @dataclass(frozen=True)
62 class FixedPoint:
63 bits: int
64 frac_wid: int
65
66 def __post_init__(self):
67 assert isinstance(self.bits, int)
68 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
69
70 @staticmethod
71 def cast(value):
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value, FixedPoint):
75 return value
76 if isinstance(value, int):
77 return FixedPoint(value, 0)
78 if isinstance(value, str):
79 value = value.strip()
80 neg = value.startswith("-")
81 if neg or value.startswith("+"):
82 value = value[1:]
83 if value.startswith(("0x", "0X")) and "." in value:
84 value = value[2:]
85 got_dot = False
86 bits = 0
87 frac_wid = 0
88 for digit in value:
89 if digit == "_":
90 continue
91 if got_dot:
92 if digit == ".":
93 raise ValueError("too many `.` in string")
94 frac_wid += 4
95 if digit == ".":
96 got_dot = True
97 continue
98 if not digit.isalnum():
99 raise ValueError("invalid hexadecimal digit")
100 bits <<= 4
101 bits |= int("0x" + digit, base=16)
102 else:
103 bits = int(value, base=0)
104 frac_wid = 0
105 if neg:
106 bits = -bits
107 return FixedPoint(bits, frac_wid)
108
109 if isinstance(value, float):
110 n, d = value.as_integer_ratio()
111 log2_d = d.bit_length() - 1
112 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n, log2_d)
115 raise TypeError("can't convert type to FixedPoint")
116
117 @staticmethod
118 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid, int) and frac_wid >= 0
122 assert isinstance(round_dir, RoundDir)
123 if isinstance(value, Fraction):
124 numerator = value.numerator
125 denominator = value.denominator
126 else:
127 value = FixedPoint.cast(value)
128 numerator = value.bits
129 denominator = 1 << value.frac_wid
130 if denominator < 0:
131 numerator = -numerator
132 denominator = -denominator
133 bits, remainder = divmod(numerator << frac_wid, denominator)
134 if round_dir == RoundDir.DOWN:
135 pass
136 elif round_dir == RoundDir.UP:
137 if remainder != 0:
138 bits += 1
139 elif round_dir == RoundDir.NEAREST_TIES_UP:
140 if remainder * 2 >= denominator:
141 bits += 1
142 elif round_dir == RoundDir.ERROR_IF_INEXACT:
143 if remainder != 0:
144 raise ValueError("inexact conversion")
145 else:
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits, frac_wid)
148
149 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
153
154 def __float__(self):
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self.bits / (1 << self.frac_wid))
158
159 def as_fraction(self):
160 return Fraction(self.bits, 1 << self.frac_wid)
161
162 def cmp(self, rhs):
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs = FixedPoint.cast(rhs)
167 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
168 lhs = self.to_frac_wid(common_frac_wid)
169 rhs = rhs.to_frac_wid(common_frac_wid)
170 return lhs.bits - rhs.bits
171
172 def __eq__(self, rhs):
173 return self.cmp(rhs) == 0
174
175 def __ne__(self, rhs):
176 return self.cmp(rhs) != 0
177
178 def __gt__(self, rhs):
179 return self.cmp(rhs) > 0
180
181 def __lt__(self, rhs):
182 return self.cmp(rhs) < 0
183
184 def __ge__(self, rhs):
185 return self.cmp(rhs) >= 0
186
187 def __le__(self, rhs):
188 return self.cmp(rhs) <= 0
189
190 def fract(self):
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
193 """
194 fract_mask = (1 << self.frac_wid) - 1
195 return FixedPoint(self.bits & fract_mask, self.frac_wid)
196
197 def __str__(self):
198 if self < 0:
199 return "-" + str(-self)
200 digit_bits = 4
201 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
202 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
203 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
204 return hex(math.floor(self)) + "." + frac_str
205
206 def __repr__(self):
207 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
208
209 def __add__(self, rhs):
210 rhs = FixedPoint.cast(rhs)
211 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
212 lhs = self.to_frac_wid(common_frac_wid)
213 rhs = rhs.to_frac_wid(common_frac_wid)
214 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
215
216 def __radd__(self, lhs):
217 # symmetric
218 return self.__add__(lhs)
219
220 def __neg__(self):
221 return FixedPoint(-self.bits, self.frac_wid)
222
223 def __sub__(self, rhs):
224 rhs = FixedPoint.cast(rhs)
225 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
226 lhs = self.to_frac_wid(common_frac_wid)
227 rhs = rhs.to_frac_wid(common_frac_wid)
228 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
229
230 def __rsub__(self, lhs):
231 # a - b == -(b - a)
232 return -self.__sub__(lhs)
233
234 def __mul__(self, rhs):
235 rhs = FixedPoint.cast(rhs)
236 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
237
238 def __rmul__(self, lhs):
239 # symmetric
240 return self.__mul__(lhs)
241
242 def __floor__(self):
243 return self.bits >> self.frac_wid
244
245
246 @dataclass
247 class GoldschmidtDivState:
248 orig_n: int
249 """original numerator"""
250
251 orig_d: int
252 """original denominator"""
253
254 n: FixedPoint
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
256
257 d: FixedPoint
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
259
260 f: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
262
263 quotient: "int | None" = None
264 """final quotient"""
265
266 remainder: "int | None" = None
267 """final remainder"""
268
269 n_shift: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
271 algorithm.
272 """
273
274
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
277 division work."""
278
279
280 def _assert_accuracy(condition, msg="not accurate enough"):
281 if condition:
282 return
283 raise ParamsNotAccurateEnough(msg)
284
285
286 @dataclass(frozen=True, unsafe_hash=True)
287 class GoldschmidtDivParamsBase:
288 """parameters for a Goldschmidt division algorithm, excluding derived
289 parameters.
290 """
291
292 io_width: int
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
295 """
296
297 extra_precision: int
298 """number of bits of additional precision used inside the algorithm."""
299
300 table_addr_bits: int
301 """the number of address bits used in the lookup-table."""
302
303 table_data_bits: int
304 """the number of data bits used in the lookup-table."""
305
306 iter_count: int
307 """the total number of iterations of the division algorithm's loop"""
308
309
310 @dataclass(frozen=True, unsafe_hash=True)
311 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
312 """parameters for a Goldschmidt division algorithm.
313 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
314 """
315
316 # tuple to be immutable, default so repr() works for debugging even when
317 # __post_init__ hasn't finished running yet
318 table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
319 """the lookup-table"""
320
321 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
322 default=NotImplemented)
323 """the operations needed to perform the goldschmidt division algorithm."""
324
325 def _shrink_bound(self, bound, round_dir):
326 """prevent fractions from having huge numerators/denominators by
327 rounding to a `FixedPoint` and converting back to a `Fraction`.
328
329 This is intended only for values used to compute bounds, and not for
330 values that end up in the hardware.
331 """
332 assert isinstance(bound, (Fraction, int))
333 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
334 "you shouldn't use that round_dir on bounds"
335 frac_wid = self.io_width * 4 + 100 # should be enough precision
336 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
337 return fixed.as_fraction()
338
339 def _shrink_min(self, min_bound):
340 """prevent fractions used as minimum bounds from having huge
341 numerators/denominators by rounding down to a `FixedPoint` and
342 converting back to a `Fraction`.
343
344 This is intended only for values used to compute bounds, and not for
345 values that end up in the hardware.
346 """
347 return self._shrink_bound(min_bound, RoundDir.DOWN)
348
349 def _shrink_max(self, max_bound):
350 """prevent fractions used as maximum bounds from having huge
351 numerators/denominators by rounding up to a `FixedPoint` and
352 converting back to a `Fraction`.
353
354 This is intended only for values used to compute bounds, and not for
355 values that end up in the hardware.
356 """
357 return self._shrink_bound(max_bound, RoundDir.UP)
358
359 @property
360 def table_addr_count(self):
361 """number of distinct addresses in the lookup-table."""
362 # used while computing self.table, so can't just do len(self.table)
363 return 1 << self.table_addr_bits
364
365 def table_input_exact_range(self, addr):
366 """return the range of inputs as `Fraction`s used for the table entry
367 with address `addr`."""
368 assert isinstance(addr, int)
369 assert 0 <= addr < self.table_addr_count
370 _assert_accuracy(self.io_width >= self.table_addr_bits)
371 addr_shift = self.io_width - self.table_addr_bits
372 min_numerator = (1 << self.io_width) + (addr << addr_shift)
373 denominator = 1 << self.io_width
374 values_per_table_entry = 1 << addr_shift
375 max_numerator = min_numerator + values_per_table_entry - 1
376 min_input = Fraction(min_numerator, denominator)
377 max_input = Fraction(max_numerator, denominator)
378 min_input = self._shrink_min(min_input)
379 max_input = self._shrink_max(max_input)
380 assert 1 <= min_input <= max_input < 2
381 return min_input, max_input
382
383 def table_value_exact_range(self, addr):
384 """return the range of values as `Fraction`s used for the table entry
385 with address `addr`."""
386 min_input, max_input = self.table_input_exact_range(addr)
387 # division swaps min/max
388 min_value = 1 / max_input
389 max_value = 1 / min_input
390 min_value = self._shrink_min(min_value)
391 max_value = self._shrink_max(max_value)
392 assert 0.5 < min_value <= max_value <= 1
393 return min_value, max_value
394
395 def table_exact_value(self, index):
396 min_value, max_value = self.table_value_exact_range(index)
397 # we round down
398 return min_value
399
400 def __post_init__(self):
401 # called by the autogenerated __init__
402 assert self.io_width >= 1
403 assert self.extra_precision >= 0
404 assert self.table_addr_bits >= 1
405 assert self.table_data_bits >= 1
406 assert self.iter_count >= 1
407 table = []
408 for addr in range(1 << self.table_addr_bits):
409 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
410 self.table_data_bits,
411 RoundDir.DOWN))
412 # we have to use object.__setattr__ since frozen=True
413 object.__setattr__(self, "table", tuple(table))
414 object.__setattr__(self, "ops", tuple(self.__make_ops()))
415
416 @property
417 def expanded_width(self):
418 """the total number of bits of precision used inside the algorithm."""
419 return self.io_width + self.extra_precision
420
421 @cache_on_self
422 def max_neps(self, i):
423 """maximum value of `neps[i]`.
424 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
425 """
426 assert isinstance(i, int) and 0 <= i < self.iter_count
427 return Fraction(1, 1 << self.expanded_width)
428
429 @cache_on_self
430 def max_deps(self, i):
431 """maximum value of `deps[i]`.
432 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
433 """
434 assert isinstance(i, int) and 0 <= i < self.iter_count
435 return Fraction(1, 1 << self.expanded_width)
436
437 @cache_on_self
438 def max_feps(self, i):
439 """maximum value of `feps[i]`.
440 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
441 """
442 assert isinstance(i, int) and 0 <= i < self.iter_count
443 # zero, because the computation of `F_prime[i]` in
444 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
445 return Fraction(0)
446
447 @cached_property
448 def e0_range(self):
449 """minimum and maximum values of `e[0]`
450 (the relative error in `F_prime[-1]`)
451 """
452 min_e0 = Fraction(0)
453 max_e0 = Fraction(0)
454 for addr in range(self.table_addr_count):
455 # `F_prime[-1] = (1 - e[0]) / B`
456 # => `e[0] = 1 - B * F_prime[-1]`
457 min_b, max_b = self.table_input_exact_range(addr)
458 f_prime_m1 = self.table[addr].as_fraction()
459 assert min_b >= 0 and f_prime_m1 >= 0, \
460 "only positive quadrant of interval multiplication implemented"
461 min_product = min_b * f_prime_m1
462 max_product = max_b * f_prime_m1
463 # negation swaps min/max
464 cur_min_e0 = 1 - max_product
465 cur_max_e0 = 1 - min_product
466 min_e0 = min(min_e0, cur_min_e0)
467 max_e0 = max(max_e0, cur_max_e0)
468 min_e0 = self._shrink_min(min_e0)
469 max_e0 = self._shrink_max(max_e0)
470 return min_e0, max_e0
471
472 @cached_property
473 def min_e0(self):
474 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
475 """
476 min_e0, max_e0 = self.e0_range
477 return min_e0
478
479 @cached_property
480 def max_e0(self):
481 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
482 """
483 min_e0, max_e0 = self.e0_range
484 return max_e0
485
486 @cached_property
487 def max_abs_e0(self):
488 """maximum value of `abs(e[0])`."""
489 return max(abs(self.min_e0), abs(self.max_e0))
490
491 @cached_property
492 def min_abs_e0(self):
493 """minimum value of `abs(e[0])`."""
494 return Fraction(0)
495
496 @cache_on_self
497 def max_n(self, i):
498 """maximum value of `n[i]` (the relative error in `N_prime[i]`
499 relative to the previous iteration)
500 """
501 assert isinstance(i, int) and 0 <= i < self.iter_count
502 if i == 0:
503 # from Claim 10
504 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
505 # `n[0] <= 2 * neps[0] / (1 - e[0])`
506
507 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
508 "only one quadrant of interval division implemented"
509 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
510 elif i == 1:
511 # from Claim 10
512 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
513 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
514 assert self.max_f(0) <= 1 and min_mpd >= 0, \
515 "only one quadrant of interval multiplication implemented"
516 prod = (1 - self.max_f(0)) * min_mpd
517 assert self.max_neps(1) >= 0 and prod > 0, \
518 "only one quadrant of interval division implemented"
519 retval = self.max_neps(1) / prod
520 else:
521 # from Claim 6
522 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
523 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
524 assert self.max_neps(i) >= 0 and min_mpd > 0, \
525 "only one quadrant of interval division implemented"
526 retval = self.max_neps(i) / min_mpd
527
528 return self._shrink_max(retval)
529
530 @cache_on_self
531 def max_d(self, i):
532 """maximum value of `d[i]` (the relative error in `D_prime[i]`
533 relative to the previous iteration)
534 """
535 assert isinstance(i, int) and 0 <= i < self.iter_count
536 if i == 0:
537 # from Claim 10
538 # `d[0] = deps[0] / (1 - e[0])`
539
540 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
541 "only one quadrant of interval division implemented"
542 retval = self.max_deps(0) / (1 - self.max_e0)
543 elif i == 1:
544 # from Claim 10
545 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
546 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
547 "only one quadrant of interval multiplication implemented"
548 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
549 assert self.max_deps(1) >= 0 and divisor > 0, \
550 "only one quadrant of interval division implemented"
551 retval = self.max_deps(1) / divisor
552 else:
553 # from Claim 6
554 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
555 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
556 "only one quadrant of interval division implemented"
557 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
558
559 return self._shrink_max(retval)
560
561 @cache_on_self
562 def max_f(self, i):
563 """maximum value of `f[i]` (the relative error in `F_prime[i]`
564 relative to the previous iteration)
565 """
566 assert isinstance(i, int) and 0 <= i < self.iter_count
567 if i == 0:
568 # from Claim 10
569 # `f[0] = feps[0] / (1 - delta[0])`
570
571 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
572 "only one quadrant of interval division implemented"
573 retval = self.max_feps(0) / (1 - self.max_delta(0))
574 elif i == 1:
575 # from Claim 10
576 # `f[1] = feps[1]`
577 retval = self.max_feps(1)
578 else:
579 # from Claim 6
580 # `f[i] <= max_feps[i]`
581 retval = self.max_feps(i)
582
583 return self._shrink_max(retval)
584
585 @cache_on_self
586 def max_delta(self, i):
587 """ maximum value of `delta[i]`.
588 `delta[i]` is defined in Definition 4 of paper.
589 """
590 assert isinstance(i, int) and 0 <= i < self.iter_count
591 if i == 0:
592 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
593 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
594 else:
595 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
596 prev_max_delta = self.max_delta(i - 1)
597 assert prev_max_delta >= 0
598 retval = prev_max_delta ** 2 + self.max_f(i - 1)
599
600 # `delta[i]` has to be smaller than one otherwise errors would go off
601 # to infinity
602 _assert_accuracy(retval < 1)
603
604 return self._shrink_max(retval)
605
606 @cache_on_self
607 def max_pi(self, i):
608 """ maximum value of `pi[i]`.
609 `pi[i]` is defined right below Theorem 5 of paper.
610 """
611 assert isinstance(i, int) and 0 <= i < self.iter_count
612 # `pi[i] = 1 - (1 - n[i]) * prod`
613 # where `prod` is the product of,
614 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
615 min_prod = Fraction(1)
616 for j in range(i):
617 max_n_j = self.max_n(j)
618 max_d_j = self.max_d(j)
619 assert max_n_j <= 1 and max_d_j > -1, \
620 "only one quadrant of interval division implemented"
621 min_prod *= (1 - max_n_j) / (1 + max_d_j)
622 max_n_i = self.max_n(i)
623 assert max_n_i <= 1 and min_prod >= 0, \
624 "only one quadrant of interval multiplication implemented"
625 retval = 1 - (1 - max_n_i) * min_prod
626 return self._shrink_max(retval)
627
628 @cached_property
629 def max_n_shift(self):
630 """ maximum value of `state.n_shift`.
631 """
632 # input numerator is `2*io_width`-bits
633 max_n = (1 << (self.io_width * 2)) - 1
634 max_n_shift = 0
635 # normalize so 1 <= n < 2
636 while max_n >= 2:
637 max_n >>= 1
638 max_n_shift += 1
639 return max_n_shift
640
641 @cached_property
642 def n_hat(self):
643 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
644 """
645 n_hat = Fraction(0)
646 for i in range(self.iter_count):
647 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
648 return self._shrink_max(n_hat)
649
650 def __make_ops(self):
651 """ Goldschmidt division algorithm.
652
653 based on:
654 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
655 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
656 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
657
658 yields: GoldschmidtDivOp
659 the operations needed to perform the division.
660 """
661 # establish assumptions of the paper's error analysis (section 3.1):
662
663 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
664 yield GoldschmidtDivOp.Normalize
665
666 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
667 # the assumption is met by multipliers with > 4-bits precision
668 _assert_accuracy(self.expanded_width > 4)
669
670 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
671 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
672 + self.max_f(0) < Fraction(1, 2))
673
674 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
675 # (B is the denominator)
676
677 for addr in range(self.table_addr_count):
678 f_prime_m1 = self.table[addr]
679 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
680
681 yield GoldschmidtDivOp.FEqTableLookup
682
683 # we use Setting I (section 4.1 of the paper):
684 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
685 # the conditions on n_hat are satisfied by construction.
686 for i in range(self.iter_count):
687 _assert_accuracy(self.max_f(i) == 0)
688 yield GoldschmidtDivOp.MulNByF
689 if i != self.iter_count - 1:
690 yield GoldschmidtDivOp.MulDByF
691 yield GoldschmidtDivOp.FEq2MinusD
692
693 # relative approximation error `p(N_prime[i])`:
694 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
695 # `0 <= p(N_prime[i])`
696 # `p(N_prime[i]) <= (2 * i) * n_hat \`
697 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
698 i = self.iter_count - 1 # last used `i`
699 # compute power manually to prevent huge intermediate values
700 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
701 for _ in range(i):
702 power = self._shrink_max(power * power)
703
704 max_rel_error = (2 * i) * self.n_hat + power
705
706 min_a_over_b = Fraction(1, 2)
707 max_a_over_b = Fraction(2)
708 max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
709 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
710
711 _assert_accuracy(max_rel_error < max_allowed_rel_error,
712 f"not accurate enough: max_rel_error={max_rel_error}"
713 f" max_allowed_rel_error={max_allowed_rel_error}")
714
715 yield GoldschmidtDivOp.CalcResult
716
717 def default_cost_fn(self):
718 """ calculate the estimated cost on an arbitrary scale of implementing
719 goldschmidt division with the specified parameters. larger cost
720 values mean worse parameters.
721
722 This is the default cost function for `GoldschmidtDivParams.get`.
723
724 returns: float
725 """
726 rom_cells = self.table_data_bits << self.table_addr_bits
727 cost = float(rom_cells)
728 for op in self.ops:
729 if op == GoldschmidtDivOp.MulNByF \
730 or op == GoldschmidtDivOp.MulDByF:
731 mul_cost = self.expanded_width ** 2
732 mul_cost *= self.expanded_width.bit_length()
733 cost += mul_cost
734 cost += 1e6 * self.iter_count
735 return cost
736
737 @staticmethod
738 def get(io_width):
739 """ find efficient parameters for a goldschmidt division algorithm
740 with `params.io_width == io_width`.
741 """
742 assert isinstance(io_width, int) and io_width >= 1
743 last_params = None
744 last_error = None
745 for extra_precision in range(io_width * 2 + 4):
746 for table_addr_bits in range(1, 7 + 1):
747 table_data_bits = io_width + extra_precision
748 for iter_count in range(1, 2 * io_width.bit_length()):
749 try:
750 return GoldschmidtDivParams(
751 io_width=io_width,
752 extra_precision=extra_precision,
753 table_addr_bits=table_addr_bits,
754 table_data_bits=table_data_bits,
755 iter_count=iter_count)
756 except ParamsNotAccurateEnough as e:
757 last_params = (f"GoldschmidtDivParams("
758 f"io_width={io_width!r}, "
759 f"extra_precision={extra_precision!r}, "
760 f"table_addr_bits={table_addr_bits!r}, "
761 f"table_data_bits={table_data_bits!r}, "
762 f"iter_count={iter_count!r})")
763 last_error = e
764 raise ValueError(f"can't find working parameters for a goldschmidt "
765 f"division algorithm: last params: {last_params}"
766 ) from last_error
767
768
769 @enum.unique
770 class GoldschmidtDivOp(enum.Enum):
771 Normalize = "n, d, n_shift = normalize(n, d)"
772 FEqTableLookup = "f = table_lookup(d)"
773 MulNByF = "n *= f"
774 MulDByF = "d *= f"
775 FEq2MinusD = "f = 2 - d"
776 CalcResult = "result = unnormalize_and_round(n)"
777
778 def run(self, params, state):
779 assert isinstance(params, GoldschmidtDivParams)
780 assert isinstance(state, GoldschmidtDivState)
781 expanded_width = params.expanded_width
782 table_addr_bits = params.table_addr_bits
783 if self == GoldschmidtDivOp.Normalize:
784 # normalize so 1 <= d < 2
785 # can easily be done with count-leading-zeros and left shift
786 while state.d < 1:
787 state.n = (state.n * 2).to_frac_wid(expanded_width)
788 state.d = (state.d * 2).to_frac_wid(expanded_width)
789
790 state.n_shift = 0
791 # normalize so 1 <= n < 2
792 while state.n >= 2:
793 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
794 state.n_shift += 1
795 elif self == GoldschmidtDivOp.FEqTableLookup:
796 # compute initial f by table lookup
797 d_m_1 = state.d - 1
798 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
799 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
800 state.f = params.table[d_m_1.bits]
801 elif self == GoldschmidtDivOp.MulNByF:
802 assert state.f is not None
803 n = state.n * state.f
804 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
805 elif self == GoldschmidtDivOp.MulDByF:
806 assert state.f is not None
807 d = state.d * state.f
808 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
809 elif self == GoldschmidtDivOp.FEq2MinusD:
810 state.f = (2 - state.d).to_frac_wid(expanded_width)
811 elif self == GoldschmidtDivOp.CalcResult:
812 assert state.n_shift is not None
813 # scale to correct value
814 n = state.n * (1 << state.n_shift)
815
816 state.quotient = math.floor(n)
817 state.remainder = state.orig_n - state.quotient * state.orig_d
818 if state.remainder >= state.orig_d:
819 state.quotient += 1
820 state.remainder -= state.orig_d
821 else:
822 assert False, f"unimplemented GoldschmidtDivOp: {self}"
823
824
825 def goldschmidt_div(n, d, params):
826 """ Goldschmidt division algorithm.
827
828 based on:
829 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
830 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
831 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
832
833 arguments:
834 n: int
835 numerator. a `2*width`-bit unsigned integer.
836 must be less than `d << width`, otherwise the quotient wouldn't
837 fit in `width` bits.
838 d: int
839 denominator. a `width`-bit unsigned integer. must not be zero.
840 width: int
841 the bit-width of the inputs/outputs. must be a positive integer.
842
843 returns: tuple[int, int]
844 the quotient and remainder. a tuple of two `width`-bit unsigned
845 integers.
846 """
847 assert isinstance(params, GoldschmidtDivParams)
848 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
849 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
850
851 # this whole algorithm is done with fixed-point arithmetic where values
852 # have `width` fractional bits
853
854 state = GoldschmidtDivState(
855 orig_n=n,
856 orig_d=d,
857 n=FixedPoint(n, params.io_width),
858 d=FixedPoint(d, params.io_width),
859 )
860
861 for op in params.ops:
862 op.run(params, state)
863
864 assert state.quotient is not None
865 assert state.remainder is not None
866
867 return state.quotient, state.remainder