change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 import logging
9 import math
10 import enum
11 from fractions import Fraction
12 from types import FunctionType
13 from functools import lru_cache
14 from nmigen.hdl.ast import Signal, unsigned, signed, Const
15 from nmigen.hdl.dsl import Module, Elaboratable
16 from nmigen.hdl.mem import Memory
17 from nmutil.clz import CLZ
18 from nmutil.plain_data import plain_data, fields, replace
19
20 try:
21 from functools import cached_property
22 except ImportError:
23 from cached_property import cached_property
24
25 # fix broken IDE type detection for cached_property
26 from typing import TYPE_CHECKING, Any
27 if TYPE_CHECKING:
28 from functools import cached_property
29
30
31 _NOT_FOUND = object()
32
33
34 def cache_on_self(func):
35 """like `functools.cached_property`, except for methods. unlike
36 `lru_cache` the cache is per-class instance rather than a global cache
37 per-method."""
38
39 assert isinstance(func, FunctionType), \
40 "non-plain methods are not supported"
41
42 cache_name = func.__name__ + "__cache"
43
44 def wrapper(self, *args, **kwargs):
45 # specifically access through `__dict__` to bypass frozen=True
46 cache = self.__dict__.get(cache_name, _NOT_FOUND)
47 if cache is _NOT_FOUND:
48 self.__dict__[cache_name] = cache = {}
49 key = (args, *kwargs.items())
50 retval = cache.get(key, _NOT_FOUND)
51 if retval is _NOT_FOUND:
52 retval = func(self, *args, **kwargs)
53 cache[key] = retval
54 return retval
55
56 wrapper.__doc__ = func.__doc__
57 return wrapper
58
59
60 @enum.unique
61 class RoundDir(enum.Enum):
62 DOWN = enum.auto()
63 UP = enum.auto()
64 NEAREST_TIES_UP = enum.auto()
65 ERROR_IF_INEXACT = enum.auto()
66
67
68 @plain_data(frozen=True, eq=False, repr=False)
69 class FixedPoint:
70 __slots__ = "bits", "frac_wid"
71
72 def __init__(self, bits, frac_wid):
73 self.bits = bits
74 self.frac_wid = frac_wid
75 assert isinstance(self.bits, int)
76 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
77
78 @staticmethod
79 def cast(value):
80 """convert `value` to a fixed-point number with enough fractional
81 bits to preserve its value."""
82 if isinstance(value, FixedPoint):
83 return value
84 if isinstance(value, int):
85 return FixedPoint(value, 0)
86 if isinstance(value, str):
87 value = value.strip()
88 neg = value.startswith("-")
89 if neg or value.startswith("+"):
90 value = value[1:]
91 if value.startswith(("0x", "0X")) and "." in value:
92 value = value[2:]
93 got_dot = False
94 bits = 0
95 frac_wid = 0
96 for digit in value:
97 if digit == "_":
98 continue
99 if got_dot:
100 if digit == ".":
101 raise ValueError("too many `.` in string")
102 frac_wid += 4
103 if digit == ".":
104 got_dot = True
105 continue
106 if not digit.isalnum():
107 raise ValueError("invalid hexadecimal digit")
108 bits <<= 4
109 bits |= int("0x" + digit, base=16)
110 else:
111 bits = int(value, base=0)
112 frac_wid = 0
113 if neg:
114 bits = -bits
115 return FixedPoint(bits, frac_wid)
116
117 if isinstance(value, float):
118 n, d = value.as_integer_ratio()
119 log2_d = d.bit_length() - 1
120 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
121 "fail with float being IEEE 754")
122 return FixedPoint(n, log2_d)
123 raise TypeError("can't convert type to FixedPoint")
124
125 @staticmethod
126 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
127 """convert `value` to the nearest fixed-point number with `frac_wid`
128 fractional bits, rounding according to `round_dir`."""
129 assert isinstance(frac_wid, int) and frac_wid >= 0
130 assert isinstance(round_dir, RoundDir)
131 if isinstance(value, Fraction):
132 numerator = value.numerator
133 denominator = value.denominator
134 else:
135 value = FixedPoint.cast(value)
136 numerator = value.bits
137 denominator = 1 << value.frac_wid
138 if denominator < 0:
139 numerator = -numerator
140 denominator = -denominator
141 bits, remainder = divmod(numerator << frac_wid, denominator)
142 if round_dir == RoundDir.DOWN:
143 pass
144 elif round_dir == RoundDir.UP:
145 if remainder != 0:
146 bits += 1
147 elif round_dir == RoundDir.NEAREST_TIES_UP:
148 if remainder * 2 >= denominator:
149 bits += 1
150 elif round_dir == RoundDir.ERROR_IF_INEXACT:
151 if remainder != 0:
152 raise ValueError("inexact conversion")
153 else:
154 assert False, "unimplemented round_dir"
155 return FixedPoint(bits, frac_wid)
156
157 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
158 """convert to the nearest fixed-point number with `frac_wid`
159 fractional bits, rounding according to `round_dir`."""
160 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
161
162 def __float__(self):
163 # use truediv to get correct result even when bits
164 # and frac_wid are huge
165 return float(self.bits / (1 << self.frac_wid))
166
167 def as_fraction(self):
168 return Fraction(self.bits, 1 << self.frac_wid)
169
170 def cmp(self, rhs):
171 """compare self with rhs, returning a positive integer if self is
172 greater than rhs, zero if self is equal to rhs, and a negative integer
173 if self is less than rhs."""
174 rhs = FixedPoint.cast(rhs)
175 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
176 lhs = self.to_frac_wid(common_frac_wid)
177 rhs = rhs.to_frac_wid(common_frac_wid)
178 return lhs.bits - rhs.bits
179
180 def __eq__(self, rhs):
181 return self.cmp(rhs) == 0
182
183 def __ne__(self, rhs):
184 return self.cmp(rhs) != 0
185
186 def __gt__(self, rhs):
187 return self.cmp(rhs) > 0
188
189 def __lt__(self, rhs):
190 return self.cmp(rhs) < 0
191
192 def __ge__(self, rhs):
193 return self.cmp(rhs) >= 0
194
195 def __le__(self, rhs):
196 return self.cmp(rhs) <= 0
197
198 def fract(self):
199 """return the fractional part of `self`.
200 that is `self - math.floor(self)`.
201 """
202 fract_mask = (1 << self.frac_wid) - 1
203 return FixedPoint(self.bits & fract_mask, self.frac_wid)
204
205 def __str__(self):
206 if self < 0:
207 return "-" + str(-self)
208 digit_bits = 4
209 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
210 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
211 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
212 return hex(math.floor(self)) + "." + frac_str
213
214 def __repr__(self):
215 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
216
217 def __add__(self, rhs):
218 rhs = FixedPoint.cast(rhs)
219 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
220 lhs = self.to_frac_wid(common_frac_wid)
221 rhs = rhs.to_frac_wid(common_frac_wid)
222 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
223
224 def __radd__(self, lhs):
225 # symmetric
226 return self.__add__(lhs)
227
228 def __neg__(self):
229 return FixedPoint(-self.bits, self.frac_wid)
230
231 def __sub__(self, rhs):
232 rhs = FixedPoint.cast(rhs)
233 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
234 lhs = self.to_frac_wid(common_frac_wid)
235 rhs = rhs.to_frac_wid(common_frac_wid)
236 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
237
238 def __rsub__(self, lhs):
239 # a - b == -(b - a)
240 return -self.__sub__(lhs)
241
242 def __mul__(self, rhs):
243 rhs = FixedPoint.cast(rhs)
244 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
245
246 def __rmul__(self, lhs):
247 # symmetric
248 return self.__mul__(lhs)
249
250 def __floor__(self):
251 return self.bits >> self.frac_wid
252
253 def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
254 assert isinstance(frac_wid, int) and frac_wid >= 0
255 assert isinstance(round_dir, RoundDir)
256 rhs = FixedPoint.cast(rhs)
257 return FixedPoint.with_frac_wid(self.as_fraction()
258 / rhs.as_fraction(),
259 frac_wid, round_dir)
260
261 def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
262 assert isinstance(round_dir, RoundDir)
263 if self < 0:
264 raise ValueError("can't compute sqrt of negative number")
265 if self == 0:
266 return self
267 retval = FixedPoint(0, self.frac_wid)
268 int_part_wid = self.bits.bit_length() - self.frac_wid
269 first_bit_index = -(-int_part_wid // 2) # division rounds up
270 last_bit_index = -self.frac_wid
271 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
272 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
273 self.frac_wid)
274 if trial * trial <= self:
275 retval = trial
276 if round_dir == RoundDir.DOWN:
277 pass
278 elif round_dir == RoundDir.UP:
279 if retval * retval < self:
280 retval += FixedPoint(1, self.frac_wid)
281 elif round_dir == RoundDir.NEAREST_TIES_UP:
282 half_way = retval + FixedPoint(1, self.frac_wid + 1)
283 if half_way * half_way <= self:
284 retval += FixedPoint(1, self.frac_wid)
285 elif round_dir == RoundDir.ERROR_IF_INEXACT:
286 if retval * retval != self:
287 raise ValueError("inexact sqrt")
288 else:
289 assert False, "unimplemented round_dir"
290 return retval
291
292 def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
293 """compute the reciprocal-sqrt of `self`"""
294 assert isinstance(round_dir, RoundDir)
295 if self < 0:
296 raise ValueError("can't compute rsqrt of negative number")
297 if self == 0:
298 raise ZeroDivisionError("can't compute rsqrt of zero")
299 retval = FixedPoint(0, self.frac_wid)
300 first_bit_index = -(-self.frac_wid // 2) # division rounds up
301 last_bit_index = -self.frac_wid
302 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
303 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
304 self.frac_wid)
305 if trial * trial * self <= 1:
306 retval = trial
307 if round_dir == RoundDir.DOWN:
308 pass
309 elif round_dir == RoundDir.UP:
310 if retval * retval * self < 1:
311 retval += FixedPoint(1, self.frac_wid)
312 elif round_dir == RoundDir.NEAREST_TIES_UP:
313 half_way = retval + FixedPoint(1, self.frac_wid + 1)
314 if half_way * half_way * self <= 1:
315 retval += FixedPoint(1, self.frac_wid)
316 elif round_dir == RoundDir.ERROR_IF_INEXACT:
317 if retval * retval * self != 1:
318 raise ValueError("inexact rsqrt")
319 else:
320 assert False, "unimplemented round_dir"
321 return retval
322
323
324 class ParamsNotAccurateEnough(Exception):
325 """raised when the parameters aren't accurate enough to have goldschmidt
326 division work."""
327
328
329 def _assert_accuracy(condition, msg="not accurate enough"):
330 if condition:
331 return
332 raise ParamsNotAccurateEnough(msg)
333
334
335 @plain_data(frozen=True, unsafe_hash=True)
336 class GoldschmidtDivParamsBase:
337 """parameters for a Goldschmidt division algorithm, excluding derived
338 parameters.
339 """
340
341 __slots__ = ("io_width", "extra_precision", "table_addr_bits",
342 "table_data_bits", "iter_count")
343
344 def __init__(self, io_width, extra_precision, table_addr_bits,
345 table_data_bits, iter_count):
346 assert isinstance(io_width, int)
347 assert isinstance(extra_precision, int)
348 assert isinstance(table_addr_bits, int)
349 assert isinstance(table_data_bits, int)
350 assert isinstance(iter_count, int)
351 self.io_width = io_width
352 """bit-width of the input divisor and the result.
353 the input numerator is `2 * io_width`-bits wide.
354 """
355
356 self.extra_precision = extra_precision
357 """number of bits of additional precision used inside the algorithm."""
358
359 self.table_addr_bits = table_addr_bits
360 """the number of address bits used in the lookup-table."""
361
362 self.table_data_bits = table_data_bits
363 """the number of data bits used in the lookup-table."""
364
365 self.iter_count = iter_count
366 """the total number of iterations of the division algorithm's loop"""
367
368
369 @plain_data(frozen=True, unsafe_hash=True)
370 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
371 """parameters for a Goldschmidt division algorithm.
372 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
373 """
374
375 __slots__ = "table", "ops"
376
377 def _shrink_bound(self, bound, round_dir):
378 """prevent fractions from having huge numerators/denominators by
379 rounding to a `FixedPoint` and converting back to a `Fraction`.
380
381 This is intended only for values used to compute bounds, and not for
382 values that end up in the hardware.
383 """
384 assert isinstance(bound, (Fraction, int))
385 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
386 "you shouldn't use that round_dir on bounds"
387 frac_wid = self.io_width * 4 + 100 # should be enough precision
388 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
389 return fixed.as_fraction()
390
391 def _shrink_min(self, min_bound):
392 """prevent fractions used as minimum bounds from having huge
393 numerators/denominators by rounding down to a `FixedPoint` and
394 converting back to a `Fraction`.
395
396 This is intended only for values used to compute bounds, and not for
397 values that end up in the hardware.
398 """
399 return self._shrink_bound(min_bound, RoundDir.DOWN)
400
401 def _shrink_max(self, max_bound):
402 """prevent fractions used as maximum bounds from having huge
403 numerators/denominators by rounding up to a `FixedPoint` and
404 converting back to a `Fraction`.
405
406 This is intended only for values used to compute bounds, and not for
407 values that end up in the hardware.
408 """
409 return self._shrink_bound(max_bound, RoundDir.UP)
410
411 @property
412 def table_addr_count(self):
413 """number of distinct addresses in the lookup-table."""
414 # used while computing self.table, so can't just do len(self.table)
415 return 1 << self.table_addr_bits
416
417 def table_input_exact_range(self, addr):
418 """return the range of inputs as `Fraction`s used for the table entry
419 with address `addr`."""
420 assert isinstance(addr, int)
421 assert 0 <= addr < self.table_addr_count
422 _assert_accuracy(self.io_width >= self.table_addr_bits)
423 addr_shift = self.io_width - self.table_addr_bits
424 min_numerator = (1 << self.io_width) + (addr << addr_shift)
425 denominator = 1 << self.io_width
426 values_per_table_entry = 1 << addr_shift
427 max_numerator = min_numerator + values_per_table_entry - 1
428 min_input = Fraction(min_numerator, denominator)
429 max_input = Fraction(max_numerator, denominator)
430 min_input = self._shrink_min(min_input)
431 max_input = self._shrink_max(max_input)
432 assert 1 <= min_input <= max_input < 2
433 return min_input, max_input
434
435 def table_value_exact_range(self, addr):
436 """return the range of values as `Fraction`s used for the table entry
437 with address `addr`."""
438 min_input, max_input = self.table_input_exact_range(addr)
439 # division swaps min/max
440 min_value = 1 / max_input
441 max_value = 1 / min_input
442 min_value = self._shrink_min(min_value)
443 max_value = self._shrink_max(max_value)
444 assert 0.5 < min_value <= max_value <= 1
445 return min_value, max_value
446
447 def table_exact_value(self, index):
448 min_value, max_value = self.table_value_exact_range(index)
449 # we round down
450 return min_value
451
452 def __init__(self, io_width, extra_precision, table_addr_bits,
453 table_data_bits, iter_count):
454 super().__init__(io_width=io_width,
455 extra_precision=extra_precision,
456 table_addr_bits=table_addr_bits,
457 table_data_bits=table_data_bits,
458 iter_count=iter_count)
459 _assert_accuracy(self.io_width >= 1, "io_width out of range")
460 _assert_accuracy(self.extra_precision >= 0,
461 "extra_precision out of range")
462 _assert_accuracy(self.table_addr_bits >= 1,
463 "table_addr_bits out of range")
464 _assert_accuracy(self.table_data_bits >= 1,
465 "table_data_bits out of range")
466 _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
467 table = []
468 for addr in range(1 << self.table_addr_bits):
469 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
470 self.table_data_bits,
471 RoundDir.DOWN))
472
473 self.table = tuple(table)
474 """ the lookup-table.
475 type: tuple[FixedPoint, ...]
476 """
477
478 self.ops = tuple(self.__make_ops())
479 "the operations needed to perform the goldschmidt division algorithm."
480
481 @property
482 def expanded_width(self):
483 """the total number of bits of precision used inside the algorithm."""
484 return self.io_width + self.extra_precision
485
486 @property
487 def n_d_f_int_wid(self):
488 """the number of bits in the integer part of `state.n`, `state.d`, and
489 `state.f` during the main iteration loop.
490 """
491 return 2
492
493 @property
494 def n_d_f_total_wid(self):
495 """the total number of bits (both integer and fraction bits) in
496 `state.n`, `state.d`, and `state.f` during the main iteration loop.
497 """
498 return self.n_d_f_int_wid + self.expanded_width
499
500 @cache_on_self
501 def max_neps(self, i):
502 """maximum value of `neps[i]`.
503 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
504 """
505 assert isinstance(i, int) and 0 <= i < self.iter_count
506 return Fraction(1, 1 << self.expanded_width)
507
508 @cache_on_self
509 def max_deps(self, i):
510 """maximum value of `deps[i]`.
511 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
512 """
513 assert isinstance(i, int) and 0 <= i < self.iter_count
514 return Fraction(1, 1 << self.expanded_width)
515
516 @cache_on_self
517 def max_feps(self, i):
518 """maximum value of `feps[i]`.
519 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
520 """
521 assert isinstance(i, int) and 0 <= i < self.iter_count
522 # zero, because the computation of `F_prime[i]` in
523 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
524 return Fraction(0)
525
526 @cached_property
527 def e0_range(self):
528 """minimum and maximum values of `e[0]`
529 (the relative error in `F_prime[-1]`)
530 """
531 min_e0 = Fraction(0)
532 max_e0 = Fraction(0)
533 for addr in range(self.table_addr_count):
534 # `F_prime[-1] = (1 - e[0]) / B`
535 # => `e[0] = 1 - B * F_prime[-1]`
536 min_b, max_b = self.table_input_exact_range(addr)
537 f_prime_m1 = self.table[addr].as_fraction()
538 assert min_b >= 0 and f_prime_m1 >= 0, \
539 "only positive quadrant of interval multiplication implemented"
540 min_product = min_b * f_prime_m1
541 max_product = max_b * f_prime_m1
542 # negation swaps min/max
543 cur_min_e0 = 1 - max_product
544 cur_max_e0 = 1 - min_product
545 min_e0 = min(min_e0, cur_min_e0)
546 max_e0 = max(max_e0, cur_max_e0)
547 min_e0 = self._shrink_min(min_e0)
548 max_e0 = self._shrink_max(max_e0)
549 return min_e0, max_e0
550
551 @cached_property
552 def min_e0(self):
553 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
554 """
555 min_e0, max_e0 = self.e0_range
556 return min_e0
557
558 @cached_property
559 def max_e0(self):
560 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
561 """
562 min_e0, max_e0 = self.e0_range
563 return max_e0
564
565 @cached_property
566 def max_abs_e0(self):
567 """maximum value of `abs(e[0])`."""
568 return max(abs(self.min_e0), abs(self.max_e0))
569
570 @cached_property
571 def min_abs_e0(self):
572 """minimum value of `abs(e[0])`."""
573 return Fraction(0)
574
575 @cache_on_self
576 def max_n(self, i):
577 """maximum value of `n[i]` (the relative error in `N_prime[i]`
578 relative to the previous iteration)
579 """
580 assert isinstance(i, int) and 0 <= i < self.iter_count
581 if i == 0:
582 # from Claim 10
583 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
584 # `n[0] <= 2 * neps[0] / (1 - e[0])`
585
586 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
587 "only one quadrant of interval division implemented"
588 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
589 elif i == 1:
590 # from Claim 10
591 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
592 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
593 assert self.max_f(0) <= 1 and min_mpd >= 0, \
594 "only one quadrant of interval multiplication implemented"
595 prod = (1 - self.max_f(0)) * min_mpd
596 assert self.max_neps(1) >= 0 and prod > 0, \
597 "only one quadrant of interval division implemented"
598 retval = self.max_neps(1) / prod
599 else:
600 # from Claim 6
601 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
602 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
603 assert self.max_neps(i) >= 0 and min_mpd > 0, \
604 "only one quadrant of interval division implemented"
605 retval = self.max_neps(i) / min_mpd
606
607 return self._shrink_max(retval)
608
609 @cache_on_self
610 def max_d(self, i):
611 """maximum value of `d[i]` (the relative error in `D_prime[i]`
612 relative to the previous iteration)
613 """
614 assert isinstance(i, int) and 0 <= i < self.iter_count
615 if i == 0:
616 # from Claim 10
617 # `d[0] = deps[0] / (1 - e[0])`
618
619 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
620 "only one quadrant of interval division implemented"
621 retval = self.max_deps(0) / (1 - self.max_e0)
622 elif i == 1:
623 # from Claim 10
624 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
625 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
626 "only one quadrant of interval multiplication implemented"
627 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
628 assert self.max_deps(1) >= 0 and divisor > 0, \
629 "only one quadrant of interval division implemented"
630 retval = self.max_deps(1) / divisor
631 else:
632 # from Claim 6
633 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
634 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
635 "only one quadrant of interval division implemented"
636 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
637
638 return self._shrink_max(retval)
639
640 @cache_on_self
641 def max_f(self, i):
642 """maximum value of `f[i]` (the relative error in `F_prime[i]`
643 relative to the previous iteration)
644 """
645 assert isinstance(i, int) and 0 <= i < self.iter_count
646 if i == 0:
647 # from Claim 10
648 # `f[0] = feps[0] / (1 - delta[0])`
649
650 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
651 "only one quadrant of interval division implemented"
652 retval = self.max_feps(0) / (1 - self.max_delta(0))
653 elif i == 1:
654 # from Claim 10
655 # `f[1] = feps[1]`
656 retval = self.max_feps(1)
657 else:
658 # from Claim 6
659 # `f[i] <= max_feps[i]`
660 retval = self.max_feps(i)
661
662 return self._shrink_max(retval)
663
664 @cache_on_self
665 def max_delta(self, i):
666 """ maximum value of `delta[i]`.
667 `delta[i]` is defined in Definition 4 of paper.
668 """
669 assert isinstance(i, int) and 0 <= i < self.iter_count
670 if i == 0:
671 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
672 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
673 else:
674 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
675 prev_max_delta = self.max_delta(i - 1)
676 assert prev_max_delta >= 0
677 retval = prev_max_delta ** 2 + self.max_f(i - 1)
678
679 # `delta[i]` has to be smaller than one otherwise errors would go off
680 # to infinity
681 _assert_accuracy(retval < 1)
682
683 return self._shrink_max(retval)
684
685 @cache_on_self
686 def max_pi(self, i):
687 """ maximum value of `pi[i]`.
688 `pi[i]` is defined right below Theorem 5 of paper.
689 """
690 assert isinstance(i, int) and 0 <= i < self.iter_count
691 # `pi[i] = 1 - (1 - n[i]) * prod`
692 # where `prod` is the product of,
693 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
694 min_prod = Fraction(1)
695 for j in range(i):
696 max_n_j = self.max_n(j)
697 max_d_j = self.max_d(j)
698 assert max_n_j <= 1 and max_d_j > -1, \
699 "only one quadrant of interval division implemented"
700 min_prod *= (1 - max_n_j) / (1 + max_d_j)
701 max_n_i = self.max_n(i)
702 assert max_n_i <= 1 and min_prod >= 0, \
703 "only one quadrant of interval multiplication implemented"
704 retval = 1 - (1 - max_n_i) * min_prod
705 return self._shrink_max(retval)
706
707 @cached_property
708 def max_n_shift(self):
709 """ maximum value of `state.n_shift`.
710 """
711 # numerator must be less than `denominator << self.io_width`, so
712 # `n_shift` is at most `self.io_width`
713 return self.io_width
714
715 @cached_property
716 def n_hat(self):
717 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
718 """
719 n_hat = Fraction(0)
720 for i in range(self.iter_count):
721 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
722 return self._shrink_max(n_hat)
723
724 def __make_ops(self):
725 """ Goldschmidt division algorithm.
726
727 based on:
728 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
729 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
730 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
731
732 yields: GoldschmidtDivOp
733 the operations needed to perform the division.
734 """
735 # establish assumptions of the paper's error analysis (section 3.1):
736
737 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
738 yield GoldschmidtDivOp.Normalize
739
740 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
741 # the assumption is met by multipliers with > 4-bits precision
742 _assert_accuracy(self.expanded_width > 4)
743
744 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
745 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
746 + self.max_f(0) < Fraction(1, 2))
747
748 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
749 # (B is the denominator)
750
751 for addr in range(self.table_addr_count):
752 f_prime_m1 = self.table[addr]
753 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
754
755 yield GoldschmidtDivOp.FEqTableLookup
756
757 # we use Setting I (section 4.1 of the paper):
758 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
759 # the conditions on n_hat are satisfied by construction.
760 for i in range(self.iter_count):
761 _assert_accuracy(self.max_f(i) == 0)
762 yield GoldschmidtDivOp.MulNByF
763 if i != self.iter_count - 1:
764 yield GoldschmidtDivOp.MulDByF
765 yield GoldschmidtDivOp.FEq2MinusD
766
767 # relative approximation error `p(N_prime[i])`:
768 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
769 # `0 <= p(N_prime[i])`
770 # `p(N_prime[i]) <= (2 * i) * n_hat \`
771 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
772 i = self.iter_count - 1 # last used `i`
773 # compute power manually to prevent huge intermediate values
774 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
775 for _ in range(i):
776 power = self._shrink_max(power * power)
777
778 max_rel_error = (2 * i) * self.n_hat + power
779
780 min_a_over_b = Fraction(1, 2)
781 min_abs_error_for_correctness = min_a_over_b / (1 << self.max_n_shift)
782 min_rel_error_for_correctness = (min_abs_error_for_correctness
783 / min_a_over_b)
784
785 _assert_accuracy(
786 max_rel_error < min_rel_error_for_correctness,
787 f"not accurate enough: max_rel_error={max_rel_error}"
788 f" min_rel_error_for_correctness={min_rel_error_for_correctness}")
789
790 yield GoldschmidtDivOp.CalcResult
791
792 @cache_on_self
793 def default_cost_fn(self):
794 """ calculate the estimated cost on an arbitrary scale of implementing
795 goldschmidt division with the specified parameters. larger cost
796 values mean worse parameters.
797
798 This is the default cost function for `GoldschmidtDivParams.get`.
799
800 returns: float
801 """
802 rom_cells = self.table_data_bits << self.table_addr_bits
803 cost = float(rom_cells)
804 for op in self.ops:
805 if op == GoldschmidtDivOp.MulNByF \
806 or op == GoldschmidtDivOp.MulDByF:
807 mul_cost = self.expanded_width ** 2
808 mul_cost *= self.expanded_width.bit_length()
809 cost += mul_cost
810 cost += 5e7 * self.iter_count
811 return cost
812
813 @staticmethod
814 @lru_cache(maxsize=1 << 16)
815 def __cached_new(base_params):
816 assert isinstance(base_params, GoldschmidtDivParamsBase)
817 kwargs = {}
818 for field in fields(GoldschmidtDivParamsBase):
819 kwargs[field] = getattr(base_params, field)
820 try:
821 return GoldschmidtDivParams(**kwargs), None
822 except ParamsNotAccurateEnough as e:
823 return None, e
824
825 @staticmethod
826 def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
827 raise e
828
829 @staticmethod
830 def cached_new(base_params, handle_error=__raise):
831 assert isinstance(base_params, GoldschmidtDivParamsBase)
832 params, error = GoldschmidtDivParams.__cached_new(base_params)
833 if error is None:
834 return params
835 else:
836 return handle_error(error)
837
838 @staticmethod
839 def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
840 """ find efficient parameters for a goldschmidt division algorithm
841 with `params.io_width == io_width`.
842
843 arguments:
844 io_width: int
845 bit-width of the input divisor and the result.
846 the input numerator is `2 * io_width`-bits wide.
847 cost_fn: Callable[[GoldschmidtDivParams], float]
848 return the estimated cost on an arbitrary scale of implementing
849 goldschmidt division with the specified parameters. larger cost
850 values mean worse parameters.
851 max_table_addr_bits: int
852 maximum allowable value of `table_addr_bits`
853 """
854 assert isinstance(io_width, int) and io_width >= 1
855 assert callable(cost_fn)
856
857 last_error = None
858 last_error_params = None
859
860 def cached_new(base_params):
861 def handle_error(e):
862 nonlocal last_error, last_error_params
863 last_error = e
864 last_error_params = base_params
865 return None
866
867 retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
868 if retval is None:
869 logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
870 else:
871 logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
872 return retval
873
874 @lru_cache(maxsize=None)
875 def get_cost(base_params):
876 params = cached_new(base_params)
877 if params is None:
878 return math.inf
879 retval = cost_fn(params)
880 logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
881 return retval
882
883 # start with parameters big enough to always work.
884 initial_extra_precision = io_width * 2 + 4
885 initial_params = GoldschmidtDivParamsBase(
886 io_width=io_width,
887 extra_precision=initial_extra_precision,
888 table_addr_bits=min(max_table_addr_bits, io_width),
889 table_data_bits=io_width + initial_extra_precision,
890 iter_count=1 + io_width.bit_length())
891
892 if cached_new(initial_params) is None:
893 raise ValueError(f"initial goldschmidt division algorithm "
894 f"parameters are invalid: {initial_params}"
895 ) from last_error
896
897 # find good initial `iter_count`
898 params = initial_params
899 for iter_count in range(1, initial_params.iter_count):
900 trial_params = replace(params, iter_count=iter_count)
901 if cached_new(trial_params) is not None:
902 params = trial_params
903 break
904
905 # now find `table_addr_bits`
906 cost = get_cost(params)
907 for table_addr_bits in range(1, max_table_addr_bits):
908 trial_params = replace(params, table_addr_bits=table_addr_bits)
909 trial_cost = get_cost(trial_params)
910 if trial_cost < cost:
911 params = trial_params
912 cost = trial_cost
913 break
914
915 # check one higher `iter_count` to see if it has lower cost
916 for table_addr_bits in range(1, max_table_addr_bits + 1):
917 trial_params = replace(params,
918 table_addr_bits=table_addr_bits,
919 iter_count=params.iter_count + 1)
920 trial_cost = get_cost(trial_params)
921 if trial_cost < cost:
922 params = trial_params
923 cost = trial_cost
924 break
925
926 # now shrink `table_data_bits`
927 while True:
928 trial_params = replace(params,
929 table_data_bits=params.table_data_bits - 1)
930 trial_cost = get_cost(trial_params)
931 if trial_cost < cost:
932 params = trial_params
933 cost = trial_cost
934 else:
935 break
936
937 # and shrink `extra_precision`
938 while True:
939 trial_params = replace(params,
940 extra_precision=params.extra_precision - 1)
941 trial_cost = get_cost(trial_params)
942 if trial_cost < cost:
943 params = trial_params
944 cost = trial_cost
945 else:
946 break
947
948 retval = cached_new(params)
949 assert isinstance(retval, GoldschmidtDivParams)
950 return retval
951
952
953 def clz(v, wid):
954 """count leading zeros -- handy for debugging."""
955 assert isinstance(wid, int)
956 assert isinstance(v, int) and 0 <= v < (1 << wid)
957 return (1 << wid).bit_length() - v.bit_length()
958
959
960 @enum.unique
961 class GoldschmidtDivOp(enum.Enum):
962 Normalize = "n, d, n_shift = normalize(n, d)"
963 FEqTableLookup = "f = table_lookup(d)"
964 MulNByF = "n *= f"
965 MulDByF = "d *= f"
966 FEq2MinusD = "f = 2 - d"
967 CalcResult = "result = unnormalize_and_round(n)"
968
969 def run(self, params, state):
970 assert isinstance(params, GoldschmidtDivParams)
971 assert isinstance(state, GoldschmidtDivState)
972 expanded_width = params.expanded_width
973 table_addr_bits = params.table_addr_bits
974 if self == GoldschmidtDivOp.Normalize:
975 # normalize so 1 <= d < 2
976 # can easily be done with count-leading-zeros and left shift
977 while state.d < 1:
978 state.n = (state.n * 2).to_frac_wid(expanded_width)
979 state.d = (state.d * 2).to_frac_wid(expanded_width)
980
981 state.n_shift = 0
982 # normalize so 1 <= n < 2
983 while state.n >= 2:
984 state.n = (state.n * 0.5).to_frac_wid(expanded_width,
985 round_dir=RoundDir.DOWN)
986 state.n_shift += 1
987 elif self == GoldschmidtDivOp.FEqTableLookup:
988 # compute initial f by table lookup
989 d_m_1 = state.d - 1
990 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
991 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
992 state.f = params.table[d_m_1.bits]
993 state.f = state.f.to_frac_wid(expanded_width,
994 round_dir=RoundDir.DOWN)
995 elif self == GoldschmidtDivOp.MulNByF:
996 assert state.f is not None
997 n = state.n * state.f
998 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
999 elif self == GoldschmidtDivOp.MulDByF:
1000 assert state.f is not None
1001 d = state.d * state.f
1002 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
1003 elif self == GoldschmidtDivOp.FEq2MinusD:
1004 state.f = (2 - state.d).to_frac_wid(expanded_width)
1005 elif self == GoldschmidtDivOp.CalcResult:
1006 assert state.n_shift is not None
1007 # scale to correct value
1008 n = state.n * (1 << state.n_shift)
1009
1010 state.quotient = math.floor(n)
1011 state.remainder = state.orig_n - state.quotient * state.orig_d
1012 if state.remainder >= state.orig_d:
1013 state.quotient += 1
1014 state.remainder -= state.orig_d
1015 else:
1016 assert False, f"unimplemented GoldschmidtDivOp: {self}"
1017
1018 def gen_hdl(self, params, state, sync_rom):
1019 """generate the hdl for this operation.
1020
1021 arguments:
1022 params: GoldschmidtDivParams
1023 the goldschmidt division parameters.
1024 state: GoldschmidtDivHDLState
1025 the input/output state
1026 sync_rom: bool
1027 true if the rom should be read synchronously rather than
1028 combinatorially, incurring an extra clock cycle of latency.
1029 """
1030 assert isinstance(params, GoldschmidtDivParams)
1031 assert isinstance(state, GoldschmidtDivHDLState)
1032 m = state.m
1033 if self == GoldschmidtDivOp.Normalize:
1034 # normalize so 1 <= d < 2
1035 assert state.d.width == params.io_width
1036 assert state.n.width == 2 * params.io_width
1037 d_leading_zeros = CLZ(params.io_width)
1038 m.submodules.d_leading_zeros = d_leading_zeros
1039 m.d.comb += d_leading_zeros.sig_in.eq(state.d)
1040 d_shift_out = Signal.like(state.d)
1041 m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
1042 d = Signal(params.n_d_f_total_wid)
1043 m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
1044 >> state.d.width)
1045
1046 # normalize so 1 <= n < 2
1047 n_leading_zeros = CLZ(2 * params.io_width)
1048 m.submodules.n_leading_zeros = n_leading_zeros
1049 m.d.comb += n_leading_zeros.sig_in.eq(state.n)
1050 signed_zero = Const(0, signed(1)) # force subtraction to be signed
1051 n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
1052 - n_leading_zeros.lz)
1053 n_shift_s = Signal.like(n_shift_s_v)
1054 n_shift_n_lz_out = Signal.like(state.n)
1055 n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
1056 m.d.comb += [
1057 n_shift_s.eq(n_shift_s_v),
1058 n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
1059 n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
1060 ]
1061 state.n_shift = Signal(d_leading_zeros.lz.width)
1062 n = Signal(params.n_d_f_total_wid)
1063 with m.If(n_shift_s < 0):
1064 m.d.comb += [
1065 state.n_shift.eq(0),
1066 n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
1067 >> state.d.width),
1068 ]
1069 with m.Else():
1070 m.d.comb += [
1071 state.n_shift.eq(n_shift_s),
1072 n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
1073 >> state.n.width),
1074 ]
1075 state.n = n
1076 state.d = d
1077 elif self == GoldschmidtDivOp.FEqTableLookup:
1078 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1079 # compute initial f by table lookup
1080
1081 # extra bit for table entries == 1.0
1082 table_width = 1 + params.table_data_bits
1083 table = Memory(width=table_width, depth=len(params.table),
1084 init=[i.bits for i in params.table])
1085 addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
1086 if sync_rom:
1087 table_read = table.read_port()
1088 m.d.comb += table_read.addr.eq(addr)
1089 state.insert_pipeline_register()
1090 else:
1091 table_read = table.read_port(domain="comb")
1092 m.d.comb += table_read.addr.eq(addr)
1093 m.submodules.table_read = table_read
1094 state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
1095 data_shift = params.expanded_width - params.table_data_bits
1096 m.d.comb += state.f.eq(table_read.data << data_shift)
1097 elif self == GoldschmidtDivOp.MulNByF:
1098 assert state.n.width == params.n_d_f_total_wid, "invalid n width"
1099 assert state.f is not None
1100 assert state.f.width == params.n_d_f_total_wid, "invalid f width"
1101 n = Signal.like(state.n)
1102 m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
1103 state.n = n
1104 elif self == GoldschmidtDivOp.MulDByF:
1105 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1106 assert state.f is not None
1107 assert state.f.width == params.n_d_f_total_wid, "invalid f width"
1108 d = Signal.like(state.d)
1109 d_times_f = Signal.like(state.d * state.f)
1110 m.d.comb += [
1111 d_times_f.eq(state.d * state.f),
1112 # round the multiplication up
1113 d.eq((d_times_f >> params.expanded_width)
1114 + (d_times_f[:params.expanded_width] != 0)),
1115 ]
1116 state.d = d
1117 elif self == GoldschmidtDivOp.FEq2MinusD:
1118 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1119 f = Signal.like(state.d)
1120 m.d.comb += f.eq((2 << params.expanded_width) - state.d)
1121 state.f = f
1122 elif self == GoldschmidtDivOp.CalcResult:
1123 assert state.n.width == params.n_d_f_total_wid, "invalid n width"
1124 assert state.n_shift is not None
1125 # scale to correct value
1126 n = state.n * (1 << state.n_shift)
1127 q_approx = Signal(params.io_width)
1128 # extra bit for if it's bigger than orig_d
1129 r_approx = Signal(params.io_width + 1)
1130 adjusted_r = Signal(signed(1 + params.io_width))
1131 m.d.comb += [
1132 q_approx.eq((state.n << state.n_shift)
1133 >> params.expanded_width),
1134 r_approx.eq(state.orig_n - q_approx * state.orig_d),
1135 adjusted_r.eq(r_approx - state.orig_d),
1136 ]
1137 state.quotient = Signal(params.io_width)
1138 state.remainder = Signal(params.io_width)
1139
1140 with m.If(adjusted_r >= 0):
1141 m.d.comb += [
1142 state.quotient.eq(q_approx + 1),
1143 state.remainder.eq(adjusted_r),
1144 ]
1145 with m.Else():
1146 m.d.comb += [
1147 state.quotient.eq(q_approx),
1148 state.remainder.eq(r_approx),
1149 ]
1150 else:
1151 assert False, f"unimplemented GoldschmidtDivOp: {self}"
1152
1153
1154 @plain_data(repr=False)
1155 class GoldschmidtDivState:
1156 __slots__ = ("orig_n", "orig_d", "n", "d",
1157 "f", "quotient", "remainder", "n_shift")
1158
1159 def __init__(self, orig_n, orig_d, n, d,
1160 f=None, quotient=None, remainder=None, n_shift=None):
1161 assert isinstance(orig_n, int)
1162 assert isinstance(orig_d, int)
1163 assert isinstance(n, FixedPoint)
1164 assert isinstance(d, FixedPoint)
1165 assert f is None or isinstance(f, FixedPoint)
1166 assert quotient is None or isinstance(quotient, int)
1167 assert remainder is None or isinstance(remainder, int)
1168 assert n_shift is None or isinstance(n_shift, int)
1169 self.orig_n = orig_n
1170 """original numerator"""
1171
1172 self.orig_d = orig_d
1173 """original denominator"""
1174
1175 self.n = n
1176 """numerator -- N_prime[i] in the paper's algorithm 2"""
1177
1178 self.d = d
1179 """denominator -- D_prime[i] in the paper's algorithm 2"""
1180
1181 self.f = f
1182 """current factor -- F_prime[i] in the paper's algorithm 2"""
1183
1184 self.quotient = quotient
1185 """final quotient"""
1186
1187 self.remainder = remainder
1188 """final remainder"""
1189
1190 self.n_shift = n_shift
1191 """amount the numerator needs to be left-shifted at the end of the
1192 algorithm.
1193 """
1194
1195 def __repr__(self):
1196 fields_str = []
1197 for field in fields(GoldschmidtDivState):
1198 value = getattr(self, field)
1199 if value is None:
1200 continue
1201 if isinstance(value, int) and field != "n_shift":
1202 fields_str.append(f"{field}={hex(value)}")
1203 else:
1204 fields_str.append(f"{field}={value!r}")
1205 return f"GoldschmidtDivState({', '.join(fields_str)})"
1206
1207
1208 def goldschmidt_div(n, d, params, trace=lambda state: None):
1209 """ Goldschmidt division algorithm.
1210
1211 based on:
1212 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1213 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1214 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1215
1216 arguments:
1217 n: int
1218 numerator. a `2*width`-bit unsigned integer.
1219 must be less than `d << width`, otherwise the quotient wouldn't
1220 fit in `width` bits.
1221 d: int
1222 denominator. a `width`-bit unsigned integer. must not be zero.
1223 width: int
1224 the bit-width of the inputs/outputs. must be a positive integer.
1225 trace: Function[[GoldschmidtDivState], None]
1226 called with the initial state and the state after executing each
1227 operation in `params.ops`.
1228
1229 returns: tuple[int, int]
1230 the quotient and remainder. a tuple of two `width`-bit unsigned
1231 integers.
1232 """
1233 assert isinstance(params, GoldschmidtDivParams)
1234 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
1235 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
1236
1237 # this whole algorithm is done with fixed-point arithmetic where values
1238 # have `width` fractional bits
1239
1240 state = GoldschmidtDivState(
1241 orig_n=n,
1242 orig_d=d,
1243 n=FixedPoint(n, params.io_width),
1244 d=FixedPoint(d, params.io_width),
1245 )
1246
1247 trace(state)
1248 for op in params.ops:
1249 op.run(params, state)
1250 trace(state)
1251
1252 assert state.quotient is not None
1253 assert state.remainder is not None
1254
1255 return state.quotient, state.remainder
1256
1257
1258 @plain_data(eq=False)
1259 class GoldschmidtDivHDLState:
1260 __slots__ = ("m", "orig_n", "orig_d", "n", "d",
1261 "f", "quotient", "remainder", "n_shift")
1262
1263 __signal_name_prefix = "state_"
1264
1265 def __init__(self, m, orig_n, orig_d, n, d,
1266 f=None, quotient=None, remainder=None, n_shift=None):
1267 assert isinstance(m, Module)
1268 assert isinstance(orig_n, Signal)
1269 assert isinstance(orig_d, Signal)
1270 assert isinstance(n, Signal)
1271 assert isinstance(d, Signal)
1272 assert f is None or isinstance(f, Signal)
1273 assert quotient is None or isinstance(quotient, Signal)
1274 assert remainder is None or isinstance(remainder, Signal)
1275 assert n_shift is None or isinstance(n_shift, Signal)
1276
1277 self.m = m
1278 """The HDL Module"""
1279
1280 self.orig_n = orig_n
1281 """original numerator"""
1282
1283 self.orig_d = orig_d
1284 """original denominator"""
1285
1286 self.n = n
1287 """numerator -- N_prime[i] in the paper's algorithm 2"""
1288
1289 self.d = d
1290 """denominator -- D_prime[i] in the paper's algorithm 2"""
1291
1292 self.f = f
1293 """current factor -- F_prime[i] in the paper's algorithm 2"""
1294
1295 self.quotient = quotient
1296 """final quotient"""
1297
1298 self.remainder = remainder
1299 """final remainder"""
1300
1301 self.n_shift = n_shift
1302 """amount the numerator needs to be left-shifted at the end of the
1303 algorithm.
1304 """
1305
1306 # old_signals must be set last
1307 self.old_signals = defaultdict(list)
1308
1309 def __setattr__(self, name, value):
1310 assert isinstance(name, str)
1311 if name.startswith("_"):
1312 return super().__setattr__(name, value)
1313 try:
1314 old_signals = self.old_signals[name]
1315 except AttributeError:
1316 # haven't yet finished __post_init__
1317 return super().__setattr__(name, value)
1318 assert name != "m" and name != "old_signals", f"can't write to {name}"
1319 assert isinstance(value, Signal)
1320 value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
1321 old_signal = getattr(self, name, None)
1322 if old_signal is not None:
1323 assert isinstance(old_signal, Signal)
1324 old_signals.append(old_signal)
1325 return super().__setattr__(name, value)
1326
1327 def insert_pipeline_register(self):
1328 old_prefix = self.__signal_name_prefix
1329 try:
1330 for field in fields(GoldschmidtDivHDLState):
1331 if field.startswith("_") or field == "m":
1332 continue
1333 old_sig = getattr(self, field, None)
1334 if old_sig is None:
1335 continue
1336 assert isinstance(old_sig, Signal)
1337 new_sig = Signal.like(old_sig)
1338 setattr(self, field, new_sig)
1339 self.m.d.sync += new_sig.eq(old_sig)
1340 finally:
1341 self.__signal_name_prefix = old_prefix
1342
1343
1344 class GoldschmidtDivHDL(Elaboratable):
1345 """ Goldschmidt division algorithm.
1346
1347 based on:
1348 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1349 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1350 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1351
1352 attributes:
1353 params: GoldschmidtDivParams
1354 the goldschmidt division algorithm parameters.
1355 pipe_reg_indexes: list[int]
1356 the operation indexes where pipeline registers should be inserted.
1357 duplicate values mean multiple registers should be inserted for
1358 that operation index -- this is useful to allow yosys to spread a
1359 multiplication across those multiple pipeline stages.
1360 sync_rom: bool
1361 true if the rom should be read synchronously rather than
1362 combinatorially, incurring an extra clock cycle of latency.
1363 n: Signal(unsigned(2 * params.io_width))
1364 input numerator. a `2 * params.io_width`-bit unsigned integer.
1365 must be less than `d << params.io_width`, otherwise the quotient
1366 wouldn't fit in `params.io_width` bits.
1367 d: Signal(unsigned(params.io_width))
1368 input denominator. a `params.io_width`-bit unsigned integer.
1369 must not be zero.
1370 q: Signal(unsigned(params.io_width))
1371 output quotient. only valid when `n < (d << params.io_width)`.
1372 r: Signal(unsigned(params.io_width))
1373 output remainder. only valid when `n < (d << params.io_width)`.
1374 trace: list[GoldschmidtDivHDLState]
1375 list of the initial state and the state after executing each
1376 operation in `params.ops`.
1377 """
1378
1379 @property
1380 def total_pipeline_registers(self):
1381 """the total number of pipeline registers"""
1382 return len(self.pipe_reg_indexes) + self.sync_rom
1383
1384 def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
1385 assert isinstance(params, GoldschmidtDivParams)
1386 assert isinstance(sync_rom, bool)
1387 self.params = params
1388 self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
1389 self.sync_rom = sync_rom
1390 self.n = Signal(unsigned(2 * params.io_width))
1391 self.d = Signal(unsigned(params.io_width))
1392 self.q = Signal(unsigned(params.io_width))
1393 self.r = Signal(unsigned(params.io_width))
1394
1395 # in constructor so we get trace without needing to call elaborate
1396 state = GoldschmidtDivHDLState(
1397 m=Module(),
1398 orig_n=self.n,
1399 orig_d=self.d,
1400 n=self.n,
1401 d=self.d)
1402
1403 self.trace = [replace(state)]
1404
1405 # copy and reverse
1406 pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
1407
1408 for op_index, op in enumerate(self.params.ops):
1409 while len(pipe_reg_indexes) > 0 \
1410 and pipe_reg_indexes[-1] <= op_index:
1411 pipe_reg_indexes.pop()
1412 state.insert_pipeline_register()
1413 op.gen_hdl(self.params, state, self.sync_rom)
1414 self.trace.append(replace(state))
1415
1416 while len(pipe_reg_indexes) > 0:
1417 pipe_reg_indexes.pop()
1418 state.insert_pipeline_register()
1419
1420 state.m.d.comb += [
1421 self.q.eq(state.quotient),
1422 self.r.eq(state.remainder),
1423 ]
1424
1425 def elaborate(self, platform):
1426 return self.trace[0].m
1427
1428
1429 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
1430
1431
1432 @lru_cache()
1433 def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
1434 """Generate the look-up table needed for Goldschmidt's square-root and
1435 reciprocal-square-root algorithm.
1436
1437 arguments:
1438 table_addr_bits: int
1439 the number of address bits for the look-up table.
1440 table_data_bits: int
1441 the number of data bits for the look-up table.
1442 """
1443 assert isinstance(table_addr_bits, int) and \
1444 table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1445 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1446 table = []
1447 table_len = 1 << table_addr_bits
1448 for addr in range(table_len):
1449 if addr == 0:
1450 value = FixedPoint(0, table_data_bits)
1451 elif (addr << 2) < table_len:
1452 value = None # table entries should be unused
1453 else:
1454 table_addr_frac_wid = table_addr_bits
1455 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1456 max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
1457 max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
1458 value = max_input_value.to_frac_wid(max_frac_wid)
1459 value = value.rsqrt(RoundDir.DOWN)
1460 value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
1461 table.append(value)
1462
1463 # tuple for immutability
1464 return tuple(table)
1465
1466 # FIXME: add code to calculate error bounds and check that the algorithm will
1467 # actually work (like in the goldschmidt division algorithm).
1468 # FIXME: add code to calculate a good set of parameters based on the error
1469 # bounds checking.
1470
1471
1472 def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
1473 table_addr_bits, table_data_bits, iter_count):
1474 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1475
1476 uses algorithm based on second method at:
1477 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1478
1479 arguments:
1480 radicand: FixedPoint(frac_wid=frac_wid)
1481 the input value to take the square-root and reciprocal-square-root of.
1482 io_width: int
1483 the number of bits in the input (`radicand`) and output values.
1484 frac_wid: int
1485 the number of fraction bits in the input (`radicand`) and output
1486 values.
1487 extra_precision: int
1488 the number of bits of internal extra precision.
1489 table_addr_bits: int
1490 the number of address bits for the look-up table.
1491 table_data_bits: int
1492 the number of data bits for the look-up table.
1493
1494 returns: tuple[FixedPoint, FixedPoint]
1495 the square-root and reciprocal-square-root, rounded down to the
1496 nearest representable value. If `radicand == 0`, then the
1497 reciprocal-square-root value returned is zero.
1498 """
1499 assert (isinstance(radicand, FixedPoint)
1500 and radicand.frac_wid == frac_wid
1501 and 0 <= radicand.bits < (1 << io_width))
1502 assert isinstance(io_width, int) and io_width >= 1
1503 assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
1504 assert isinstance(extra_precision, int) and extra_precision >= io_width
1505 assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
1506 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1507 assert isinstance(iter_count, int) and iter_count >= 0
1508 expanded_frac_wid = frac_wid + extra_precision
1509 s = radicand.to_frac_wid(expanded_frac_wid)
1510 sqrt_rshift = extra_precision
1511 rsqrt_rshift = extra_precision
1512 while s != 0 and s < 1:
1513 s = (s * 4).to_frac_wid(expanded_frac_wid)
1514 sqrt_rshift += 1
1515 rsqrt_rshift -= 1
1516 while s >= 4:
1517 s = s.div(4, expanded_frac_wid)
1518 sqrt_rshift -= 1
1519 rsqrt_rshift += 1
1520 table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
1521 table_data_bits=table_data_bits)
1522 # core goldschmidt sqrt/rsqrt algorithm:
1523 # initial setup:
1524 table_addr_frac_wid = table_addr_bits
1525 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1526 addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
1527 assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
1528 f = table[addr.bits]
1529 assert f is not None, "accessed invalid table entry"
1530 # use with_frac_wid to fix IDE type deduction
1531 f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
1532 x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1533 h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1534 for _ in range(iter_count):
1535 # iteration step:
1536 f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1537 x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1538 h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1539 r = 2 * h
1540 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1541
1542 sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
1543 rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
1544
1545 next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
1546 if next_sqrt * next_sqrt <= radicand:
1547 sqrt = next_sqrt
1548
1549 next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
1550 if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
1551 rsqrt = next_rsqrt
1552 return sqrt, rsqrt