working on goldschmidt_div_sqrt.py
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field
8 import math
9 import enum
10 from fractions import Fraction
11 from types import FunctionType
12
13 try:
14 from functools import cached_property
15 except ImportError:
16 from cached_property import cached_property
17
18 # fix broken IDE type detection for cached_property
19 from typing import TYPE_CHECKING
20 if TYPE_CHECKING:
21 from functools import cached_property
22
23
24 _NOT_FOUND = object()
25
26
27 def cache_on_self(func):
28 """like `functools.cached_property`, except for methods. unlike
29 `lru_cache` the cache is per-class instance rather than a global cache
30 per-method."""
31
32 assert isinstance(func, FunctionType), \
33 "non-plain methods are not supported"
34
35 cache_name = func.__name__ + "__cache"
36
37 def wrapper(self, *args, **kwargs):
38 # specifically access through `__dict__` to bypass frozen=True
39 cache = self.__dict__.get(cache_name, _NOT_FOUND)
40 if cache is _NOT_FOUND:
41 self.__dict__[cache_name] = cache = {}
42 key = (args, *kwargs.items())
43 retval = cache.get(key, _NOT_FOUND)
44 if retval is _NOT_FOUND:
45 retval = func(self, *args, **kwargs)
46 cache[key] = retval
47 return retval
48
49 wrapper.__doc__ = func.__doc__
50 return wrapper
51
52
53 @enum.unique
54 class RoundDir(enum.Enum):
55 DOWN = enum.auto()
56 UP = enum.auto()
57 NEAREST_TIES_UP = enum.auto()
58 ERROR_IF_INEXACT = enum.auto()
59
60
61 @dataclass(frozen=True)
62 class FixedPoint:
63 bits: int
64 frac_wid: int
65
66 def __post_init__(self):
67 assert isinstance(self.bits, int)
68 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
69
70 @staticmethod
71 def cast(value):
72 """convert `value` to a fixed-point number with enough fractional
73 bits to preserve its value."""
74 if isinstance(value, FixedPoint):
75 return value
76 if isinstance(value, int):
77 return FixedPoint(value, 0)
78 if isinstance(value, str):
79 value = value.strip()
80 neg = value.startswith("-")
81 if neg or value.startswith("+"):
82 value = value[1:]
83 if value.startswith(("0x", "0X")) and "." in value:
84 value = value[2:]
85 got_dot = False
86 bits = 0
87 frac_wid = 0
88 for digit in value:
89 if digit == "_":
90 continue
91 if got_dot:
92 if digit == ".":
93 raise ValueError("too many `.` in string")
94 frac_wid += 4
95 if digit == ".":
96 got_dot = True
97 continue
98 if not digit.isalnum():
99 raise ValueError("invalid hexadecimal digit")
100 bits <<= 4
101 bits |= int("0x" + digit, base=16)
102 else:
103 bits = int(value, base=0)
104 frac_wid = 0
105 if neg:
106 bits = -bits
107 return FixedPoint(bits, frac_wid)
108
109 if isinstance(value, float):
110 n, d = value.as_integer_ratio()
111 log2_d = d.bit_length() - 1
112 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
113 "fail with float being IEEE 754")
114 return FixedPoint(n, log2_d)
115 raise TypeError("can't convert type to FixedPoint")
116
117 @staticmethod
118 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
119 """convert `value` to the nearest fixed-point number with `frac_wid`
120 fractional bits, rounding according to `round_dir`."""
121 assert isinstance(frac_wid, int) and frac_wid >= 0
122 assert isinstance(round_dir, RoundDir)
123 if isinstance(value, Fraction):
124 numerator = value.numerator
125 denominator = value.denominator
126 else:
127 value = FixedPoint.cast(value)
128 # compute number of bits that should be removed from value
129 del_bits = value.frac_wid - frac_wid
130 if del_bits == 0:
131 return value
132 if del_bits < 0: # add bits
133 return FixedPoint(value.bits << -del_bits,
134 frac_wid)
135 numerator = value.bits
136 denominator = 1 << value.frac_wid
137 if denominator < 0:
138 numerator = -numerator
139 denominator = -denominator
140 bits, remainder = divmod(numerator << frac_wid, denominator)
141 if round_dir == RoundDir.DOWN:
142 pass
143 elif round_dir == RoundDir.UP:
144 if remainder != 0:
145 bits += 1
146 elif round_dir == RoundDir.NEAREST_TIES_UP:
147 if remainder * 2 >= denominator:
148 bits += 1
149 elif round_dir == RoundDir.ERROR_IF_INEXACT:
150 if remainder != 0:
151 raise ValueError("inexact conversion")
152 else:
153 assert False, "unimplemented round_dir"
154 return FixedPoint(bits, frac_wid)
155
156 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
157 """convert to the nearest fixed-point number with `frac_wid`
158 fractional bits, rounding according to `round_dir`."""
159 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
160
161 def __float__(self):
162 # use truediv to get correct result even when bits
163 # and frac_wid are huge
164 return float(self.bits / (1 << self.frac_wid))
165
166 def as_fraction(self):
167 return Fraction(self.bits, 1 << self.frac_wid)
168
169 def cmp(self, rhs):
170 """compare self with rhs, returning a positive integer if self is
171 greater than rhs, zero if self is equal to rhs, and a negative integer
172 if self is less than rhs."""
173 rhs = FixedPoint.cast(rhs)
174 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
175 lhs = self.to_frac_wid(common_frac_wid)
176 rhs = rhs.to_frac_wid(common_frac_wid)
177 return lhs.bits - rhs.bits
178
179 def __eq__(self, rhs):
180 return self.cmp(rhs) == 0
181
182 def __ne__(self, rhs):
183 return self.cmp(rhs) != 0
184
185 def __gt__(self, rhs):
186 return self.cmp(rhs) > 0
187
188 def __lt__(self, rhs):
189 return self.cmp(rhs) < 0
190
191 def __ge__(self, rhs):
192 return self.cmp(rhs) >= 0
193
194 def __le__(self, rhs):
195 return self.cmp(rhs) <= 0
196
197 def fract(self):
198 """return the fractional part of `self`.
199 that is `self - math.floor(self)`.
200 """
201 fract_mask = (1 << self.frac_wid) - 1
202 return FixedPoint(self.bits & fract_mask, self.frac_wid)
203
204 def __str__(self):
205 if self < 0:
206 return "-" + str(-self)
207 digit_bits = 4
208 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
209 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
210 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
211 return hex(math.floor(self)) + "." + frac_str
212
213 def __repr__(self):
214 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
215
216 def __add__(self, rhs):
217 rhs = FixedPoint.cast(rhs)
218 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
219 lhs = self.to_frac_wid(common_frac_wid)
220 rhs = rhs.to_frac_wid(common_frac_wid)
221 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
222
223 def __radd__(self, lhs):
224 # symmetric
225 return self.__add__(lhs)
226
227 def __neg__(self):
228 return FixedPoint(-self.bits, self.frac_wid)
229
230 def __sub__(self, rhs):
231 rhs = FixedPoint.cast(rhs)
232 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
233 lhs = self.to_frac_wid(common_frac_wid)
234 rhs = rhs.to_frac_wid(common_frac_wid)
235 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
236
237 def __rsub__(self, lhs):
238 # a - b == -(b - a)
239 return -self.__sub__(lhs)
240
241 def __mul__(self, rhs):
242 rhs = FixedPoint.cast(rhs)
243 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
244
245 def __rmul__(self, lhs):
246 # symmetric
247 return self.__mul__(lhs)
248
249 def __floor__(self):
250 return self.bits >> self.frac_wid
251
252
253 @dataclass
254 class GoldschmidtDivState:
255 orig_n: int
256 """original numerator"""
257
258 orig_d: int
259 """original denominator"""
260
261 n: FixedPoint
262 """numerator -- N_prime[i] in the paper's algorithm 2"""
263
264 d: FixedPoint
265 """denominator -- D_prime[i] in the paper's algorithm 2"""
266
267 f: "FixedPoint | None" = None
268 """current factor -- F_prime[i] in the paper's algorithm 2"""
269
270 quotient: "int | None" = None
271 """final quotient"""
272
273 remainder: "int | None" = None
274 """final remainder"""
275
276 n_shift: "int | None" = None
277 """amount the numerator needs to be left-shifted at the end of the
278 algorithm.
279 """
280
281
282 class ParamsNotAccurateEnough(Exception):
283 """raised when the parameters aren't accurate enough to have goldschmidt
284 division work."""
285
286
287 def _assert_accuracy(condition, msg="not accurate enough"):
288 if condition:
289 return
290 raise ParamsNotAccurateEnough(msg)
291
292
293 @dataclass(frozen=True, unsafe_hash=True)
294 class GoldschmidtDivParams:
295 """parameters for a Goldschmidt division algorithm.
296 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
297 """
298
299 io_width: int
300 """bit-width of the input divisor and the result.
301 the input numerator is `2 * io_width`-bits wide.
302 """
303
304 extra_precision: int
305 """number of bits of additional precision used inside the algorithm."""
306
307 table_addr_bits: int
308 """the number of address bits used in the lookup-table."""
309
310 table_data_bits: int
311 """the number of data bits used in the lookup-table."""
312
313 iter_count: int
314 """the total number of iterations of the division algorithm's loop"""
315
316 # tuple to be immutable
317 table: "tuple[FixedPoint, ...]" = field(init=False)
318 """the lookup-table"""
319
320 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
321 """the operations needed to perform the goldschmidt division algorithm."""
322
323 @property
324 def table_addr_count(self):
325 """number of distinct addresses in the lookup-table."""
326 # used while computing self.table, so can't just do len(self.table)
327 return 1 << self.table_addr_bits
328
329 def table_input_exact_range(self, addr):
330 """return the range of inputs as `Fraction`s used for the table entry
331 with address `addr`."""
332 assert isinstance(addr, int)
333 assert 0 <= addr < self.table_addr_count
334 _assert_accuracy(self.io_width >= self.table_addr_bits)
335 min_numerator = (1 << self.table_addr_bits) + addr
336 denominator = 1 << self.table_addr_bits
337 values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
338 max_numerator = min_numerator + values_per_table_entry
339 min_input = Fraction(min_numerator, denominator)
340 max_input = Fraction(max_numerator, denominator)
341 return min_input, max_input
342
343 def table_value_exact_range(self, addr):
344 """return the range of values as `Fraction`s used for the table entry
345 with address `addr`."""
346 min_value, max_value = self.table_input_exact_range(addr)
347 # division swaps min/max
348 return 1 / max_value, 1 / min_value
349
350 def table_exact_value(self, index):
351 min_value, max_value = self.table_value_exact_range(index)
352 # we round down
353 return min_value
354
355 def __post_init__(self):
356 # called by the autogenerated __init__
357 assert self.io_width >= 1
358 assert self.extra_precision >= 0
359 assert self.table_addr_bits >= 1
360 assert self.table_data_bits >= 1
361 assert self.iter_count >= 1
362 table = []
363 for addr in range(1 << self.table_addr_bits):
364 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
365 self.table_data_bits,
366 RoundDir.DOWN))
367 # we have to use object.__setattr__ since frozen=True
368 object.__setattr__(self, "table", tuple(table))
369 object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
370
371 @staticmethod
372 def get(io_width):
373 """ find efficient parameters for a goldschmidt division algorithm
374 with `params.io_width == io_width`.
375 """
376 assert isinstance(io_width, int) and io_width >= 1
377 for extra_precision in range(io_width * 2 + 4):
378 for table_addr_bits in range(1, 7 + 1):
379 table_data_bits = io_width + extra_precision
380 for iter_count in range(1, 2 * io_width.bit_length()):
381 try:
382 return GoldschmidtDivParams(
383 io_width=io_width,
384 extra_precision=extra_precision,
385 table_addr_bits=table_addr_bits,
386 table_data_bits=table_data_bits,
387 iter_count=iter_count)
388 except ParamsNotAccurateEnough:
389 pass
390 raise ValueError(f"can't find working parameters for a goldschmidt "
391 f"division algorithm with io_width={io_width}")
392
393 @property
394 def expanded_width(self):
395 """the total number of bits of precision used inside the algorithm."""
396 return self.io_width + self.extra_precision
397
398 @cache_on_self
399 def max_neps(self, i):
400 """maximum value of `neps[i]`.
401 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
402 """
403 assert isinstance(i, int) and 0 <= i < self.iter_count
404 return Fraction(1, 1 << self.expanded_width)
405
406 @cache_on_self
407 def max_deps(self, i):
408 """maximum value of `deps[i]`.
409 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
410 """
411 assert isinstance(i, int) and 0 <= i < self.iter_count
412 return Fraction(1, 1 << self.expanded_width)
413
414 @cache_on_self
415 def max_feps(self, i):
416 """maximum value of `feps[i]`.
417 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
418 """
419 assert isinstance(i, int) and 0 <= i < self.iter_count
420 # zero, because the computation of `F_prime[i]` in
421 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
422 return Fraction(0)
423
424 @cached_property
425 def e0_range(self):
426 """minimum and maximum values of `e[0]`
427 (the relative error in `F_prime[-1]`)
428 """
429 min_e0 = Fraction(0)
430 max_e0 = Fraction(0)
431 for addr in range(self.table_addr_count):
432 # `F_prime[-1] = (1 - e[0]) / B`
433 # => `e[0] = 1 - B * F_prime[-1]`
434 min_b, max_b = self.table_input_exact_range(addr)
435 f_prime_m1 = self.table[addr].as_fraction()
436 assert min_b >= 0 and f_prime_m1 >= 0, \
437 "only positive quadrant of interval multiplication implemented"
438 min_product = min_b * f_prime_m1
439 max_product = max_b * f_prime_m1
440 # negation swaps min/max
441 cur_min_e0 = 1 - max_product
442 cur_max_e0 = 1 - min_product
443 min_e0 = min(min_e0, cur_min_e0)
444 max_e0 = max(max_e0, cur_max_e0)
445 return min_e0, max_e0
446
447 @cached_property
448 def min_e0(self):
449 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
450 """
451 min_e0, max_e0 = self.e0_range
452 return min_e0
453
454 @cached_property
455 def max_e0(self):
456 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
457 """
458 min_e0, max_e0 = self.e0_range
459 return max_e0
460
461 @cached_property
462 def max_abs_e0(self):
463 """maximum value of `abs(e[0])`."""
464 return max(abs(self.min_e0), abs(self.max_e0))
465
466 @cached_property
467 def min_abs_e0(self):
468 """minimum value of `abs(e[0])`."""
469 return Fraction(0)
470
471 @cache_on_self
472 def max_n(self, i):
473 """maximum value of `n[i]` (the relative error in `N_prime[i]`
474 relative to the previous iteration)
475 """
476 assert isinstance(i, int) and 0 <= i < self.iter_count
477 if i == 0:
478 # from Claim 10
479 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
480 # `n[0] <= 2 * neps[0] / (1 - e[0])`
481
482 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
483 "only one quadrant of interval division implemented"
484 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
485 elif i == 1:
486 # from Claim 10
487 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
488 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
489 assert self.max_f(0) <= 1 and min_mpd >= 0, \
490 "only one quadrant of interval multiplication implemented"
491 prod = (1 - self.max_f(0)) * min_mpd
492 assert self.max_neps(1) >= 0 and prod > 0, \
493 "only one quadrant of interval division implemented"
494 retval = self.max_neps(1) / prod
495 else:
496 # from Claim 6
497 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
498 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
499 assert self.max_neps(i) >= 0 and min_mpd > 0, \
500 "only one quadrant of interval division implemented"
501 retval = self.max_neps(i) / min_mpd
502
503 # we need Fraction to avoid using float by accident
504 # -- it also hints to the IDE to give the correct type
505 return Fraction(retval)
506
507 @cache_on_self
508 def max_d(self, i):
509 """maximum value of `d[i]` (the relative error in `D_prime[i]`
510 relative to the previous iteration)
511 """
512 assert isinstance(i, int) and 0 <= i < self.iter_count
513 if i == 0:
514 # from Claim 10
515 # `d[0] = deps[0] / (1 - e[0])`
516
517 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
518 "only one quadrant of interval division implemented"
519 retval = self.max_deps(0) / (1 - self.max_e0)
520 elif i == 1:
521 # from Claim 10
522 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
523 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
524 "only one quadrant of interval multiplication implemented"
525 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
526 assert self.max_deps(1) >= 0 and divisor > 0, \
527 "only one quadrant of interval division implemented"
528 retval = self.max_deps(1) / divisor
529 else:
530 # from Claim 6
531 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
532 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
533 "only one quadrant of interval division implemented"
534 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
535
536 # we need Fraction to avoid using float by accident
537 # -- it also hints to the IDE to give the correct type
538 return Fraction(retval)
539
540 @cache_on_self
541 def max_f(self, i):
542 """maximum value of `f[i]` (the relative error in `F_prime[i]`
543 relative to the previous iteration)
544 """
545 assert isinstance(i, int) and 0 <= i < self.iter_count
546 if i == 0:
547 # from Claim 10
548 # `f[0] = feps[0] / (1 - delta[0])`
549
550 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
551 "only one quadrant of interval division implemented"
552 retval = self.max_feps(0) / (1 - self.max_delta(0))
553 elif i == 1:
554 # from Claim 10
555 # `f[1] = feps[1]`
556 retval = self.max_feps(1)
557 else:
558 # from Claim 6
559 # `f[i] <= max_feps[i]`
560 retval = self.max_feps(i)
561
562 # we need Fraction to avoid using float by accident
563 # -- it also hints to the IDE to give the correct type
564 return Fraction(retval)
565
566 @cache_on_self
567 def max_delta(self, i):
568 """ maximum value of `delta[i]`.
569 `delta[i]` is defined in Definition 4 of paper.
570 """
571 assert isinstance(i, int) and 0 <= i < self.iter_count
572 if i == 0:
573 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
574 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
575 else:
576 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
577 prev_max_delta = self.max_delta(i - 1)
578 assert prev_max_delta >= 0
579 retval = prev_max_delta ** 2 + self.max_f(i - 1)
580
581 # we need Fraction to avoid using float by accident
582 # -- it also hints to the IDE to give the correct type
583 return Fraction(retval)
584
585 @cache_on_self
586 def max_pi(self, i):
587 """ maximum value of `pi[i]`.
588 `pi[i]` is defined right below Theorem 5 of paper.
589 """
590 assert isinstance(i, int) and 0 <= i < self.iter_count
591 # `pi[i] = 1 - (1 - n[i]) * prod`
592 # where `prod` is the product of,
593 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
594 min_prod = Fraction(0)
595 for j in range(i):
596 max_n_j = self.max_n(j)
597 max_d_j = self.max_d(j)
598 assert max_n_j <= 1 and max_d_j > -1, \
599 "only one quadrant of interval division implemented"
600 min_prod *= (1 - max_n_j) / (1 + max_d_j)
601 max_n_i = self.max_n(i)
602 assert max_n_i <= 1 and min_prod >= 0, \
603 "only one quadrant of interval multiplication implemented"
604 return 1 - (1 - max_n_i) * min_prod
605
606 @cached_property
607 def max_n_shift(self):
608 """ maximum value of `state.n_shift`.
609 """
610 # input numerator is `2*io_width`-bits
611 max_n = (1 << (self.io_width * 2)) - 1
612 max_n_shift = 0
613 # normalize so 1 <= n < 2
614 while max_n >= 2:
615 max_n >>= 1
616 max_n_shift += 1
617 return max_n_shift
618
619
620 @enum.unique
621 class GoldschmidtDivOp(enum.Enum):
622 Normalize = "n, d, n_shift = normalize(n, d)"
623 FEqTableLookup = "f = table_lookup(d)"
624 MulNByF = "n *= f"
625 MulDByF = "d *= f"
626 FEq2MinusD = "f = 2 - d"
627 CalcResult = "result = unnormalize_and_round(n)"
628
629 def run(self, params, state):
630 assert isinstance(params, GoldschmidtDivParams)
631 assert isinstance(state, GoldschmidtDivState)
632 expanded_width = params.expanded_width
633 table_addr_bits = params.table_addr_bits
634 if self == GoldschmidtDivOp.Normalize:
635 # normalize so 1 <= d < 2
636 # can easily be done with count-leading-zeros and left shift
637 while state.d < 1:
638 state.n = (state.n * 2).to_frac_wid(expanded_width)
639 state.d = (state.d * 2).to_frac_wid(expanded_width)
640
641 state.n_shift = 0
642 # normalize so 1 <= n < 2
643 while state.n >= 2:
644 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
645 state.n_shift += 1
646 elif self == GoldschmidtDivOp.FEqTableLookup:
647 # compute initial f by table lookup
648 d_m_1 = state.d - 1
649 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
650 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
651 state.f = params.table[d_m_1.bits]
652 elif self == GoldschmidtDivOp.MulNByF:
653 assert state.f is not None
654 n = state.n * state.f
655 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
656 elif self == GoldschmidtDivOp.MulDByF:
657 assert state.f is not None
658 d = state.d * state.f
659 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
660 elif self == GoldschmidtDivOp.FEq2MinusD:
661 state.f = (2 - state.d).to_frac_wid(expanded_width)
662 elif self == GoldschmidtDivOp.CalcResult:
663 assert state.n_shift is not None
664 # scale to correct value
665 n = state.n * (1 << state.n_shift)
666
667 state.quotient = math.floor(n)
668 state.remainder = state.orig_n - state.quotient * state.orig_d
669 if state.remainder >= state.orig_d:
670 state.quotient += 1
671 state.remainder -= state.orig_d
672 else:
673 assert False, f"unimplemented GoldschmidtDivOp: {self}"
674
675
676 def _goldschmidt_div_ops(params):
677 """ Goldschmidt division algorithm.
678
679 based on:
680 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
681 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
682 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
683
684 arguments:
685 params: GoldschmidtDivParams
686 the parameters for the algorithm
687
688 yields: GoldschmidtDivOp
689 the operations needed to perform the division.
690 """
691 assert isinstance(params, GoldschmidtDivParams)
692
693 # establish assumptions of the paper's error analysis (section 3.1):
694
695 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
696 yield GoldschmidtDivOp.Normalize
697
698 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
699 # the assumption is met by multipliers with > 4-bits precision
700 _assert_accuracy(params.expanded_width > 4)
701
702 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
703 _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2
704 + params.max_f(0) < Fraction(1, 2))
705
706 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
707 # (B is the denominator)
708
709 for addr in range(params.table_addr_count):
710 f_prime_m1 = params.table[addr]
711 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
712
713 yield GoldschmidtDivOp.FEqTableLookup
714
715 # we use Setting I (section 4.1 of the paper):
716 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
717 n_hat = Fraction(0)
718 for i in range(params.iter_count):
719 _assert_accuracy(params.max_f(i) == 0)
720 n_hat = max(n_hat, params.max_n(i), params.max_d(i))
721 yield GoldschmidtDivOp.MulNByF
722 if i != params.iter_count - 1:
723 yield GoldschmidtDivOp.MulDByF
724 yield GoldschmidtDivOp.FEq2MinusD
725
726 # relative approximation error `p(N_prime[i])`:
727 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
728 # `0 <= p(N_prime[i])`
729 # `p(N_prime[i]) <= (2 * i) * n_hat \`
730 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
731 i = params.iter_count - 1 # last used `i`
732 max_rel_error = (2 * i) * n_hat + \
733 (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i)
734
735 min_a_over_b = Fraction(1, 2)
736 max_a_over_b = Fraction(2)
737 max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
738 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
739
740 _assert_accuracy(max_rel_error < max_allowed_rel_error)
741
742 yield GoldschmidtDivOp.CalcResult
743
744
745 def goldschmidt_div(n, d, params):
746 """ Goldschmidt division algorithm.
747
748 based on:
749 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
750 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
751 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
752
753 arguments:
754 n: int
755 numerator. a `2*width`-bit unsigned integer.
756 must be less than `d << width`, otherwise the quotient wouldn't
757 fit in `width` bits.
758 d: int
759 denominator. a `width`-bit unsigned integer. must not be zero.
760 width: int
761 the bit-width of the inputs/outputs. must be a positive integer.
762
763 returns: tuple[int, int]
764 the quotient and remainder. a tuple of two `width`-bit unsigned
765 integers.
766 """
767 assert isinstance(params, GoldschmidtDivParams)
768 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
769 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
770
771 # this whole algorithm is done with fixed-point arithmetic where values
772 # have `width` fractional bits
773
774 state = GoldschmidtDivState(
775 orig_n=n,
776 orig_d=d,
777 n=FixedPoint(n, params.io_width),
778 d=FixedPoint(d, params.io_width),
779 )
780
781 for op in params.ops:
782 op.run(params, state)
783
784 assert state.quotient is not None
785 assert state.remainder is not None
786
787 return state.quotient, state.remainder