goldschmidt division works! still needs better parameter selection tho...
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field
8 import math
9 import enum
10 from fractions import Fraction
11 from types import FunctionType
12
13 try:
14 from functools import cached_property
15 except ImportError:
16 from cached_property import cached_property
17
18 # fix broken IDE type detection for cached_property
19 from typing import TYPE_CHECKING
20 if TYPE_CHECKING:
21 from functools import cached_property
22
23
24 _NOT_FOUND = object()
25
26
27 def cache_on_self(func):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
30 per-method."""
31
32 assert isinstance(func, FunctionType), \
33 "non-plain methods are not supported"
34
35 cache_name = func.__name__ + "__cache"
36
37 def wrapper(self, *args, **kwargs):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache = self.__dict__.get(cache_name, _NOT_FOUND)
40 if cache is _NOT_FOUND:
41 self.__dict__[cache_name] = cache = {}
42 key = (args, *kwargs.items())
43 retval = cache.get(key, _NOT_FOUND)
44 if retval is _NOT_FOUND:
45 retval = func(self, *args, **kwargs)
46 cache[key] = retval
47 return retval
48
49 wrapper.__doc__ = func.__doc__
50 return wrapper
51
52
53 @enum.unique
54 class RoundDir(enum.Enum):
55 DOWN = enum.auto()
56 UP = enum.auto()
57 NEAREST_TIES_UP = enum.auto()
58 ERROR_IF_INEXACT = enum.auto()
59
60
61 @dataclass(frozen=True)
62 class FixedPoint:
63 bits: int
64 frac_wid: int
65
66 def __post_init__(self):
67 assert isinstance(self.bits, int)
68 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
69
70 @staticmethod
71 def cast(value):
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value, FixedPoint):
75 return value
76 if isinstance(value, int):
77 return FixedPoint(value, 0)
78 if isinstance(value, str):
79 value = value.strip()
80 neg = value.startswith("-")
81 if neg or value.startswith("+"):
82 value = value[1:]
83 if value.startswith(("0x", "0X")) and "." in value:
84 value = value[2:]
85 got_dot = False
86 bits = 0
87 frac_wid = 0
88 for digit in value:
89 if digit == "_":
90 continue
91 if got_dot:
92 if digit == ".":
93 raise ValueError("too many `.` in string")
94 frac_wid += 4
95 if digit == ".":
96 got_dot = True
97 continue
98 if not digit.isalnum():
99 raise ValueError("invalid hexadecimal digit")
100 bits <<= 4
101 bits |= int("0x" + digit, base=16)
102 else:
103 bits = int(value, base=0)
104 frac_wid = 0
105 if neg:
106 bits = -bits
107 return FixedPoint(bits, frac_wid)
108
109 if isinstance(value, float):
110 n, d = value.as_integer_ratio()
111 log2_d = d.bit_length() - 1
112 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n, log2_d)
115 raise TypeError("can't convert type to FixedPoint")
116
117 @staticmethod
118 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid, int) and frac_wid >= 0
122 assert isinstance(round_dir, RoundDir)
123 if isinstance(value, Fraction):
124 numerator = value.numerator
125 denominator = value.denominator
126 else:
127 value = FixedPoint.cast(value)
128 numerator = value.bits
129 denominator = 1 << value.frac_wid
130 if denominator < 0:
131 numerator = -numerator
132 denominator = -denominator
133 bits, remainder = divmod(numerator << frac_wid, denominator)
134 if round_dir == RoundDir.DOWN:
135 pass
136 elif round_dir == RoundDir.UP:
137 if remainder != 0:
138 bits += 1
139 elif round_dir == RoundDir.NEAREST_TIES_UP:
140 if remainder * 2 >= denominator:
141 bits += 1
142 elif round_dir == RoundDir.ERROR_IF_INEXACT:
143 if remainder != 0:
144 raise ValueError("inexact conversion")
145 else:
146 assert False, "unimplemented round_dir"
147 return FixedPoint(bits, frac_wid)
148
149 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
150 """convert to the nearest fixed-point number with `frac_wid`
151 fractional bits, rounding according to `round_dir`."""
152 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
153
154 def __float__(self):
155 # use truediv to get correct result even when bits
156 # and frac_wid are huge
157 return float(self.bits / (1 << self.frac_wid))
158
159 def as_fraction(self):
160 return Fraction(self.bits, 1 << self.frac_wid)
161
162 def cmp(self, rhs):
163 """compare self with rhs, returning a positive integer if self is
164 greater than rhs, zero if self is equal to rhs, and a negative integer
165 if self is less than rhs."""
166 rhs = FixedPoint.cast(rhs)
167 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
168 lhs = self.to_frac_wid(common_frac_wid)
169 rhs = rhs.to_frac_wid(common_frac_wid)
170 return lhs.bits - rhs.bits
171
172 def __eq__(self, rhs):
173 return self.cmp(rhs) == 0
174
175 def __ne__(self, rhs):
176 return self.cmp(rhs) != 0
177
178 def __gt__(self, rhs):
179 return self.cmp(rhs) > 0
180
181 def __lt__(self, rhs):
182 return self.cmp(rhs) < 0
183
184 def __ge__(self, rhs):
185 return self.cmp(rhs) >= 0
186
187 def __le__(self, rhs):
188 return self.cmp(rhs) <= 0
189
190 def fract(self):
191 """return the fractional part of `self`.
192 that is `self - math.floor(self)`.
193 """
194 fract_mask = (1 << self.frac_wid) - 1
195 return FixedPoint(self.bits & fract_mask, self.frac_wid)
196
197 def __str__(self):
198 if self < 0:
199 return "-" + str(-self)
200 digit_bits = 4
201 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
202 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
203 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
204 return hex(math.floor(self)) + "." + frac_str
205
206 def __repr__(self):
207 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
208
209 def __add__(self, rhs):
210 rhs = FixedPoint.cast(rhs)
211 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
212 lhs = self.to_frac_wid(common_frac_wid)
213 rhs = rhs.to_frac_wid(common_frac_wid)
214 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
215
216 def __radd__(self, lhs):
217 # symmetric
218 return self.__add__(lhs)
219
220 def __neg__(self):
221 return FixedPoint(-self.bits, self.frac_wid)
222
223 def __sub__(self, rhs):
224 rhs = FixedPoint.cast(rhs)
225 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
226 lhs = self.to_frac_wid(common_frac_wid)
227 rhs = rhs.to_frac_wid(common_frac_wid)
228 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
229
230 def __rsub__(self, lhs):
231 # a - b == -(b - a)
232 return -self.__sub__(lhs)
233
234 def __mul__(self, rhs):
235 rhs = FixedPoint.cast(rhs)
236 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
237
238 def __rmul__(self, lhs):
239 # symmetric
240 return self.__mul__(lhs)
241
242 def __floor__(self):
243 return self.bits >> self.frac_wid
244
245
246 @dataclass
247 class GoldschmidtDivState:
248 orig_n: int
249 """original numerator"""
250
251 orig_d: int
252 """original denominator"""
253
254 n: FixedPoint
255 """numerator -- N_prime[i] in the paper's algorithm 2"""
256
257 d: FixedPoint
258 """denominator -- D_prime[i] in the paper's algorithm 2"""
259
260 f: "FixedPoint | None" = None
261 """current factor -- F_prime[i] in the paper's algorithm 2"""
262
263 quotient: "int | None" = None
264 """final quotient"""
265
266 remainder: "int | None" = None
267 """final remainder"""
268
269 n_shift: "int | None" = None
270 """amount the numerator needs to be left-shifted at the end of the
271 algorithm.
272 """
273
274
275 class ParamsNotAccurateEnough(Exception):
276 """raised when the parameters aren't accurate enough to have goldschmidt
277 division work."""
278
279
280 def _assert_accuracy(condition, msg="not accurate enough"):
281 if condition:
282 return
283 raise ParamsNotAccurateEnough(msg)
284
285
286 @dataclass(frozen=True, unsafe_hash=True)
287 class GoldschmidtDivParams:
288 """parameters for a Goldschmidt division algorithm.
289 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
290 """
291
292 io_width: int
293 """bit-width of the input divisor and the result.
294 the input numerator is `2 * io_width`-bits wide.
295 """
296
297 extra_precision: int
298 """number of bits of additional precision used inside the algorithm."""
299
300 table_addr_bits: int
301 """the number of address bits used in the lookup-table."""
302
303 table_data_bits: int
304 """the number of data bits used in the lookup-table."""
305
306 iter_count: int
307 """the total number of iterations of the division algorithm's loop"""
308
309 # tuple to be immutable, default so repr() works for debugging even when
310 # __post_init__ hasn't finished running yet
311 table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
312 """the lookup-table"""
313
314 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
315 default=NotImplemented)
316 """the operations needed to perform the goldschmidt division algorithm."""
317
318 def _shrink_bound(self, bound, round_dir):
319 """prevent fractions from having huge numerators/denominators by
320 rounding to a `FixedPoint` and converting back to a `Fraction`.
321
322 This is intended only for values used to compute bounds, and not for
323 values that end up in the hardware.
324 """
325 assert isinstance(bound, (Fraction, int))
326 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
327 "you shouldn't use that round_dir on bounds"
328 frac_wid = self.io_width * 4 + 100 # should be enough precision
329 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
330 return fixed.as_fraction()
331
332 def _shrink_min(self, min_bound):
333 """prevent fractions used as minimum bounds from having huge
334 numerators/denominators by rounding down to a `FixedPoint` and
335 converting back to a `Fraction`.
336
337 This is intended only for values used to compute bounds, and not for
338 values that end up in the hardware.
339 """
340 return self._shrink_bound(min_bound, RoundDir.DOWN)
341
342 def _shrink_max(self, max_bound):
343 """prevent fractions used as maximum bounds from having huge
344 numerators/denominators by rounding up to a `FixedPoint` and
345 converting back to a `Fraction`.
346
347 This is intended only for values used to compute bounds, and not for
348 values that end up in the hardware.
349 """
350 return self._shrink_bound(max_bound, RoundDir.UP)
351
352 @property
353 def table_addr_count(self):
354 """number of distinct addresses in the lookup-table."""
355 # used while computing self.table, so can't just do len(self.table)
356 return 1 << self.table_addr_bits
357
358 def table_input_exact_range(self, addr):
359 """return the range of inputs as `Fraction`s used for the table entry
360 with address `addr`."""
361 assert isinstance(addr, int)
362 assert 0 <= addr < self.table_addr_count
363 _assert_accuracy(self.io_width >= self.table_addr_bits)
364 addr_shift = self.io_width - self.table_addr_bits
365 min_numerator = (1 << self.io_width) + (addr << addr_shift)
366 denominator = 1 << self.io_width
367 values_per_table_entry = 1 << addr_shift
368 max_numerator = min_numerator + values_per_table_entry - 1
369 min_input = Fraction(min_numerator, denominator)
370 max_input = Fraction(max_numerator, denominator)
371 min_input = self._shrink_min(min_input)
372 max_input = self._shrink_max(max_input)
373 assert 1 <= min_input <= max_input < 2
374 return min_input, max_input
375
376 def table_value_exact_range(self, addr):
377 """return the range of values as `Fraction`s used for the table entry
378 with address `addr`."""
379 min_input, max_input = self.table_input_exact_range(addr)
380 # division swaps min/max
381 min_value = 1 / max_input
382 max_value = 1 / min_input
383 min_value = self._shrink_min(min_value)
384 max_value = self._shrink_max(max_value)
385 assert 0.5 < min_value <= max_value <= 1
386 return min_value, max_value
387
388 def table_exact_value(self, index):
389 min_value, max_value = self.table_value_exact_range(index)
390 # we round down
391 return min_value
392
393 def __post_init__(self):
394 # called by the autogenerated __init__
395 assert self.io_width >= 1
396 assert self.extra_precision >= 0
397 assert self.table_addr_bits >= 1
398 assert self.table_data_bits >= 1
399 assert self.iter_count >= 1
400 table = []
401 for addr in range(1 << self.table_addr_bits):
402 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
403 self.table_data_bits,
404 RoundDir.DOWN))
405 # we have to use object.__setattr__ since frozen=True
406 object.__setattr__(self, "table", tuple(table))
407 object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
408
409 @staticmethod
410 def get(io_width):
411 """ find efficient parameters for a goldschmidt division algorithm
412 with `params.io_width == io_width`.
413 """
414 assert isinstance(io_width, int) and io_width >= 1
415 last_params = None
416 last_error = None
417 for extra_precision in range(io_width * 2 + 4):
418 for table_addr_bits in range(1, 7 + 1):
419 table_data_bits = io_width + extra_precision
420 for iter_count in range(1, 2 * io_width.bit_length()):
421 try:
422 return GoldschmidtDivParams(
423 io_width=io_width,
424 extra_precision=extra_precision,
425 table_addr_bits=table_addr_bits,
426 table_data_bits=table_data_bits,
427 iter_count=iter_count)
428 except ParamsNotAccurateEnough as e:
429 last_params = (f"GoldschmidtDivParams("
430 f"io_width={io_width!r}, "
431 f"extra_precision={extra_precision!r}, "
432 f"table_addr_bits={table_addr_bits!r}, "
433 f"table_data_bits={table_data_bits!r}, "
434 f"iter_count={iter_count!r})")
435 last_error = e
436 raise ValueError(f"can't find working parameters for a goldschmidt "
437 f"division algorithm: last params: {last_params}"
438 ) from last_error
439
440 @property
441 def expanded_width(self):
442 """the total number of bits of precision used inside the algorithm."""
443 return self.io_width + self.extra_precision
444
445 @cache_on_self
446 def max_neps(self, i):
447 """maximum value of `neps[i]`.
448 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
449 """
450 assert isinstance(i, int) and 0 <= i < self.iter_count
451 return Fraction(1, 1 << self.expanded_width)
452
453 @cache_on_self
454 def max_deps(self, i):
455 """maximum value of `deps[i]`.
456 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
457 """
458 assert isinstance(i, int) and 0 <= i < self.iter_count
459 return Fraction(1, 1 << self.expanded_width)
460
461 @cache_on_self
462 def max_feps(self, i):
463 """maximum value of `feps[i]`.
464 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
465 """
466 assert isinstance(i, int) and 0 <= i < self.iter_count
467 # zero, because the computation of `F_prime[i]` in
468 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
469 return Fraction(0)
470
471 @cached_property
472 def e0_range(self):
473 """minimum and maximum values of `e[0]`
474 (the relative error in `F_prime[-1]`)
475 """
476 min_e0 = Fraction(0)
477 max_e0 = Fraction(0)
478 for addr in range(self.table_addr_count):
479 # `F_prime[-1] = (1 - e[0]) / B`
480 # => `e[0] = 1 - B * F_prime[-1]`
481 min_b, max_b = self.table_input_exact_range(addr)
482 f_prime_m1 = self.table[addr].as_fraction()
483 assert min_b >= 0 and f_prime_m1 >= 0, \
484 "only positive quadrant of interval multiplication implemented"
485 min_product = min_b * f_prime_m1
486 max_product = max_b * f_prime_m1
487 # negation swaps min/max
488 cur_min_e0 = 1 - max_product
489 cur_max_e0 = 1 - min_product
490 min_e0 = min(min_e0, cur_min_e0)
491 max_e0 = max(max_e0, cur_max_e0)
492 min_e0 = self._shrink_min(min_e0)
493 max_e0 = self._shrink_max(max_e0)
494 return min_e0, max_e0
495
496 @cached_property
497 def min_e0(self):
498 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
499 """
500 min_e0, max_e0 = self.e0_range
501 return min_e0
502
503 @cached_property
504 def max_e0(self):
505 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
506 """
507 min_e0, max_e0 = self.e0_range
508 return max_e0
509
510 @cached_property
511 def max_abs_e0(self):
512 """maximum value of `abs(e[0])`."""
513 return max(abs(self.min_e0), abs(self.max_e0))
514
515 @cached_property
516 def min_abs_e0(self):
517 """minimum value of `abs(e[0])`."""
518 return Fraction(0)
519
520 @cache_on_self
521 def max_n(self, i):
522 """maximum value of `n[i]` (the relative error in `N_prime[i]`
523 relative to the previous iteration)
524 """
525 assert isinstance(i, int) and 0 <= i < self.iter_count
526 if i == 0:
527 # from Claim 10
528 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
529 # `n[0] <= 2 * neps[0] / (1 - e[0])`
530
531 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
532 "only one quadrant of interval division implemented"
533 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
534 elif i == 1:
535 # from Claim 10
536 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
537 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
538 assert self.max_f(0) <= 1 and min_mpd >= 0, \
539 "only one quadrant of interval multiplication implemented"
540 prod = (1 - self.max_f(0)) * min_mpd
541 assert self.max_neps(1) >= 0 and prod > 0, \
542 "only one quadrant of interval division implemented"
543 retval = self.max_neps(1) / prod
544 else:
545 # from Claim 6
546 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
547 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
548 assert self.max_neps(i) >= 0 and min_mpd > 0, \
549 "only one quadrant of interval division implemented"
550 retval = self.max_neps(i) / min_mpd
551
552 return self._shrink_max(retval)
553
554 @cache_on_self
555 def max_d(self, i):
556 """maximum value of `d[i]` (the relative error in `D_prime[i]`
557 relative to the previous iteration)
558 """
559 assert isinstance(i, int) and 0 <= i < self.iter_count
560 if i == 0:
561 # from Claim 10
562 # `d[0] = deps[0] / (1 - e[0])`
563
564 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
565 "only one quadrant of interval division implemented"
566 retval = self.max_deps(0) / (1 - self.max_e0)
567 elif i == 1:
568 # from Claim 10
569 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
570 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
571 "only one quadrant of interval multiplication implemented"
572 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
573 assert self.max_deps(1) >= 0 and divisor > 0, \
574 "only one quadrant of interval division implemented"
575 retval = self.max_deps(1) / divisor
576 else:
577 # from Claim 6
578 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
579 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
580 "only one quadrant of interval division implemented"
581 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
582
583 return self._shrink_max(retval)
584
585 @cache_on_self
586 def max_f(self, i):
587 """maximum value of `f[i]` (the relative error in `F_prime[i]`
588 relative to the previous iteration)
589 """
590 assert isinstance(i, int) and 0 <= i < self.iter_count
591 if i == 0:
592 # from Claim 10
593 # `f[0] = feps[0] / (1 - delta[0])`
594
595 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
596 "only one quadrant of interval division implemented"
597 retval = self.max_feps(0) / (1 - self.max_delta(0))
598 elif i == 1:
599 # from Claim 10
600 # `f[1] = feps[1]`
601 retval = self.max_feps(1)
602 else:
603 # from Claim 6
604 # `f[i] <= max_feps[i]`
605 retval = self.max_feps(i)
606
607 return self._shrink_max(retval)
608
609 @cache_on_self
610 def max_delta(self, i):
611 """ maximum value of `delta[i]`.
612 `delta[i]` is defined in Definition 4 of paper.
613 """
614 assert isinstance(i, int) and 0 <= i < self.iter_count
615 if i == 0:
616 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
617 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
618 else:
619 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
620 prev_max_delta = self.max_delta(i - 1)
621 assert prev_max_delta >= 0
622 retval = prev_max_delta ** 2 + self.max_f(i - 1)
623
624 # `delta[i]` has to be smaller than one otherwise errors would go off
625 # to infinity
626 _assert_accuracy(retval < 1)
627
628 return self._shrink_max(retval)
629
630 @cache_on_self
631 def max_pi(self, i):
632 """ maximum value of `pi[i]`.
633 `pi[i]` is defined right below Theorem 5 of paper.
634 """
635 assert isinstance(i, int) and 0 <= i < self.iter_count
636 # `pi[i] = 1 - (1 - n[i]) * prod`
637 # where `prod` is the product of,
638 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
639 min_prod = Fraction(1)
640 for j in range(i):
641 max_n_j = self.max_n(j)
642 max_d_j = self.max_d(j)
643 assert max_n_j <= 1 and max_d_j > -1, \
644 "only one quadrant of interval division implemented"
645 min_prod *= (1 - max_n_j) / (1 + max_d_j)
646 max_n_i = self.max_n(i)
647 assert max_n_i <= 1 and min_prod >= 0, \
648 "only one quadrant of interval multiplication implemented"
649 retval = 1 - (1 - max_n_i) * min_prod
650 return self._shrink_max(retval)
651
652 @cached_property
653 def max_n_shift(self):
654 """ maximum value of `state.n_shift`.
655 """
656 # input numerator is `2*io_width`-bits
657 max_n = (1 << (self.io_width * 2)) - 1
658 max_n_shift = 0
659 # normalize so 1 <= n < 2
660 while max_n >= 2:
661 max_n >>= 1
662 max_n_shift += 1
663 return max_n_shift
664
665
666 @enum.unique
667 class GoldschmidtDivOp(enum.Enum):
668 Normalize = "n, d, n_shift = normalize(n, d)"
669 FEqTableLookup = "f = table_lookup(d)"
670 MulNByF = "n *= f"
671 MulDByF = "d *= f"
672 FEq2MinusD = "f = 2 - d"
673 CalcResult = "result = unnormalize_and_round(n)"
674
675 def run(self, params, state):
676 assert isinstance(params, GoldschmidtDivParams)
677 assert isinstance(state, GoldschmidtDivState)
678 expanded_width = params.expanded_width
679 table_addr_bits = params.table_addr_bits
680 if self == GoldschmidtDivOp.Normalize:
681 # normalize so 1 <= d < 2
682 # can easily be done with count-leading-zeros and left shift
683 while state.d < 1:
684 state.n = (state.n * 2).to_frac_wid(expanded_width)
685 state.d = (state.d * 2).to_frac_wid(expanded_width)
686
687 state.n_shift = 0
688 # normalize so 1 <= n < 2
689 while state.n >= 2:
690 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
691 state.n_shift += 1
692 elif self == GoldschmidtDivOp.FEqTableLookup:
693 # compute initial f by table lookup
694 d_m_1 = state.d - 1
695 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
696 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
697 state.f = params.table[d_m_1.bits]
698 elif self == GoldschmidtDivOp.MulNByF:
699 assert state.f is not None
700 n = state.n * state.f
701 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
702 elif self == GoldschmidtDivOp.MulDByF:
703 assert state.f is not None
704 d = state.d * state.f
705 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
706 elif self == GoldschmidtDivOp.FEq2MinusD:
707 state.f = (2 - state.d).to_frac_wid(expanded_width)
708 elif self == GoldschmidtDivOp.CalcResult:
709 assert state.n_shift is not None
710 # scale to correct value
711 n = state.n * (1 << state.n_shift)
712
713 state.quotient = math.floor(n)
714 state.remainder = state.orig_n - state.quotient * state.orig_d
715 if state.remainder >= state.orig_d:
716 state.quotient += 1
717 state.remainder -= state.orig_d
718 else:
719 assert False, f"unimplemented GoldschmidtDivOp: {self}"
720
721
722 def _goldschmidt_div_ops(params):
723 """ Goldschmidt division algorithm.
724
725 based on:
726 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
727 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
728 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
729
730 arguments:
731 params: GoldschmidtDivParams
732 the parameters for the algorithm
733
734 yields: GoldschmidtDivOp
735 the operations needed to perform the division.
736 """
737 assert isinstance(params, GoldschmidtDivParams)
738
739 # establish assumptions of the paper's error analysis (section 3.1):
740
741 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
742 yield GoldschmidtDivOp.Normalize
743
744 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
745 # the assumption is met by multipliers with > 4-bits precision
746 _assert_accuracy(params.expanded_width > 4)
747
748 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
749 _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2
750 + params.max_f(0) < Fraction(1, 2))
751
752 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
753 # (B is the denominator)
754
755 for addr in range(params.table_addr_count):
756 f_prime_m1 = params.table[addr]
757 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
758
759 yield GoldschmidtDivOp.FEqTableLookup
760
761 # we use Setting I (section 4.1 of the paper):
762 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
763 n_hat = Fraction(0)
764 for i in range(params.iter_count):
765 _assert_accuracy(params.max_f(i) == 0)
766 n_hat = max(n_hat, params.max_n(i), params.max_d(i))
767 yield GoldschmidtDivOp.MulNByF
768 if i != params.iter_count - 1:
769 yield GoldschmidtDivOp.MulDByF
770 yield GoldschmidtDivOp.FEq2MinusD
771
772 # relative approximation error `p(N_prime[i])`:
773 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
774 # `0 <= p(N_prime[i])`
775 # `p(N_prime[i]) <= (2 * i) * n_hat \`
776 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
777 i = params.iter_count - 1 # last used `i`
778 # compute power manually to prevent huge intermediate values
779 power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2)
780 for _ in range(i):
781 power = params._shrink_max(power * power)
782
783 max_rel_error = (2 * i) * n_hat + power
784
785 min_a_over_b = Fraction(1, 2)
786 max_a_over_b = Fraction(2)
787 max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
788 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
789
790 _assert_accuracy(max_rel_error < max_allowed_rel_error,
791 f"not accurate enough: max_rel_error={max_rel_error} "
792 f"max_allowed_rel_error={max_allowed_rel_error}")
793
794 yield GoldschmidtDivOp.CalcResult
795
796
797 def goldschmidt_div(n, d, params):
798 """ Goldschmidt division algorithm.
799
800 based on:
801 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
802 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
803 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
804
805 arguments:
806 n: int
807 numerator. a `2*width`-bit unsigned integer.
808 must be less than `d << width`, otherwise the quotient wouldn't
809 fit in `width` bits.
810 d: int
811 denominator. a `width`-bit unsigned integer. must not be zero.
812 width: int
813 the bit-width of the inputs/outputs. must be a positive integer.
814
815 returns: tuple[int, int]
816 the quotient and remainder. a tuple of two `width`-bit unsigned
817 integers.
818 """
819 assert isinstance(params, GoldschmidtDivParams)
820 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
821 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
822
823 # this whole algorithm is done with fixed-point arithmetic where values
824 # have `width` fractional bits
825
826 state = GoldschmidtDivState(
827 orig_n=n,
828 orig_d=d,
829 n=FixedPoint(n, params.io_width),
830 d=FixedPoint(d, params.io_width),
831 )
832
833 for op in params.ops:
834 op.run(params, state)
835
836 assert state.quotient is not None
837 assert state.remainder is not None
838
839 return state.quotient, state.remainder