fix so HDL works for 5, 8, 16, 32, and 64-bits.
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 from dataclasses import dataclass, field, fields, replace
9 import logging
10 import math
11 import enum
12 from fractions import Fraction
13 from types import FunctionType
14 from functools import lru_cache
15 from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
16 from nmigen.hdl.dsl import Module, Elaboratable
17 from nmigen.hdl.mem import Memory
18 from nmutil.clz import CLZ
19
20 try:
21 from functools import cached_property
22 except ImportError:
23 from cached_property import cached_property
24
25 # fix broken IDE type detection for cached_property
26 from typing import TYPE_CHECKING, Any
27 if TYPE_CHECKING:
28 from functools import cached_property
29
30
31 _NOT_FOUND = object()
32
33
34 def cache_on_self(func):
35 """like `functools.cached_property`, except for methods. unlike
36 `lru_cache` the cache is per-class instance rather than a global cache
37 per-method."""
38
39 assert isinstance(func, FunctionType), \
40 "non-plain methods are not supported"
41
42 cache_name = func.__name__ + "__cache"
43
44 def wrapper(self, *args, **kwargs):
45 # specifically access through `__dict__` to bypass frozen=True
46 cache = self.__dict__.get(cache_name, _NOT_FOUND)
47 if cache is _NOT_FOUND:
48 self.__dict__[cache_name] = cache = {}
49 key = (args, *kwargs.items())
50 retval = cache.get(key, _NOT_FOUND)
51 if retval is _NOT_FOUND:
52 retval = func(self, *args, **kwargs)
53 cache[key] = retval
54 return retval
55
56 wrapper.__doc__ = func.__doc__
57 return wrapper
58
59
60 @enum.unique
61 class RoundDir(enum.Enum):
62 DOWN = enum.auto()
63 UP = enum.auto()
64 NEAREST_TIES_UP = enum.auto()
65 ERROR_IF_INEXACT = enum.auto()
66
67
68 @dataclass(frozen=True)
69 class FixedPoint:
70 bits: int
71 frac_wid: int
72
73 def __post_init__(self):
74 # called by the autogenerated __init__
75 assert isinstance(self.bits, int)
76 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
77
78 @staticmethod
79 def cast(value):
80 """convert `value` to a fixed-point number with enough fractional
81 bits to preserve its value."""
82 if isinstance(value, FixedPoint):
83 return value
84 if isinstance(value, int):
85 return FixedPoint(value, 0)
86 if isinstance(value, str):
87 value = value.strip()
88 neg = value.startswith("-")
89 if neg or value.startswith("+"):
90 value = value[1:]
91 if value.startswith(("0x", "0X")) and "." in value:
92 value = value[2:]
93 got_dot = False
94 bits = 0
95 frac_wid = 0
96 for digit in value:
97 if digit == "_":
98 continue
99 if got_dot:
100 if digit == ".":
101 raise ValueError("too many `.` in string")
102 frac_wid += 4
103 if digit == ".":
104 got_dot = True
105 continue
106 if not digit.isalnum():
107 raise ValueError("invalid hexadecimal digit")
108 bits <<= 4
109 bits |= int("0x" + digit, base=16)
110 else:
111 bits = int(value, base=0)
112 frac_wid = 0
113 if neg:
114 bits = -bits
115 return FixedPoint(bits, frac_wid)
116
117 if isinstance(value, float):
118 n, d = value.as_integer_ratio()
119 log2_d = d.bit_length() - 1
120 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
121 "fail with float being IEEE 754")
122 return FixedPoint(n, log2_d)
123 raise TypeError("can't convert type to FixedPoint")
124
125 @staticmethod
126 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
127 """convert `value` to the nearest fixed-point number with `frac_wid`
128 fractional bits, rounding according to `round_dir`."""
129 assert isinstance(frac_wid, int) and frac_wid >= 0
130 assert isinstance(round_dir, RoundDir)
131 if isinstance(value, Fraction):
132 numerator = value.numerator
133 denominator = value.denominator
134 else:
135 value = FixedPoint.cast(value)
136 numerator = value.bits
137 denominator = 1 << value.frac_wid
138 if denominator < 0:
139 numerator = -numerator
140 denominator = -denominator
141 bits, remainder = divmod(numerator << frac_wid, denominator)
142 if round_dir == RoundDir.DOWN:
143 pass
144 elif round_dir == RoundDir.UP:
145 if remainder != 0:
146 bits += 1
147 elif round_dir == RoundDir.NEAREST_TIES_UP:
148 if remainder * 2 >= denominator:
149 bits += 1
150 elif round_dir == RoundDir.ERROR_IF_INEXACT:
151 if remainder != 0:
152 raise ValueError("inexact conversion")
153 else:
154 assert False, "unimplemented round_dir"
155 return FixedPoint(bits, frac_wid)
156
157 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
158 """convert to the nearest fixed-point number with `frac_wid`
159 fractional bits, rounding according to `round_dir`."""
160 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
161
162 def __float__(self):
163 # use truediv to get correct result even when bits
164 # and frac_wid are huge
165 return float(self.bits / (1 << self.frac_wid))
166
167 def as_fraction(self):
168 return Fraction(self.bits, 1 << self.frac_wid)
169
170 def cmp(self, rhs):
171 """compare self with rhs, returning a positive integer if self is
172 greater than rhs, zero if self is equal to rhs, and a negative integer
173 if self is less than rhs."""
174 rhs = FixedPoint.cast(rhs)
175 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
176 lhs = self.to_frac_wid(common_frac_wid)
177 rhs = rhs.to_frac_wid(common_frac_wid)
178 return lhs.bits - rhs.bits
179
180 def __eq__(self, rhs):
181 return self.cmp(rhs) == 0
182
183 def __ne__(self, rhs):
184 return self.cmp(rhs) != 0
185
186 def __gt__(self, rhs):
187 return self.cmp(rhs) > 0
188
189 def __lt__(self, rhs):
190 return self.cmp(rhs) < 0
191
192 def __ge__(self, rhs):
193 return self.cmp(rhs) >= 0
194
195 def __le__(self, rhs):
196 return self.cmp(rhs) <= 0
197
198 def fract(self):
199 """return the fractional part of `self`.
200 that is `self - math.floor(self)`.
201 """
202 fract_mask = (1 << self.frac_wid) - 1
203 return FixedPoint(self.bits & fract_mask, self.frac_wid)
204
205 def __str__(self):
206 if self < 0:
207 return "-" + str(-self)
208 digit_bits = 4
209 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
210 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
211 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
212 return hex(math.floor(self)) + "." + frac_str
213
214 def __repr__(self):
215 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
216
217 def __add__(self, rhs):
218 rhs = FixedPoint.cast(rhs)
219 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
220 lhs = self.to_frac_wid(common_frac_wid)
221 rhs = rhs.to_frac_wid(common_frac_wid)
222 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
223
224 def __radd__(self, lhs):
225 # symmetric
226 return self.__add__(lhs)
227
228 def __neg__(self):
229 return FixedPoint(-self.bits, self.frac_wid)
230
231 def __sub__(self, rhs):
232 rhs = FixedPoint.cast(rhs)
233 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
234 lhs = self.to_frac_wid(common_frac_wid)
235 rhs = rhs.to_frac_wid(common_frac_wid)
236 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
237
238 def __rsub__(self, lhs):
239 # a - b == -(b - a)
240 return -self.__sub__(lhs)
241
242 def __mul__(self, rhs):
243 rhs = FixedPoint.cast(rhs)
244 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
245
246 def __rmul__(self, lhs):
247 # symmetric
248 return self.__mul__(lhs)
249
250 def __floor__(self):
251 return self.bits >> self.frac_wid
252
253 def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
254 assert isinstance(frac_wid, int) and frac_wid >= 0
255 assert isinstance(round_dir, RoundDir)
256 rhs = FixedPoint.cast(rhs)
257 return FixedPoint.with_frac_wid(self.as_fraction()
258 / rhs.as_fraction(),
259 frac_wid, round_dir)
260
261 def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
262 assert isinstance(round_dir, RoundDir)
263 if self < 0:
264 raise ValueError("can't compute sqrt of negative number")
265 if self == 0:
266 return self
267 retval = FixedPoint(0, self.frac_wid)
268 int_part_wid = self.bits.bit_length() - self.frac_wid
269 first_bit_index = -(-int_part_wid // 2) # division rounds up
270 last_bit_index = -self.frac_wid
271 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
272 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
273 self.frac_wid)
274 if trial * trial <= self:
275 retval = trial
276 if round_dir == RoundDir.DOWN:
277 pass
278 elif round_dir == RoundDir.UP:
279 if retval * retval < self:
280 retval += FixedPoint(1, self.frac_wid)
281 elif round_dir == RoundDir.NEAREST_TIES_UP:
282 half_way = retval + FixedPoint(1, self.frac_wid + 1)
283 if half_way * half_way <= self:
284 retval += FixedPoint(1, self.frac_wid)
285 elif round_dir == RoundDir.ERROR_IF_INEXACT:
286 if retval * retval != self:
287 raise ValueError("inexact sqrt")
288 else:
289 assert False, "unimplemented round_dir"
290 return retval
291
292 def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
293 """compute the reciprocal-sqrt of `self`"""
294 assert isinstance(round_dir, RoundDir)
295 if self < 0:
296 raise ValueError("can't compute rsqrt of negative number")
297 if self == 0:
298 raise ZeroDivisionError("can't compute rsqrt of zero")
299 retval = FixedPoint(0, self.frac_wid)
300 first_bit_index = -(-self.frac_wid // 2) # division rounds up
301 last_bit_index = -self.frac_wid
302 for bit_index in range(first_bit_index, last_bit_index - 1, -1):
303 trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
304 self.frac_wid)
305 if trial * trial * self <= 1:
306 retval = trial
307 if round_dir == RoundDir.DOWN:
308 pass
309 elif round_dir == RoundDir.UP:
310 if retval * retval * self < 1:
311 retval += FixedPoint(1, self.frac_wid)
312 elif round_dir == RoundDir.NEAREST_TIES_UP:
313 half_way = retval + FixedPoint(1, self.frac_wid + 1)
314 if half_way * half_way * self <= 1:
315 retval += FixedPoint(1, self.frac_wid)
316 elif round_dir == RoundDir.ERROR_IF_INEXACT:
317 if retval * retval * self != 1:
318 raise ValueError("inexact rsqrt")
319 else:
320 assert False, "unimplemented round_dir"
321 return retval
322
323
324 class ParamsNotAccurateEnough(Exception):
325 """raised when the parameters aren't accurate enough to have goldschmidt
326 division work."""
327
328
329 def _assert_accuracy(condition, msg="not accurate enough"):
330 if condition:
331 return
332 raise ParamsNotAccurateEnough(msg)
333
334
335 @dataclass(frozen=True, unsafe_hash=True)
336 class GoldschmidtDivParamsBase:
337 """parameters for a Goldschmidt division algorithm, excluding derived
338 parameters.
339 """
340
341 io_width: int
342 """bit-width of the input divisor and the result.
343 the input numerator is `2 * io_width`-bits wide.
344 """
345
346 extra_precision: int
347 """number of bits of additional precision used inside the algorithm."""
348
349 table_addr_bits: int
350 """the number of address bits used in the lookup-table."""
351
352 table_data_bits: int
353 """the number of data bits used in the lookup-table."""
354
355 iter_count: int
356 """the total number of iterations of the division algorithm's loop"""
357
358
359 @dataclass(frozen=True, unsafe_hash=True)
360 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
361 """parameters for a Goldschmidt division algorithm.
362 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
363 """
364
365 # tuple to be immutable, repr=False so repr() works for debugging even when
366 # __post_init__ hasn't finished running yet
367 table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
368 """the lookup-table"""
369
370 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
371 """the operations needed to perform the goldschmidt division algorithm."""
372
373 def _shrink_bound(self, bound, round_dir):
374 """prevent fractions from having huge numerators/denominators by
375 rounding to a `FixedPoint` and converting back to a `Fraction`.
376
377 This is intended only for values used to compute bounds, and not for
378 values that end up in the hardware.
379 """
380 assert isinstance(bound, (Fraction, int))
381 assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
382 "you shouldn't use that round_dir on bounds"
383 frac_wid = self.io_width * 4 + 100 # should be enough precision
384 fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
385 return fixed.as_fraction()
386
387 def _shrink_min(self, min_bound):
388 """prevent fractions used as minimum bounds from having huge
389 numerators/denominators by rounding down to a `FixedPoint` and
390 converting back to a `Fraction`.
391
392 This is intended only for values used to compute bounds, and not for
393 values that end up in the hardware.
394 """
395 return self._shrink_bound(min_bound, RoundDir.DOWN)
396
397 def _shrink_max(self, max_bound):
398 """prevent fractions used as maximum bounds from having huge
399 numerators/denominators by rounding up to a `FixedPoint` and
400 converting back to a `Fraction`.
401
402 This is intended only for values used to compute bounds, and not for
403 values that end up in the hardware.
404 """
405 return self._shrink_bound(max_bound, RoundDir.UP)
406
407 @property
408 def table_addr_count(self):
409 """number of distinct addresses in the lookup-table."""
410 # used while computing self.table, so can't just do len(self.table)
411 return 1 << self.table_addr_bits
412
413 def table_input_exact_range(self, addr):
414 """return the range of inputs as `Fraction`s used for the table entry
415 with address `addr`."""
416 assert isinstance(addr, int)
417 assert 0 <= addr < self.table_addr_count
418 _assert_accuracy(self.io_width >= self.table_addr_bits)
419 addr_shift = self.io_width - self.table_addr_bits
420 min_numerator = (1 << self.io_width) + (addr << addr_shift)
421 denominator = 1 << self.io_width
422 values_per_table_entry = 1 << addr_shift
423 max_numerator = min_numerator + values_per_table_entry - 1
424 min_input = Fraction(min_numerator, denominator)
425 max_input = Fraction(max_numerator, denominator)
426 min_input = self._shrink_min(min_input)
427 max_input = self._shrink_max(max_input)
428 assert 1 <= min_input <= max_input < 2
429 return min_input, max_input
430
431 def table_value_exact_range(self, addr):
432 """return the range of values as `Fraction`s used for the table entry
433 with address `addr`."""
434 min_input, max_input = self.table_input_exact_range(addr)
435 # division swaps min/max
436 min_value = 1 / max_input
437 max_value = 1 / min_input
438 min_value = self._shrink_min(min_value)
439 max_value = self._shrink_max(max_value)
440 assert 0.5 < min_value <= max_value <= 1
441 return min_value, max_value
442
443 def table_exact_value(self, index):
444 min_value, max_value = self.table_value_exact_range(index)
445 # we round down
446 return min_value
447
448 def __post_init__(self):
449 # called by the autogenerated __init__
450 _assert_accuracy(self.io_width >= 1, "io_width out of range")
451 _assert_accuracy(self.extra_precision >= 0,
452 "extra_precision out of range")
453 _assert_accuracy(self.table_addr_bits >= 1,
454 "table_addr_bits out of range")
455 _assert_accuracy(self.table_data_bits >= 1,
456 "table_data_bits out of range")
457 _assert_accuracy(self.iter_count >= 1, "iter_count out of range")
458 table = []
459 for addr in range(1 << self.table_addr_bits):
460 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
461 self.table_data_bits,
462 RoundDir.DOWN))
463 # we have to use object.__setattr__ since frozen=True
464 object.__setattr__(self, "table", tuple(table))
465 object.__setattr__(self, "ops", tuple(self.__make_ops()))
466
467 @property
468 def expanded_width(self):
469 """the total number of bits of precision used inside the algorithm."""
470 return self.io_width + self.extra_precision
471
472 @property
473 def n_d_f_int_wid(self):
474 """the number of bits in the integer part of `state.n`, `state.d`, and
475 `state.f` during the main iteration loop.
476 """
477 return 2
478
479 @property
480 def n_d_f_total_wid(self):
481 """the total number of bits (both integer and fraction bits) in
482 `state.n`, `state.d`, and `state.f` during the main iteration loop.
483 """
484 return self.n_d_f_int_wid + self.expanded_width
485
486 @cache_on_self
487 def max_neps(self, i):
488 """maximum value of `neps[i]`.
489 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
490 """
491 assert isinstance(i, int) and 0 <= i < self.iter_count
492 return Fraction(1, 1 << self.expanded_width)
493
494 @cache_on_self
495 def max_deps(self, i):
496 """maximum value of `deps[i]`.
497 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
498 """
499 assert isinstance(i, int) and 0 <= i < self.iter_count
500 return Fraction(1, 1 << self.expanded_width)
501
502 @cache_on_self
503 def max_feps(self, i):
504 """maximum value of `feps[i]`.
505 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
506 """
507 assert isinstance(i, int) and 0 <= i < self.iter_count
508 # zero, because the computation of `F_prime[i]` in
509 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
510 return Fraction(0)
511
512 @cached_property
513 def e0_range(self):
514 """minimum and maximum values of `e[0]`
515 (the relative error in `F_prime[-1]`)
516 """
517 min_e0 = Fraction(0)
518 max_e0 = Fraction(0)
519 for addr in range(self.table_addr_count):
520 # `F_prime[-1] = (1 - e[0]) / B`
521 # => `e[0] = 1 - B * F_prime[-1]`
522 min_b, max_b = self.table_input_exact_range(addr)
523 f_prime_m1 = self.table[addr].as_fraction()
524 assert min_b >= 0 and f_prime_m1 >= 0, \
525 "only positive quadrant of interval multiplication implemented"
526 min_product = min_b * f_prime_m1
527 max_product = max_b * f_prime_m1
528 # negation swaps min/max
529 cur_min_e0 = 1 - max_product
530 cur_max_e0 = 1 - min_product
531 min_e0 = min(min_e0, cur_min_e0)
532 max_e0 = max(max_e0, cur_max_e0)
533 min_e0 = self._shrink_min(min_e0)
534 max_e0 = self._shrink_max(max_e0)
535 return min_e0, max_e0
536
537 @cached_property
538 def min_e0(self):
539 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
540 """
541 min_e0, max_e0 = self.e0_range
542 return min_e0
543
544 @cached_property
545 def max_e0(self):
546 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
547 """
548 min_e0, max_e0 = self.e0_range
549 return max_e0
550
551 @cached_property
552 def max_abs_e0(self):
553 """maximum value of `abs(e[0])`."""
554 return max(abs(self.min_e0), abs(self.max_e0))
555
556 @cached_property
557 def min_abs_e0(self):
558 """minimum value of `abs(e[0])`."""
559 return Fraction(0)
560
561 @cache_on_self
562 def max_n(self, i):
563 """maximum value of `n[i]` (the relative error in `N_prime[i]`
564 relative to the previous iteration)
565 """
566 assert isinstance(i, int) and 0 <= i < self.iter_count
567 if i == 0:
568 # from Claim 10
569 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
570 # `n[0] <= 2 * neps[0] / (1 - e[0])`
571
572 assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
573 "only one quadrant of interval division implemented"
574 retval = 2 * self.max_neps(0) / (1 - self.max_e0)
575 elif i == 1:
576 # from Claim 10
577 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
578 min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
579 assert self.max_f(0) <= 1 and min_mpd >= 0, \
580 "only one quadrant of interval multiplication implemented"
581 prod = (1 - self.max_f(0)) * min_mpd
582 assert self.max_neps(1) >= 0 and prod > 0, \
583 "only one quadrant of interval division implemented"
584 retval = self.max_neps(1) / prod
585 else:
586 # from Claim 6
587 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
588 min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
589 assert self.max_neps(i) >= 0 and min_mpd > 0, \
590 "only one quadrant of interval division implemented"
591 retval = self.max_neps(i) / min_mpd
592
593 return self._shrink_max(retval)
594
595 @cache_on_self
596 def max_d(self, i):
597 """maximum value of `d[i]` (the relative error in `D_prime[i]`
598 relative to the previous iteration)
599 """
600 assert isinstance(i, int) and 0 <= i < self.iter_count
601 if i == 0:
602 # from Claim 10
603 # `d[0] = deps[0] / (1 - e[0])`
604
605 assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
606 "only one quadrant of interval division implemented"
607 retval = self.max_deps(0) / (1 - self.max_e0)
608 elif i == 1:
609 # from Claim 10
610 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
611 assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
612 "only one quadrant of interval multiplication implemented"
613 divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
614 assert self.max_deps(1) >= 0 and divisor > 0, \
615 "only one quadrant of interval division implemented"
616 retval = self.max_deps(1) / divisor
617 else:
618 # from Claim 6
619 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
620 assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
621 "only one quadrant of interval division implemented"
622 retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
623
624 return self._shrink_max(retval)
625
626 @cache_on_self
627 def max_f(self, i):
628 """maximum value of `f[i]` (the relative error in `F_prime[i]`
629 relative to the previous iteration)
630 """
631 assert isinstance(i, int) and 0 <= i < self.iter_count
632 if i == 0:
633 # from Claim 10
634 # `f[0] = feps[0] / (1 - delta[0])`
635
636 assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
637 "only one quadrant of interval division implemented"
638 retval = self.max_feps(0) / (1 - self.max_delta(0))
639 elif i == 1:
640 # from Claim 10
641 # `f[1] = feps[1]`
642 retval = self.max_feps(1)
643 else:
644 # from Claim 6
645 # `f[i] <= max_feps[i]`
646 retval = self.max_feps(i)
647
648 return self._shrink_max(retval)
649
650 @cache_on_self
651 def max_delta(self, i):
652 """ maximum value of `delta[i]`.
653 `delta[i]` is defined in Definition 4 of paper.
654 """
655 assert isinstance(i, int) and 0 <= i < self.iter_count
656 if i == 0:
657 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
658 retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
659 else:
660 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
661 prev_max_delta = self.max_delta(i - 1)
662 assert prev_max_delta >= 0
663 retval = prev_max_delta ** 2 + self.max_f(i - 1)
664
665 # `delta[i]` has to be smaller than one otherwise errors would go off
666 # to infinity
667 _assert_accuracy(retval < 1)
668
669 return self._shrink_max(retval)
670
671 @cache_on_self
672 def max_pi(self, i):
673 """ maximum value of `pi[i]`.
674 `pi[i]` is defined right below Theorem 5 of paper.
675 """
676 assert isinstance(i, int) and 0 <= i < self.iter_count
677 # `pi[i] = 1 - (1 - n[i]) * prod`
678 # where `prod` is the product of,
679 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
680 min_prod = Fraction(1)
681 for j in range(i):
682 max_n_j = self.max_n(j)
683 max_d_j = self.max_d(j)
684 assert max_n_j <= 1 and max_d_j > -1, \
685 "only one quadrant of interval division implemented"
686 min_prod *= (1 - max_n_j) / (1 + max_d_j)
687 max_n_i = self.max_n(i)
688 assert max_n_i <= 1 and min_prod >= 0, \
689 "only one quadrant of interval multiplication implemented"
690 retval = 1 - (1 - max_n_i) * min_prod
691 return self._shrink_max(retval)
692
693 @cached_property
694 def max_n_shift(self):
695 """ maximum value of `state.n_shift`.
696 """
697 # input numerator is `2*io_width`-bits
698 max_n = (1 << (self.io_width * 2)) - 1
699 max_n_shift = 0
700 # normalize so 1 <= n < 2
701 while max_n >= 2:
702 max_n >>= 1
703 max_n_shift += 1
704 return max_n_shift
705
706 @cached_property
707 def n_hat(self):
708 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
709 """
710 n_hat = Fraction(0)
711 for i in range(self.iter_count):
712 n_hat = max(n_hat, self.max_n(i), self.max_d(i))
713 return self._shrink_max(n_hat)
714
715 def __make_ops(self):
716 """ Goldschmidt division algorithm.
717
718 based on:
719 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
720 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
721 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
722
723 yields: GoldschmidtDivOp
724 the operations needed to perform the division.
725 """
726 # establish assumptions of the paper's error analysis (section 3.1):
727
728 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
729 yield GoldschmidtDivOp.Normalize
730
731 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
732 # the assumption is met by multipliers with > 4-bits precision
733 _assert_accuracy(self.expanded_width > 4)
734
735 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
736 _assert_accuracy(self.max_abs_e0 + 3 * self.max_d(0) / 2
737 + self.max_f(0) < Fraction(1, 2))
738
739 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
740 # (B is the denominator)
741
742 for addr in range(self.table_addr_count):
743 f_prime_m1 = self.table[addr]
744 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
745
746 yield GoldschmidtDivOp.FEqTableLookup
747
748 # we use Setting I (section 4.1 of the paper):
749 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
750 # the conditions on n_hat are satisfied by construction.
751 for i in range(self.iter_count):
752 _assert_accuracy(self.max_f(i) == 0)
753 yield GoldschmidtDivOp.MulNByF
754 if i != self.iter_count - 1:
755 yield GoldschmidtDivOp.MulDByF
756 yield GoldschmidtDivOp.FEq2MinusD
757
758 # relative approximation error `p(N_prime[i])`:
759 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
760 # `0 <= p(N_prime[i])`
761 # `p(N_prime[i]) <= (2 * i) * n_hat \`
762 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
763 i = self.iter_count - 1 # last used `i`
764 # compute power manually to prevent huge intermediate values
765 power = self._shrink_max(self.max_abs_e0 + 3 * self.n_hat / 2)
766 for _ in range(i):
767 power = self._shrink_max(power * power)
768
769 max_rel_error = (2 * i) * self.n_hat + power
770
771 min_a_over_b = Fraction(1, 2)
772 max_a_over_b = Fraction(2)
773 max_allowed_abs_error = max_a_over_b / (1 << self.max_n_shift)
774 max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
775
776 _assert_accuracy(max_rel_error < max_allowed_rel_error,
777 f"not accurate enough: max_rel_error={max_rel_error}"
778 f" max_allowed_rel_error={max_allowed_rel_error}")
779
780 yield GoldschmidtDivOp.CalcResult
781
782 @cache_on_self
783 def default_cost_fn(self):
784 """ calculate the estimated cost on an arbitrary scale of implementing
785 goldschmidt division with the specified parameters. larger cost
786 values mean worse parameters.
787
788 This is the default cost function for `GoldschmidtDivParams.get`.
789
790 returns: float
791 """
792 rom_cells = self.table_data_bits << self.table_addr_bits
793 cost = float(rom_cells)
794 for op in self.ops:
795 if op == GoldschmidtDivOp.MulNByF \
796 or op == GoldschmidtDivOp.MulDByF:
797 mul_cost = self.expanded_width ** 2
798 mul_cost *= self.expanded_width.bit_length()
799 cost += mul_cost
800 cost += 5e7 * self.iter_count
801 return cost
802
803 @staticmethod
804 @lru_cache(maxsize=1 << 16)
805 def __cached_new(base_params):
806 assert isinstance(base_params, GoldschmidtDivParamsBase)
807 # can't use dataclasses.asdict, since it's recursive and will also give
808 # child class fields too, which we don't want.
809 kwargs = {}
810 for field in fields(GoldschmidtDivParamsBase):
811 kwargs[field.name] = getattr(base_params, field.name)
812 try:
813 return GoldschmidtDivParams(**kwargs), None
814 except ParamsNotAccurateEnough as e:
815 return None, e
816
817 @staticmethod
818 def __raise(e): # type: (ParamsNotAccurateEnough) -> Any
819 raise e
820
821 @staticmethod
822 def cached_new(base_params, handle_error=__raise):
823 assert isinstance(base_params, GoldschmidtDivParamsBase)
824 params, error = GoldschmidtDivParams.__cached_new(base_params)
825 if error is None:
826 return params
827 else:
828 return handle_error(error)
829
830 @staticmethod
831 def get(io_width, cost_fn=default_cost_fn, max_table_addr_bits=12):
832 """ find efficient parameters for a goldschmidt division algorithm
833 with `params.io_width == io_width`.
834
835 arguments:
836 io_width: int
837 bit-width of the input divisor and the result.
838 the input numerator is `2 * io_width`-bits wide.
839 cost_fn: Callable[[GoldschmidtDivParams], float]
840 return the estimated cost on an arbitrary scale of implementing
841 goldschmidt division with the specified parameters. larger cost
842 values mean worse parameters.
843 max_table_addr_bits: int
844 maximum allowable value of `table_addr_bits`
845 """
846 assert isinstance(io_width, int) and io_width >= 1
847 assert callable(cost_fn)
848
849 last_error = None
850 last_error_params = None
851
852 def cached_new(base_params):
853 def handle_error(e):
854 nonlocal last_error, last_error_params
855 last_error = e
856 last_error_params = base_params
857 return None
858
859 retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
860 if retval is None:
861 logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
862 else:
863 logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
864 return retval
865
866 @lru_cache(maxsize=None)
867 def get_cost(base_params):
868 params = cached_new(base_params)
869 if params is None:
870 return math.inf
871 retval = cost_fn(params)
872 logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
873 return retval
874
875 # start with parameters big enough to always work.
876 initial_extra_precision = io_width * 2 + 4
877 initial_params = GoldschmidtDivParamsBase(
878 io_width=io_width,
879 extra_precision=initial_extra_precision,
880 table_addr_bits=min(max_table_addr_bits, io_width),
881 table_data_bits=io_width + initial_extra_precision,
882 iter_count=1 + io_width.bit_length())
883
884 if cached_new(initial_params) is None:
885 raise ValueError(f"initial goldschmidt division algorithm "
886 f"parameters are invalid: {initial_params}"
887 ) from last_error
888
889 # find good initial `iter_count`
890 params = initial_params
891 for iter_count in range(1, initial_params.iter_count):
892 trial_params = replace(params, iter_count=iter_count)
893 if cached_new(trial_params) is not None:
894 params = trial_params
895 break
896
897 # now find `table_addr_bits`
898 cost = get_cost(params)
899 for table_addr_bits in range(1, max_table_addr_bits):
900 trial_params = replace(params, table_addr_bits=table_addr_bits)
901 trial_cost = get_cost(trial_params)
902 if trial_cost < cost:
903 params = trial_params
904 cost = trial_cost
905 break
906
907 # check one higher `iter_count` to see if it has lower cost
908 for table_addr_bits in range(1, max_table_addr_bits + 1):
909 trial_params = replace(params,
910 table_addr_bits=table_addr_bits,
911 iter_count=params.iter_count + 1)
912 trial_cost = get_cost(trial_params)
913 if trial_cost < cost:
914 params = trial_params
915 cost = trial_cost
916 break
917
918 # now shrink `table_data_bits`
919 while True:
920 trial_params = replace(params,
921 table_data_bits=params.table_data_bits - 1)
922 trial_cost = get_cost(trial_params)
923 if trial_cost < cost:
924 params = trial_params
925 cost = trial_cost
926 else:
927 break
928
929 # and shrink `extra_precision`
930 while True:
931 trial_params = replace(params,
932 extra_precision=params.extra_precision - 1)
933 trial_cost = get_cost(trial_params)
934 if trial_cost < cost:
935 params = trial_params
936 cost = trial_cost
937 else:
938 break
939
940 retval = cached_new(params)
941 assert isinstance(retval, GoldschmidtDivParams)
942 return retval
943
944
945 def clz(v, wid):
946 """count leading zeros -- handy for debugging."""
947 assert isinstance(wid, int)
948 assert isinstance(v, int) and 0 <= v < (1 << wid)
949 return (1 << wid).bit_length() - v.bit_length()
950
951
952 @enum.unique
953 class GoldschmidtDivOp(enum.Enum):
954 Normalize = "n, d, n_shift = normalize(n, d)"
955 FEqTableLookup = "f = table_lookup(d)"
956 MulNByF = "n *= f"
957 MulDByF = "d *= f"
958 FEq2MinusD = "f = 2 - d"
959 CalcResult = "result = unnormalize_and_round(n)"
960
961 def run(self, params, state):
962 assert isinstance(params, GoldschmidtDivParams)
963 assert isinstance(state, GoldschmidtDivState)
964 expanded_width = params.expanded_width
965 table_addr_bits = params.table_addr_bits
966 if self == GoldschmidtDivOp.Normalize:
967 # normalize so 1 <= d < 2
968 # can easily be done with count-leading-zeros and left shift
969 while state.d < 1:
970 state.n = (state.n * 2).to_frac_wid(expanded_width)
971 state.d = (state.d * 2).to_frac_wid(expanded_width)
972
973 state.n_shift = 0
974 # normalize so 1 <= n < 2
975 while state.n >= 2:
976 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
977 state.n_shift += 1
978 elif self == GoldschmidtDivOp.FEqTableLookup:
979 # compute initial f by table lookup
980 d_m_1 = state.d - 1
981 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
982 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
983 state.f = params.table[d_m_1.bits]
984 state.f = state.f.to_frac_wid(expanded_width,
985 round_dir=RoundDir.DOWN)
986 elif self == GoldschmidtDivOp.MulNByF:
987 assert state.f is not None
988 n = state.n * state.f
989 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
990 elif self == GoldschmidtDivOp.MulDByF:
991 assert state.f is not None
992 d = state.d * state.f
993 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
994 elif self == GoldschmidtDivOp.FEq2MinusD:
995 state.f = (2 - state.d).to_frac_wid(expanded_width)
996 elif self == GoldschmidtDivOp.CalcResult:
997 assert state.n_shift is not None
998 # scale to correct value
999 n = state.n * (1 << state.n_shift)
1000
1001 state.quotient = math.floor(n)
1002 state.remainder = state.orig_n - state.quotient * state.orig_d
1003 if state.remainder >= state.orig_d:
1004 state.quotient += 1
1005 state.remainder -= state.orig_d
1006 else:
1007 assert False, f"unimplemented GoldschmidtDivOp: {self}"
1008
1009 def gen_hdl(self, params, state, sync_rom):
1010 """generate the hdl for this operation.
1011
1012 arguments:
1013 params: GoldschmidtDivParams
1014 the goldschmidt division parameters.
1015 state: GoldschmidtDivHDLState
1016 the input/output state
1017 sync_rom: bool
1018 true if the rom should be read synchronously rather than
1019 combinatorially, incurring an extra clock cycle of latency.
1020 """
1021 assert isinstance(params, GoldschmidtDivParams)
1022 assert isinstance(state, GoldschmidtDivHDLState)
1023 m = state.m
1024 if self == GoldschmidtDivOp.Normalize:
1025 # normalize so 1 <= d < 2
1026 assert state.d.width == params.io_width
1027 assert state.n.width == 2 * params.io_width
1028 d_leading_zeros = CLZ(params.io_width)
1029 m.submodules.d_leading_zeros = d_leading_zeros
1030 m.d.comb += d_leading_zeros.sig_in.eq(state.d)
1031 d_shift_out = Signal.like(state.d)
1032 m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
1033 d = Signal(params.n_d_f_total_wid)
1034 m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
1035 >> state.d.width)
1036
1037 # normalize so 1 <= n < 2
1038 n_leading_zeros = CLZ(2 * params.io_width)
1039 m.submodules.n_leading_zeros = n_leading_zeros
1040 m.d.comb += n_leading_zeros.sig_in.eq(state.n)
1041 signed_zero = Const(0, signed(1)) # force subtraction to be signed
1042 n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
1043 - n_leading_zeros.lz)
1044 n_shift_s = Signal.like(n_shift_s_v)
1045 n_shift_n_lz_out = Signal.like(state.n)
1046 n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
1047 m.d.comb += [
1048 n_shift_s.eq(n_shift_s_v),
1049 n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
1050 n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
1051 ]
1052 state.n_shift = Signal(d_leading_zeros.lz.width)
1053 n = Signal(params.n_d_f_total_wid)
1054 with m.If(n_shift_s < 0):
1055 m.d.comb += [
1056 state.n_shift.eq(0),
1057 n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
1058 >> state.d.width),
1059 ]
1060 with m.Else():
1061 m.d.comb += [
1062 state.n_shift.eq(n_shift_s),
1063 n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
1064 >> state.n.width),
1065 ]
1066 state.n = n
1067 state.d = d
1068 elif self == GoldschmidtDivOp.FEqTableLookup:
1069 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1070 # compute initial f by table lookup
1071
1072 # extra bit for table entries == 1.0
1073 table_width = 1 + params.table_data_bits
1074 table = Memory(width=table_width, depth=len(params.table),
1075 init=[i.bits for i in params.table])
1076 addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
1077 if sync_rom:
1078 table_read = table.read_port()
1079 m.d.comb += table_read.addr.eq(addr)
1080 state.insert_pipeline_register()
1081 else:
1082 table_read = table.read_port(domain="comb")
1083 m.d.comb += table_read.addr.eq(addr)
1084 m.submodules.table_read = table_read
1085 state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
1086 data_shift = params.expanded_width - params.table_data_bits
1087 m.d.comb += state.f.eq(table_read.data << data_shift)
1088 elif self == GoldschmidtDivOp.MulNByF:
1089 assert state.n.width == params.n_d_f_total_wid, "invalid n width"
1090 assert state.f is not None
1091 assert state.f.width == params.n_d_f_total_wid, "invalid f width"
1092 n = Signal.like(state.n)
1093 m.d.comb += n.eq((state.n * state.f) >> params.expanded_width)
1094 state.n = n
1095 elif self == GoldschmidtDivOp.MulDByF:
1096 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1097 assert state.f is not None
1098 assert state.f.width == params.n_d_f_total_wid, "invalid f width"
1099 d = Signal.like(state.d)
1100 d_times_f = Signal.like(state.d * state.f)
1101 m.d.comb += [
1102 d_times_f.eq(state.d * state.f),
1103 d.eq((d_times_f >> params.expanded_width)
1104 + (d_times_f[:params.expanded_width] != 0)),
1105 ]
1106 state.d = d
1107 elif self == GoldschmidtDivOp.FEq2MinusD:
1108 assert state.d.width == params.n_d_f_total_wid, "invalid d width"
1109 f = Signal.like(state.d)
1110 m.d.comb += f.eq((2 << params.expanded_width) - state.d)
1111 state.f = f
1112 elif self == GoldschmidtDivOp.CalcResult:
1113 assert state.n.width == params.n_d_f_total_wid, "invalid n width"
1114 assert state.n_shift is not None
1115 # scale to correct value
1116 n = state.n * (1 << state.n_shift)
1117 q_approx = Signal(params.io_width)
1118 # extra bit for if it's bigger than orig_d
1119 r_approx = Signal(params.io_width + 1)
1120 adjusted_r = Signal(signed(1 + params.io_width))
1121 m.d.comb += [
1122 q_approx.eq((state.n << state.n_shift)
1123 >> params.expanded_width),
1124 r_approx.eq(state.orig_n - q_approx * state.orig_d),
1125 adjusted_r.eq(r_approx - state.orig_d),
1126 ]
1127 state.quotient = Signal(params.io_width)
1128 state.remainder = Signal(params.io_width)
1129
1130 with m.If(adjusted_r >= 0):
1131 m.d.comb += [
1132 state.quotient.eq(q_approx + 1),
1133 state.remainder.eq(adjusted_r),
1134 ]
1135 with m.Else():
1136 m.d.comb += [
1137 state.quotient.eq(q_approx),
1138 state.remainder.eq(r_approx),
1139 ]
1140 else:
1141 assert False, f"unimplemented GoldschmidtDivOp: {self}"
1142
1143
1144 @dataclass
1145 class GoldschmidtDivState:
1146 orig_n: int
1147 """original numerator"""
1148
1149 orig_d: int
1150 """original denominator"""
1151
1152 n: FixedPoint
1153 """numerator -- N_prime[i] in the paper's algorithm 2"""
1154
1155 d: FixedPoint
1156 """denominator -- D_prime[i] in the paper's algorithm 2"""
1157
1158 f: "FixedPoint | None" = None
1159 """current factor -- F_prime[i] in the paper's algorithm 2"""
1160
1161 quotient: "int | None" = None
1162 """final quotient"""
1163
1164 remainder: "int | None" = None
1165 """final remainder"""
1166
1167 n_shift: "int | None" = None
1168 """amount the numerator needs to be left-shifted at the end of the
1169 algorithm.
1170 """
1171
1172 def __repr__(self):
1173 fields_str = []
1174 for field in fields(GoldschmidtDivState):
1175 value = getattr(self, field.name)
1176 if value is None:
1177 continue
1178 if isinstance(value, int) and field.name != "n_shift":
1179 fields_str.append(f"{field.name}={hex(value)}")
1180 else:
1181 fields_str.append(f"{field.name}={value!r}")
1182 return f"GoldschmidtDivState({', '.join(fields_str)})"
1183
1184
1185 def goldschmidt_div(n, d, params, trace=lambda state: None):
1186 """ Goldschmidt division algorithm.
1187
1188 based on:
1189 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1190 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1191 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1192
1193 arguments:
1194 n: int
1195 numerator. a `2*width`-bit unsigned integer.
1196 must be less than `d << width`, otherwise the quotient wouldn't
1197 fit in `width` bits.
1198 d: int
1199 denominator. a `width`-bit unsigned integer. must not be zero.
1200 width: int
1201 the bit-width of the inputs/outputs. must be a positive integer.
1202 trace: Function[[GoldschmidtDivState], None]
1203 called with the initial state and the state after executing each
1204 operation in `params.ops`.
1205
1206 returns: tuple[int, int]
1207 the quotient and remainder. a tuple of two `width`-bit unsigned
1208 integers.
1209 """
1210 assert isinstance(params, GoldschmidtDivParams)
1211 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
1212 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
1213
1214 # this whole algorithm is done with fixed-point arithmetic where values
1215 # have `width` fractional bits
1216
1217 state = GoldschmidtDivState(
1218 orig_n=n,
1219 orig_d=d,
1220 n=FixedPoint(n, params.io_width),
1221 d=FixedPoint(d, params.io_width),
1222 )
1223
1224 trace(state)
1225 for op in params.ops:
1226 op.run(params, state)
1227 trace(state)
1228
1229 assert state.quotient is not None
1230 assert state.remainder is not None
1231
1232 return state.quotient, state.remainder
1233
1234
1235 @dataclass(eq=False)
1236 class GoldschmidtDivHDLState:
1237 m: Module
1238 """The HDL Module"""
1239
1240 orig_n: Signal
1241 """original numerator"""
1242
1243 orig_d: Signal
1244 """original denominator"""
1245
1246 n: Signal
1247 """numerator -- N_prime[i] in the paper's algorithm 2"""
1248
1249 d: Signal
1250 """denominator -- D_prime[i] in the paper's algorithm 2"""
1251
1252 f: "Signal | None" = None
1253 """current factor -- F_prime[i] in the paper's algorithm 2"""
1254
1255 quotient: "Signal | None" = None
1256 """final quotient"""
1257
1258 remainder: "Signal | None" = None
1259 """final remainder"""
1260
1261 n_shift: "Signal | None" = None
1262 """amount the numerator needs to be left-shifted at the end of the
1263 algorithm.
1264 """
1265
1266 old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
1267 init=False)
1268
1269 __signal_name_prefix: "str" = field(default="state_", repr=False,
1270 init=False)
1271
1272 def __post_init__(self):
1273 # called by the autogenerated __init__
1274 self.old_signals = defaultdict(list)
1275
1276 def __setattr__(self, name, value):
1277 assert isinstance(name, str)
1278 if name.startswith("_"):
1279 return super().__setattr__(name, value)
1280 try:
1281 old_signals = self.old_signals[name]
1282 except AttributeError:
1283 # haven't yet finished __post_init__
1284 return super().__setattr__(name, value)
1285 assert name != "m" and name != "old_signals", f"can't write to {name}"
1286 assert isinstance(value, Signal)
1287 value.name = f"{self.__signal_name_prefix}{name}_{len(old_signals)}"
1288 old_signal = getattr(self, name, None)
1289 if old_signal is not None:
1290 assert isinstance(old_signal, Signal)
1291 old_signals.append(old_signal)
1292 return super().__setattr__(name, value)
1293
1294 def insert_pipeline_register(self):
1295 old_prefix = self.__signal_name_prefix
1296 try:
1297 for field in fields(GoldschmidtDivHDLState):
1298 if field.name.startswith("_") or field.name == "m":
1299 continue
1300 old_sig = getattr(self, field.name, None)
1301 if old_sig is None:
1302 continue
1303 assert isinstance(old_sig, Signal)
1304 new_sig = Signal.like(old_sig)
1305 setattr(self, field.name, new_sig)
1306 self.m.d.sync += new_sig.eq(old_sig)
1307 finally:
1308 self.__signal_name_prefix = old_prefix
1309
1310
1311 class GoldschmidtDivHDL(Elaboratable):
1312 """ Goldschmidt division algorithm.
1313
1314 based on:
1315 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1316 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1317 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1318
1319 attributes:
1320 params: GoldschmidtDivParams
1321 the goldschmidt division algorithm parameters.
1322 pipe_reg_indexes: list[int]
1323 the operation indexes where pipeline registers should be inserted.
1324 duplicate values mean multiple registers should be inserted for
1325 that operation index -- this is useful to allow yosys to spread a
1326 multiplication across those multiple pipeline stages.
1327 sync_rom: bool
1328 true if the rom should be read synchronously rather than
1329 combinatorially, incurring an extra clock cycle of latency.
1330 n: Signal(unsigned(2 * params.io_width))
1331 input numerator. a `2 * params.io_width`-bit unsigned integer.
1332 must be less than `d << params.io_width`, otherwise the quotient
1333 wouldn't fit in `params.io_width` bits.
1334 d: Signal(unsigned(params.io_width))
1335 input denominator. a `params.io_width`-bit unsigned integer.
1336 must not be zero.
1337 q: Signal(unsigned(params.io_width))
1338 output quotient. only valid when `n < (d << params.io_width)`.
1339 r: Signal(unsigned(params.io_width))
1340 output remainder. only valid when `n < (d << params.io_width)`.
1341 trace: list[GoldschmidtDivHDLState]
1342 list of the initial state and the state after executing each
1343 operation in `params.ops`.
1344 """
1345
1346 @property
1347 def total_pipeline_registers(self):
1348 """the total number of pipeline registers"""
1349 return len(self.pipe_reg_indexes) + self.sync_rom
1350
1351 def __init__(self, params, pipe_reg_indexes=(), sync_rom=False):
1352 assert isinstance(params, GoldschmidtDivParams)
1353 assert isinstance(sync_rom, bool)
1354 self.params = params
1355 self.pipe_reg_indexes = sorted(int(i) for i in pipe_reg_indexes)
1356 self.sync_rom = sync_rom
1357 self.n = Signal(unsigned(2 * params.io_width))
1358 self.d = Signal(unsigned(params.io_width))
1359 self.q = Signal(unsigned(params.io_width))
1360 self.r = Signal(unsigned(params.io_width))
1361
1362 # in constructor so we get trace without needing to call elaborate
1363 state = GoldschmidtDivHDLState(
1364 m=Module(),
1365 orig_n=self.n,
1366 orig_d=self.d,
1367 n=self.n,
1368 d=self.d)
1369
1370 self.trace = [replace(state)]
1371
1372 # copy and reverse
1373 pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
1374
1375 for op_index, op in enumerate(self.params.ops):
1376 while len(pipe_reg_indexes) > 0 \
1377 and pipe_reg_indexes[-1] <= op_index:
1378 pipe_reg_indexes.pop()
1379 state.insert_pipeline_register()
1380 op.gen_hdl(self.params, state, self.sync_rom)
1381 self.trace.append(replace(state))
1382
1383 while len(pipe_reg_indexes) > 0:
1384 pipe_reg_indexes.pop()
1385 state.insert_pipeline_register()
1386
1387 state.m.d.comb += [
1388 self.q.eq(state.quotient),
1389 self.r.eq(state.remainder),
1390 ]
1391
1392 def elaborate(self, platform):
1393 return self.trace[0].m
1394
1395
1396 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
1397
1398
1399 @lru_cache()
1400 def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
1401 """Generate the look-up table needed for Goldschmidt's square-root and
1402 reciprocal-square-root algorithm.
1403
1404 arguments:
1405 table_addr_bits: int
1406 the number of address bits for the look-up table.
1407 table_data_bits: int
1408 the number of data bits for the look-up table.
1409 """
1410 assert isinstance(table_addr_bits, int) and \
1411 table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1412 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1413 table = []
1414 table_len = 1 << table_addr_bits
1415 for addr in range(table_len):
1416 if addr == 0:
1417 value = FixedPoint(0, table_data_bits)
1418 elif (addr << 2) < table_len:
1419 value = None # table entries should be unused
1420 else:
1421 table_addr_frac_wid = table_addr_bits
1422 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1423 max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
1424 max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
1425 value = max_input_value.to_frac_wid(max_frac_wid)
1426 value = value.rsqrt(RoundDir.DOWN)
1427 value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
1428 table.append(value)
1429
1430 # tuple for immutability
1431 return tuple(table)
1432
1433 # FIXME: add code to calculate error bounds and check that the algorithm will
1434 # actually work (like in the goldschmidt division algorithm).
1435 # FIXME: add code to calculate a good set of parameters based on the error
1436 # bounds checking.
1437
1438
1439 def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
1440 table_addr_bits, table_data_bits, iter_count):
1441 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1442
1443 uses algorithm based on second method at:
1444 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1445
1446 arguments:
1447 radicand: FixedPoint(frac_wid=frac_wid)
1448 the input value to take the square-root and reciprocal-square-root of.
1449 io_width: int
1450 the number of bits in the input (`radicand`) and output values.
1451 frac_wid: int
1452 the number of fraction bits in the input (`radicand`) and output
1453 values.
1454 extra_precision: int
1455 the number of bits of internal extra precision.
1456 table_addr_bits: int
1457 the number of address bits for the look-up table.
1458 table_data_bits: int
1459 the number of data bits for the look-up table.
1460
1461 returns: tuple[FixedPoint, FixedPoint]
1462 the square-root and reciprocal-square-root, rounded down to the
1463 nearest representable value. If `radicand == 0`, then the
1464 reciprocal-square-root value returned is zero.
1465 """
1466 assert (isinstance(radicand, FixedPoint)
1467 and radicand.frac_wid == frac_wid
1468 and 0 <= radicand.bits < (1 << io_width))
1469 assert isinstance(io_width, int) and io_width >= 1
1470 assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
1471 assert isinstance(extra_precision, int) and extra_precision >= io_width
1472 assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
1473 assert isinstance(table_data_bits, int) and table_data_bits >= 1
1474 assert isinstance(iter_count, int) and iter_count >= 0
1475 expanded_frac_wid = frac_wid + extra_precision
1476 s = radicand.to_frac_wid(expanded_frac_wid)
1477 sqrt_rshift = extra_precision
1478 rsqrt_rshift = extra_precision
1479 while s != 0 and s < 1:
1480 s = (s * 4).to_frac_wid(expanded_frac_wid)
1481 sqrt_rshift += 1
1482 rsqrt_rshift -= 1
1483 while s >= 4:
1484 s = s.div(4, expanded_frac_wid)
1485 sqrt_rshift -= 1
1486 rsqrt_rshift += 1
1487 table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
1488 table_data_bits=table_data_bits)
1489 # core goldschmidt sqrt/rsqrt algorithm:
1490 # initial setup:
1491 table_addr_frac_wid = table_addr_bits
1492 table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1493 addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
1494 assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
1495 f = table[addr.bits]
1496 assert f is not None, "accessed invalid table entry"
1497 # use with_frac_wid to fix IDE type deduction
1498 f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
1499 x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1500 h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1501 for _ in range(iter_count):
1502 # iteration step:
1503 f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1504 x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1505 h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
1506 r = 2 * h
1507 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1508
1509 sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
1510 rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
1511
1512 next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
1513 if next_sqrt * next_sqrt <= radicand:
1514 sqrt = next_sqrt
1515
1516 next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
1517 if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
1518 rsqrt = next_rsqrt
1519 return sqrt, rsqrt