1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
, field
, fields
, replace
12 from fractions
import Fraction
13 from types
import FunctionType
14 from functools
import lru_cache
15 from nmigen
.hdl
.ast
import Signal
, unsigned
, signed
, Const
, Cat
16 from nmigen
.hdl
.dsl
import Module
, Elaboratable
17 from nmigen
.hdl
.mem
import Memory
18 from nmutil
.clz
import CLZ
21 from functools
import cached_property
23 from cached_property
import cached_property
25 # fix broken IDE type detection for cached_property
26 from typing
import TYPE_CHECKING
, Any
28 from functools
import cached_property
34 def cache_on_self(func
):
35 """like `functools.cached_property`, except for methods. unlike
36 `lru_cache` the cache is per-class instance rather than a global cache
39 assert isinstance(func
, FunctionType
), \
40 "non-plain methods are not supported"
42 cache_name
= func
.__name
__ + "__cache"
44 def wrapper(self
, *args
, **kwargs
):
45 # specifically access through `__dict__` to bypass frozen=True
46 cache
= self
.__dict
__.get(cache_name
, _NOT_FOUND
)
47 if cache
is _NOT_FOUND
:
48 self
.__dict
__[cache_name
] = cache
= {}
49 key
= (args
, *kwargs
.items())
50 retval
= cache
.get(key
, _NOT_FOUND
)
51 if retval
is _NOT_FOUND
:
52 retval
= func(self
, *args
, **kwargs
)
56 wrapper
.__doc
__ = func
.__doc
__
61 class RoundDir(enum
.Enum
):
64 NEAREST_TIES_UP
= enum
.auto()
65 ERROR_IF_INEXACT
= enum
.auto()
68 @dataclass(frozen
=True)
73 def __post_init__(self
):
74 # called by the autogenerated __init__
75 assert isinstance(self
.bits
, int)
76 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
80 """convert `value` to a fixed-point number with enough fractional
81 bits to preserve its value."""
82 if isinstance(value
, FixedPoint
):
84 if isinstance(value
, int):
85 return FixedPoint(value
, 0)
86 if isinstance(value
, str):
88 neg
= value
.startswith("-")
89 if neg
or value
.startswith("+"):
91 if value
.startswith(("0x", "0X")) and "." in value
:
101 raise ValueError("too many `.` in string")
106 if not digit
.isalnum():
107 raise ValueError("invalid hexadecimal digit")
109 bits |
= int("0x" + digit
, base
=16)
111 bits
= int(value
, base
=0)
115 return FixedPoint(bits
, frac_wid
)
117 if isinstance(value
, float):
118 n
, d
= value
.as_integer_ratio()
119 log2_d
= d
.bit_length() - 1
120 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
121 "fail with float being IEEE 754")
122 return FixedPoint(n
, log2_d
)
123 raise TypeError("can't convert type to FixedPoint")
126 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
127 """convert `value` to the nearest fixed-point number with `frac_wid`
128 fractional bits, rounding according to `round_dir`."""
129 assert isinstance(frac_wid
, int) and frac_wid
>= 0
130 assert isinstance(round_dir
, RoundDir
)
131 if isinstance(value
, Fraction
):
132 numerator
= value
.numerator
133 denominator
= value
.denominator
135 value
= FixedPoint
.cast(value
)
136 numerator
= value
.bits
137 denominator
= 1 << value
.frac_wid
139 numerator
= -numerator
140 denominator
= -denominator
141 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
142 if round_dir
== RoundDir
.DOWN
:
144 elif round_dir
== RoundDir
.UP
:
147 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
148 if remainder
* 2 >= denominator
:
150 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
152 raise ValueError("inexact conversion")
154 assert False, "unimplemented round_dir"
155 return FixedPoint(bits
, frac_wid
)
157 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
158 """convert to the nearest fixed-point number with `frac_wid`
159 fractional bits, rounding according to `round_dir`."""
160 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
163 # use truediv to get correct result even when bits
164 # and frac_wid are huge
165 return float(self
.bits
/ (1 << self
.frac_wid
))
167 def as_fraction(self
):
168 return Fraction(self
.bits
, 1 << self
.frac_wid
)
171 """compare self with rhs, returning a positive integer if self is
172 greater than rhs, zero if self is equal to rhs, and a negative integer
173 if self is less than rhs."""
174 rhs
= FixedPoint
.cast(rhs
)
175 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
176 lhs
= self
.to_frac_wid(common_frac_wid
)
177 rhs
= rhs
.to_frac_wid(common_frac_wid
)
178 return lhs
.bits
- rhs
.bits
180 def __eq__(self
, rhs
):
181 return self
.cmp(rhs
) == 0
183 def __ne__(self
, rhs
):
184 return self
.cmp(rhs
) != 0
186 def __gt__(self
, rhs
):
187 return self
.cmp(rhs
) > 0
189 def __lt__(self
, rhs
):
190 return self
.cmp(rhs
) < 0
192 def __ge__(self
, rhs
):
193 return self
.cmp(rhs
) >= 0
195 def __le__(self
, rhs
):
196 return self
.cmp(rhs
) <= 0
199 """return the fractional part of `self`.
200 that is `self - math.floor(self)`.
202 fract_mask
= (1 << self
.frac_wid
) - 1
203 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
207 return "-" + str(-self
)
209 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
210 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
211 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
212 return hex(math
.floor(self
)) + "." + frac_str
215 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
217 def __add__(self
, rhs
):
218 rhs
= FixedPoint
.cast(rhs
)
219 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
220 lhs
= self
.to_frac_wid(common_frac_wid
)
221 rhs
= rhs
.to_frac_wid(common_frac_wid
)
222 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
224 def __radd__(self
, lhs
):
226 return self
.__add
__(lhs
)
229 return FixedPoint(-self
.bits
, self
.frac_wid
)
231 def __sub__(self
, rhs
):
232 rhs
= FixedPoint
.cast(rhs
)
233 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
234 lhs
= self
.to_frac_wid(common_frac_wid
)
235 rhs
= rhs
.to_frac_wid(common_frac_wid
)
236 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
238 def __rsub__(self
, lhs
):
240 return -self
.__sub
__(lhs
)
242 def __mul__(self
, rhs
):
243 rhs
= FixedPoint
.cast(rhs
)
244 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
246 def __rmul__(self
, lhs
):
248 return self
.__mul
__(lhs
)
251 return self
.bits
>> self
.frac_wid
253 def div(self
, rhs
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
254 assert isinstance(frac_wid
, int) and frac_wid
>= 0
255 assert isinstance(round_dir
, RoundDir
)
256 rhs
= FixedPoint
.cast(rhs
)
257 return FixedPoint
.with_frac_wid(self
.as_fraction()
261 def sqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
262 assert isinstance(round_dir
, RoundDir
)
264 raise ValueError("can't compute sqrt of negative number")
267 retval
= FixedPoint(0, self
.frac_wid
)
268 int_part_wid
= self
.bits
.bit_length() - self
.frac_wid
269 first_bit_index
= -(-int_part_wid
// 2) # division rounds up
270 last_bit_index
= -self
.frac_wid
271 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
272 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
274 if trial
* trial
<= self
:
276 if round_dir
== RoundDir
.DOWN
:
278 elif round_dir
== RoundDir
.UP
:
279 if retval
* retval
< self
:
280 retval
+= FixedPoint(1, self
.frac_wid
)
281 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
282 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
283 if half_way
* half_way
<= self
:
284 retval
+= FixedPoint(1, self
.frac_wid
)
285 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
286 if retval
* retval
!= self
:
287 raise ValueError("inexact sqrt")
289 assert False, "unimplemented round_dir"
292 def rsqrt(self
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
293 """compute the reciprocal-sqrt of `self`"""
294 assert isinstance(round_dir
, RoundDir
)
296 raise ValueError("can't compute rsqrt of negative number")
298 raise ZeroDivisionError("can't compute rsqrt of zero")
299 retval
= FixedPoint(0, self
.frac_wid
)
300 first_bit_index
= -(-self
.frac_wid
// 2) # division rounds up
301 last_bit_index
= -self
.frac_wid
302 for bit_index
in range(first_bit_index
, last_bit_index
- 1, -1):
303 trial
= retval
+ FixedPoint(1 << (bit_index
+ self
.frac_wid
),
305 if trial
* trial
* self
<= 1:
307 if round_dir
== RoundDir
.DOWN
:
309 elif round_dir
== RoundDir
.UP
:
310 if retval
* retval
* self
< 1:
311 retval
+= FixedPoint(1, self
.frac_wid
)
312 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
313 half_way
= retval
+ FixedPoint(1, self
.frac_wid
+ 1)
314 if half_way
* half_way
* self
<= 1:
315 retval
+= FixedPoint(1, self
.frac_wid
)
316 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
317 if retval
* retval
* self
!= 1:
318 raise ValueError("inexact rsqrt")
320 assert False, "unimplemented round_dir"
324 class ParamsNotAccurateEnough(Exception):
325 """raised when the parameters aren't accurate enough to have goldschmidt
329 def _assert_accuracy(condition
, msg
="not accurate enough"):
332 raise ParamsNotAccurateEnough(msg
)
335 @dataclass(frozen
=True, unsafe_hash
=True)
336 class GoldschmidtDivParamsBase
:
337 """parameters for a Goldschmidt division algorithm, excluding derived
342 """bit-width of the input divisor and the result.
343 the input numerator is `2 * io_width`-bits wide.
347 """number of bits of additional precision used inside the algorithm."""
350 """the number of address bits used in the lookup-table."""
353 """the number of data bits used in the lookup-table."""
356 """the total number of iterations of the division algorithm's loop"""
359 @dataclass(frozen
=True, unsafe_hash
=True)
360 class GoldschmidtDivParams(GoldschmidtDivParamsBase
):
361 """parameters for a Goldschmidt division algorithm.
362 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
365 # tuple to be immutable, repr=False so repr() works for debugging even when
366 # __post_init__ hasn't finished running yet
367 table
: "tuple[FixedPoint, ...]" = field(init
=False, repr=False)
368 """the lookup-table"""
370 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False, repr=False)
371 """the operations needed to perform the goldschmidt division algorithm."""
373 def _shrink_bound(self
, bound
, round_dir
):
374 """prevent fractions from having huge numerators/denominators by
375 rounding to a `FixedPoint` and converting back to a `Fraction`.
377 This is intended only for values used to compute bounds, and not for
378 values that end up in the hardware.
380 assert isinstance(bound
, (Fraction
, int))
381 assert round_dir
is RoundDir
.DOWN
or round_dir
is RoundDir
.UP
, \
382 "you shouldn't use that round_dir on bounds"
383 frac_wid
= self
.io_width
* 4 + 100 # should be enough precision
384 fixed
= FixedPoint
.with_frac_wid(bound
, frac_wid
, round_dir
)
385 return fixed
.as_fraction()
387 def _shrink_min(self
, min_bound
):
388 """prevent fractions used as minimum bounds from having huge
389 numerators/denominators by rounding down to a `FixedPoint` and
390 converting back to a `Fraction`.
392 This is intended only for values used to compute bounds, and not for
393 values that end up in the hardware.
395 return self
._shrink
_bound
(min_bound
, RoundDir
.DOWN
)
397 def _shrink_max(self
, max_bound
):
398 """prevent fractions used as maximum bounds from having huge
399 numerators/denominators by rounding up to a `FixedPoint` and
400 converting back to a `Fraction`.
402 This is intended only for values used to compute bounds, and not for
403 values that end up in the hardware.
405 return self
._shrink
_bound
(max_bound
, RoundDir
.UP
)
408 def table_addr_count(self
):
409 """number of distinct addresses in the lookup-table."""
410 # used while computing self.table, so can't just do len(self.table)
411 return 1 << self
.table_addr_bits
413 def table_input_exact_range(self
, addr
):
414 """return the range of inputs as `Fraction`s used for the table entry
415 with address `addr`."""
416 assert isinstance(addr
, int)
417 assert 0 <= addr
< self
.table_addr_count
418 _assert_accuracy(self
.io_width
>= self
.table_addr_bits
)
419 addr_shift
= self
.io_width
- self
.table_addr_bits
420 min_numerator
= (1 << self
.io_width
) + (addr
<< addr_shift
)
421 denominator
= 1 << self
.io_width
422 values_per_table_entry
= 1 << addr_shift
423 max_numerator
= min_numerator
+ values_per_table_entry
- 1
424 min_input
= Fraction(min_numerator
, denominator
)
425 max_input
= Fraction(max_numerator
, denominator
)
426 min_input
= self
._shrink
_min
(min_input
)
427 max_input
= self
._shrink
_max
(max_input
)
428 assert 1 <= min_input
<= max_input
< 2
429 return min_input
, max_input
431 def table_value_exact_range(self
, addr
):
432 """return the range of values as `Fraction`s used for the table entry
433 with address `addr`."""
434 min_input
, max_input
= self
.table_input_exact_range(addr
)
435 # division swaps min/max
436 min_value
= 1 / max_input
437 max_value
= 1 / min_input
438 min_value
= self
._shrink
_min
(min_value
)
439 max_value
= self
._shrink
_max
(max_value
)
440 assert 0.5 < min_value
<= max_value
<= 1
441 return min_value
, max_value
443 def table_exact_value(self
, index
):
444 min_value
, max_value
= self
.table_value_exact_range(index
)
448 def __post_init__(self
):
449 # called by the autogenerated __init__
450 _assert_accuracy(self
.io_width
>= 1, "io_width out of range")
451 _assert_accuracy(self
.extra_precision
>= 0,
452 "extra_precision out of range")
453 _assert_accuracy(self
.table_addr_bits
>= 1,
454 "table_addr_bits out of range")
455 _assert_accuracy(self
.table_data_bits
>= 1,
456 "table_data_bits out of range")
457 _assert_accuracy(self
.iter_count
>= 1, "iter_count out of range")
459 for addr
in range(1 << self
.table_addr_bits
):
460 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
461 self
.table_data_bits
,
463 # we have to use object.__setattr__ since frozen=True
464 object.__setattr
__(self
, "table", tuple(table
))
465 object.__setattr
__(self
, "ops", tuple(self
.__make
_ops
()))
468 def expanded_width(self
):
469 """the total number of bits of precision used inside the algorithm."""
470 return self
.io_width
+ self
.extra_precision
473 def n_d_f_int_wid(self
):
474 """the number of bits in the integer part of `state.n`, `state.d`, and
475 `state.f` during the main iteration loop.
480 def n_d_f_total_wid(self
):
481 """the total number of bits (both integer and fraction bits) in
482 `state.n`, `state.d`, and `state.f` during the main iteration loop.
484 return self
.n_d_f_int_wid
+ self
.expanded_width
487 def max_neps(self
, i
):
488 """maximum value of `neps[i]`.
489 `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
491 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
492 return Fraction(1, 1 << self
.expanded_width
)
495 def max_deps(self
, i
):
496 """maximum value of `deps[i]`.
497 `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
499 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
500 return Fraction(1, 1 << self
.expanded_width
)
503 def max_feps(self
, i
):
504 """maximum value of `feps[i]`.
505 `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
507 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
508 # zero, because the computation of `F_prime[i]` in
509 # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
514 """minimum and maximum values of `e[0]`
515 (the relative error in `F_prime[-1]`)
519 for addr
in range(self
.table_addr_count
):
520 # `F_prime[-1] = (1 - e[0]) / B`
521 # => `e[0] = 1 - B * F_prime[-1]`
522 min_b
, max_b
= self
.table_input_exact_range(addr
)
523 f_prime_m1
= self
.table
[addr
].as_fraction()
524 assert min_b
>= 0 and f_prime_m1
>= 0, \
525 "only positive quadrant of interval multiplication implemented"
526 min_product
= min_b
* f_prime_m1
527 max_product
= max_b
* f_prime_m1
528 # negation swaps min/max
529 cur_min_e0
= 1 - max_product
530 cur_max_e0
= 1 - min_product
531 min_e0
= min(min_e0
, cur_min_e0
)
532 max_e0
= max(max_e0
, cur_max_e0
)
533 min_e0
= self
._shrink
_min
(min_e0
)
534 max_e0
= self
._shrink
_max
(max_e0
)
535 return min_e0
, max_e0
539 """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
541 min_e0
, max_e0
= self
.e0_range
546 """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
548 min_e0
, max_e0
= self
.e0_range
552 def max_abs_e0(self
):
553 """maximum value of `abs(e[0])`."""
554 return max(abs(self
.min_e0
), abs(self
.max_e0
))
557 def min_abs_e0(self
):
558 """minimum value of `abs(e[0])`."""
563 """maximum value of `n[i]` (the relative error in `N_prime[i]`
564 relative to the previous iteration)
566 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
569 # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
570 # `n[0] <= 2 * neps[0] / (1 - e[0])`
572 assert self
.max_e0
< 1 and self
.max_neps(0) >= 0, \
573 "only one quadrant of interval division implemented"
574 retval
= 2 * self
.max_neps(0) / (1 - self
.max_e0
)
577 # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
578 min_mpd
= 1 - self
.max_pi(0) - self
.max_delta(0)
579 assert self
.max_f(0) <= 1 and min_mpd
>= 0, \
580 "only one quadrant of interval multiplication implemented"
581 prod
= (1 - self
.max_f(0)) * min_mpd
582 assert self
.max_neps(1) >= 0 and prod
> 0, \
583 "only one quadrant of interval division implemented"
584 retval
= self
.max_neps(1) / prod
587 # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
588 min_mpd
= 1 - self
.max_pi(i
- 1) - self
.max_delta(i
- 1)
589 assert self
.max_neps(i
) >= 0 and min_mpd
> 0, \
590 "only one quadrant of interval division implemented"
591 retval
= self
.max_neps(i
) / min_mpd
593 return self
._shrink
_max
(retval
)
597 """maximum value of `d[i]` (the relative error in `D_prime[i]`
598 relative to the previous iteration)
600 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
603 # `d[0] = deps[0] / (1 - e[0])`
605 assert self
.max_e0
< 1 and self
.max_deps(0) >= 0, \
606 "only one quadrant of interval division implemented"
607 retval
= self
.max_deps(0) / (1 - self
.max_e0
)
610 # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
611 assert self
.max_f(0) <= 1 and self
.max_delta(0) <= 1, \
612 "only one quadrant of interval multiplication implemented"
613 divisor
= (1 - self
.max_f(0)) * (1 - self
.max_delta(0) ** 2)
614 assert self
.max_deps(1) >= 0 and divisor
> 0, \
615 "only one quadrant of interval division implemented"
616 retval
= self
.max_deps(1) / divisor
619 # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
620 assert self
.max_deps(i
) >= 0 and self
.max_delta(i
- 1) < 1, \
621 "only one quadrant of interval division implemented"
622 retval
= self
.max_deps(i
) / (1 - self
.max_delta(i
- 1))
624 return self
._shrink
_max
(retval
)
628 """maximum value of `f[i]` (the relative error in `F_prime[i]`
629 relative to the previous iteration)
631 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
634 # `f[0] = feps[0] / (1 - delta[0])`
636 assert self
.max_delta(0) < 1 and self
.max_feps(0) >= 0, \
637 "only one quadrant of interval division implemented"
638 retval
= self
.max_feps(0) / (1 - self
.max_delta(0))
642 retval
= self
.max_feps(1)
645 # `f[i] <= max_feps[i]`
646 retval
= self
.max_feps(i
)
648 return self
._shrink
_max
(retval
)
651 def max_delta(self
, i
):
652 """ maximum value of `delta[i]`.
653 `delta[i]` is defined in Definition 4 of paper.
655 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
657 # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
658 retval
= self
.max_abs_e0
+ Fraction(3, 2) * self
.max_d(0)
660 # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
661 prev_max_delta
= self
.max_delta(i
- 1)
662 assert prev_max_delta
>= 0
663 retval
= prev_max_delta
** 2 + self
.max_f(i
- 1)
665 # `delta[i]` has to be smaller than one otherwise errors would go off
667 _assert_accuracy(retval
< 1)
669 return self
._shrink
_max
(retval
)
673 """ maximum value of `pi[i]`.
674 `pi[i]` is defined right below Theorem 5 of paper.
676 assert isinstance(i
, int) and 0 <= i
< self
.iter_count
677 # `pi[i] = 1 - (1 - n[i]) * prod`
678 # where `prod` is the product of,
679 # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
680 min_prod
= Fraction(1)
682 max_n_j
= self
.max_n(j
)
683 max_d_j
= self
.max_d(j
)
684 assert max_n_j
<= 1 and max_d_j
> -1, \
685 "only one quadrant of interval division implemented"
686 min_prod
*= (1 - max_n_j
) / (1 + max_d_j
)
687 max_n_i
= self
.max_n(i
)
688 assert max_n_i
<= 1 and min_prod
>= 0, \
689 "only one quadrant of interval multiplication implemented"
690 retval
= 1 - (1 - max_n_i
) * min_prod
691 return self
._shrink
_max
(retval
)
694 def max_n_shift(self
):
695 """ maximum value of `state.n_shift`.
697 # numerator must be less than `denominator << self.io_width`, so
698 # `n_shift` is at most `self.io_width`
703 """ maximum value of, for all `i`, `max_n(i)` and `max_d(i)`
706 for i
in range(self
.iter_count
):
707 n_hat
= max(n_hat
, self
.max_n(i
), self
.max_d(i
))
708 return self
._shrink
_max
(n_hat
)
710 def __make_ops(self
):
711 """ Goldschmidt division algorithm.
714 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
715 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
716 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
718 yields: GoldschmidtDivOp
719 the operations needed to perform the division.
721 # establish assumptions of the paper's error analysis (section 3.1):
723 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
724 yield GoldschmidtDivOp
.Normalize
726 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
727 # the assumption is met by multipliers with > 4-bits precision
728 _assert_accuracy(self
.expanded_width
> 4)
730 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
731 _assert_accuracy(self
.max_abs_e0
+ 3 * self
.max_d(0) / 2
732 + self
.max_f(0) < Fraction(1, 2))
734 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
735 # (B is the denominator)
737 for addr
in range(self
.table_addr_count
):
738 f_prime_m1
= self
.table
[addr
]
739 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
741 yield GoldschmidtDivOp
.FEqTableLookup
743 # we use Setting I (section 4.1 of the paper):
744 # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`:
745 # the conditions on n_hat are satisfied by construction.
746 for i
in range(self
.iter_count
):
747 _assert_accuracy(self
.max_f(i
) == 0)
748 yield GoldschmidtDivOp
.MulNByF
749 if i
!= self
.iter_count
- 1:
750 yield GoldschmidtDivOp
.MulDByF
751 yield GoldschmidtDivOp
.FEq2MinusD
753 # relative approximation error `p(N_prime[i])`:
754 # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
755 # `0 <= p(N_prime[i])`
756 # `p(N_prime[i]) <= (2 * i) * n_hat \`
757 # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
758 i
= self
.iter_count
- 1 # last used `i`
759 # compute power manually to prevent huge intermediate values
760 power
= self
._shrink
_max
(self
.max_abs_e0
+ 3 * self
.n_hat
/ 2)
762 power
= self
._shrink
_max
(power
* power
)
764 max_rel_error
= (2 * i
) * self
.n_hat
+ power
766 min_a_over_b
= Fraction(1, 2)
767 min_abs_error_for_correctness
= min_a_over_b
/ (1 << self
.max_n_shift
)
768 min_rel_error_for_correctness
= (min_abs_error_for_correctness
772 max_rel_error
< min_rel_error_for_correctness
,
773 f
"not accurate enough: max_rel_error={max_rel_error}"
774 f
" min_rel_error_for_correctness={min_rel_error_for_correctness}")
776 yield GoldschmidtDivOp
.CalcResult
779 def default_cost_fn(self
):
780 """ calculate the estimated cost on an arbitrary scale of implementing
781 goldschmidt division with the specified parameters. larger cost
782 values mean worse parameters.
784 This is the default cost function for `GoldschmidtDivParams.get`.
788 rom_cells
= self
.table_data_bits
<< self
.table_addr_bits
789 cost
= float(rom_cells
)
791 if op
== GoldschmidtDivOp
.MulNByF \
792 or op
== GoldschmidtDivOp
.MulDByF
:
793 mul_cost
= self
.expanded_width
** 2
794 mul_cost
*= self
.expanded_width
.bit_length()
796 cost
+= 5e7
* self
.iter_count
800 @lru_cache(maxsize
=1 << 16)
801 def __cached_new(base_params
):
802 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
803 # can't use dataclasses.asdict, since it's recursive and will also give
804 # child class fields too, which we don't want.
806 for field
in fields(GoldschmidtDivParamsBase
):
807 kwargs
[field
.name
] = getattr(base_params
, field
.name
)
809 return GoldschmidtDivParams(**kwargs
), None
810 except ParamsNotAccurateEnough
as e
:
814 def __raise(e
): # type: (ParamsNotAccurateEnough) -> Any
818 def cached_new(base_params
, handle_error
=__raise
):
819 assert isinstance(base_params
, GoldschmidtDivParamsBase
)
820 params
, error
= GoldschmidtDivParams
.__cached
_new
(base_params
)
824 return handle_error(error
)
827 def get(io_width
, cost_fn
=default_cost_fn
, max_table_addr_bits
=12):
828 """ find efficient parameters for a goldschmidt division algorithm
829 with `params.io_width == io_width`.
833 bit-width of the input divisor and the result.
834 the input numerator is `2 * io_width`-bits wide.
835 cost_fn: Callable[[GoldschmidtDivParams], float]
836 return the estimated cost on an arbitrary scale of implementing
837 goldschmidt division with the specified parameters. larger cost
838 values mean worse parameters.
839 max_table_addr_bits: int
840 maximum allowable value of `table_addr_bits`
842 assert isinstance(io_width
, int) and io_width
>= 1
843 assert callable(cost_fn
)
846 last_error_params
= None
848 def cached_new(base_params
):
850 nonlocal last_error
, last_error_params
852 last_error_params
= base_params
855 retval
= GoldschmidtDivParams
.cached_new(base_params
, handle_error
)
857 logging
.debug(f
"GoldschmidtDivParams.get: err: {base_params}")
859 logging
.debug(f
"GoldschmidtDivParams.get: ok: {base_params}")
862 @lru_cache(maxsize
=None)
863 def get_cost(base_params
):
864 params
= cached_new(base_params
)
867 retval
= cost_fn(params
)
868 logging
.debug(f
"GoldschmidtDivParams.get: cost={retval}: {params}")
871 # start with parameters big enough to always work.
872 initial_extra_precision
= io_width
* 2 + 4
873 initial_params
= GoldschmidtDivParamsBase(
875 extra_precision
=initial_extra_precision
,
876 table_addr_bits
=min(max_table_addr_bits
, io_width
),
877 table_data_bits
=io_width
+ initial_extra_precision
,
878 iter_count
=1 + io_width
.bit_length())
880 if cached_new(initial_params
) is None:
881 raise ValueError(f
"initial goldschmidt division algorithm "
882 f
"parameters are invalid: {initial_params}"
885 # find good initial `iter_count`
886 params
= initial_params
887 for iter_count
in range(1, initial_params
.iter_count
):
888 trial_params
= replace(params
, iter_count
=iter_count
)
889 if cached_new(trial_params
) is not None:
890 params
= trial_params
893 # now find `table_addr_bits`
894 cost
= get_cost(params
)
895 for table_addr_bits
in range(1, max_table_addr_bits
):
896 trial_params
= replace(params
, table_addr_bits
=table_addr_bits
)
897 trial_cost
= get_cost(trial_params
)
898 if trial_cost
< cost
:
899 params
= trial_params
903 # check one higher `iter_count` to see if it has lower cost
904 for table_addr_bits
in range(1, max_table_addr_bits
+ 1):
905 trial_params
= replace(params
,
906 table_addr_bits
=table_addr_bits
,
907 iter_count
=params
.iter_count
+ 1)
908 trial_cost
= get_cost(trial_params
)
909 if trial_cost
< cost
:
910 params
= trial_params
914 # now shrink `table_data_bits`
916 trial_params
= replace(params
,
917 table_data_bits
=params
.table_data_bits
- 1)
918 trial_cost
= get_cost(trial_params
)
919 if trial_cost
< cost
:
920 params
= trial_params
925 # and shrink `extra_precision`
927 trial_params
= replace(params
,
928 extra_precision
=params
.extra_precision
- 1)
929 trial_cost
= get_cost(trial_params
)
930 if trial_cost
< cost
:
931 params
= trial_params
936 retval
= cached_new(params
)
937 assert isinstance(retval
, GoldschmidtDivParams
)
942 """count leading zeros -- handy for debugging."""
943 assert isinstance(wid
, int)
944 assert isinstance(v
, int) and 0 <= v
< (1 << wid
)
945 return (1 << wid
).bit_length() - v
.bit_length()
949 class GoldschmidtDivOp(enum
.Enum
):
950 Normalize
= "n, d, n_shift = normalize(n, d)"
951 FEqTableLookup
= "f = table_lookup(d)"
954 FEq2MinusD
= "f = 2 - d"
955 CalcResult
= "result = unnormalize_and_round(n)"
957 def run(self
, params
, state
):
958 assert isinstance(params
, GoldschmidtDivParams
)
959 assert isinstance(state
, GoldschmidtDivState
)
960 expanded_width
= params
.expanded_width
961 table_addr_bits
= params
.table_addr_bits
962 if self
== GoldschmidtDivOp
.Normalize
:
963 # normalize so 1 <= d < 2
964 # can easily be done with count-leading-zeros and left shift
966 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
967 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
970 # normalize so 1 <= n < 2
972 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
,
973 round_dir
=RoundDir
.DOWN
)
975 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
976 # compute initial f by table lookup
978 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
979 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
980 state
.f
= params
.table
[d_m_1
.bits
]
981 state
.f
= state
.f
.to_frac_wid(expanded_width
,
982 round_dir
=RoundDir
.DOWN
)
983 elif self
== GoldschmidtDivOp
.MulNByF
:
984 assert state
.f
is not None
985 n
= state
.n
* state
.f
986 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
987 elif self
== GoldschmidtDivOp
.MulDByF
:
988 assert state
.f
is not None
989 d
= state
.d
* state
.f
990 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
991 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
992 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
993 elif self
== GoldschmidtDivOp
.CalcResult
:
994 assert state
.n_shift
is not None
995 # scale to correct value
996 n
= state
.n
* (1 << state
.n_shift
)
998 state
.quotient
= math
.floor(n
)
999 state
.remainder
= state
.orig_n
- state
.quotient
* state
.orig_d
1000 if state
.remainder
>= state
.orig_d
:
1002 state
.remainder
-= state
.orig_d
1004 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
1006 def gen_hdl(self
, params
, state
, sync_rom
):
1007 """generate the hdl for this operation.
1010 params: GoldschmidtDivParams
1011 the goldschmidt division parameters.
1012 state: GoldschmidtDivHDLState
1013 the input/output state
1015 true if the rom should be read synchronously rather than
1016 combinatorially, incurring an extra clock cycle of latency.
1018 assert isinstance(params
, GoldschmidtDivParams
)
1019 assert isinstance(state
, GoldschmidtDivHDLState
)
1021 if self
== GoldschmidtDivOp
.Normalize
:
1022 # normalize so 1 <= d < 2
1023 assert state
.d
.width
== params
.io_width
1024 assert state
.n
.width
== 2 * params
.io_width
1025 d_leading_zeros
= CLZ(params
.io_width
)
1026 m
.submodules
.d_leading_zeros
= d_leading_zeros
1027 m
.d
.comb
+= d_leading_zeros
.sig_in
.eq(state
.d
)
1028 d_shift_out
= Signal
.like(state
.d
)
1029 m
.d
.comb
+= d_shift_out
.eq(state
.d
<< d_leading_zeros
.lz
)
1030 d
= Signal(params
.n_d_f_total_wid
)
1031 m
.d
.comb
+= d
.eq((d_shift_out
<< (1 + params
.expanded_width
))
1034 # normalize so 1 <= n < 2
1035 n_leading_zeros
= CLZ(2 * params
.io_width
)
1036 m
.submodules
.n_leading_zeros
= n_leading_zeros
1037 m
.d
.comb
+= n_leading_zeros
.sig_in
.eq(state
.n
)
1038 signed_zero
= Const(0, signed(1)) # force subtraction to be signed
1039 n_shift_s_v
= (params
.io_width
+ signed_zero
+ d_leading_zeros
.lz
1040 - n_leading_zeros
.lz
)
1041 n_shift_s
= Signal
.like(n_shift_s_v
)
1042 n_shift_n_lz_out
= Signal
.like(state
.n
)
1043 n_shift_d_lz_out
= Signal
.like(state
.n
<< d_leading_zeros
.lz
)
1045 n_shift_s
.eq(n_shift_s_v
),
1046 n_shift_d_lz_out
.eq(state
.n
<< d_leading_zeros
.lz
),
1047 n_shift_n_lz_out
.eq(state
.n
<< n_leading_zeros
.lz
),
1049 state
.n_shift
= Signal(d_leading_zeros
.lz
.width
)
1050 n
= Signal(params
.n_d_f_total_wid
)
1051 with m
.If(n_shift_s
< 0):
1053 state
.n_shift
.eq(0),
1054 n
.eq((n_shift_d_lz_out
<< (1 + params
.expanded_width
))
1059 state
.n_shift
.eq(n_shift_s
),
1060 n
.eq((n_shift_n_lz_out
<< (1 + params
.expanded_width
))
1065 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
1066 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1067 # compute initial f by table lookup
1069 # extra bit for table entries == 1.0
1070 table_width
= 1 + params
.table_data_bits
1071 table
= Memory(width
=table_width
, depth
=len(params
.table
),
1072 init
=[i
.bits
for i
in params
.table
])
1073 addr
= state
.d
[:-params
.n_d_f_int_wid
][-params
.table_addr_bits
:]
1075 table_read
= table
.read_port()
1076 m
.d
.comb
+= table_read
.addr
.eq(addr
)
1077 state
.insert_pipeline_register()
1079 table_read
= table
.read_port(domain
="comb")
1080 m
.d
.comb
+= table_read
.addr
.eq(addr
)
1081 m
.submodules
.table_read
= table_read
1082 state
.f
= Signal(params
.n_d_f_int_wid
+ params
.expanded_width
)
1083 data_shift
= params
.expanded_width
- params
.table_data_bits
1084 m
.d
.comb
+= state
.f
.eq(table_read
.data
<< data_shift
)
1085 elif self
== GoldschmidtDivOp
.MulNByF
:
1086 assert state
.n
.width
== params
.n_d_f_total_wid
, "invalid n width"
1087 assert state
.f
is not None
1088 assert state
.f
.width
== params
.n_d_f_total_wid
, "invalid f width"
1089 n
= Signal
.like(state
.n
)
1090 m
.d
.comb
+= n
.eq((state
.n
* state
.f
) >> params
.expanded_width
)
1092 elif self
== GoldschmidtDivOp
.MulDByF
:
1093 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1094 assert state
.f
is not None
1095 assert state
.f
.width
== params
.n_d_f_total_wid
, "invalid f width"
1096 d
= Signal
.like(state
.d
)
1097 d_times_f
= Signal
.like(state
.d
* state
.f
)
1099 d_times_f
.eq(state
.d
* state
.f
),
1100 # round the multiplication up
1101 d
.eq((d_times_f
>> params
.expanded_width
)
1102 + (d_times_f
[:params
.expanded_width
] != 0)),
1105 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
1106 assert state
.d
.width
== params
.n_d_f_total_wid
, "invalid d width"
1107 f
= Signal
.like(state
.d
)
1108 m
.d
.comb
+= f
.eq((2 << params
.expanded_width
) - state
.d
)
1110 elif self
== GoldschmidtDivOp
.CalcResult
:
1111 assert state
.n
.width
== params
.n_d_f_total_wid
, "invalid n width"
1112 assert state
.n_shift
is not None
1113 # scale to correct value
1114 n
= state
.n
* (1 << state
.n_shift
)
1115 q_approx
= Signal(params
.io_width
)
1116 # extra bit for if it's bigger than orig_d
1117 r_approx
= Signal(params
.io_width
+ 1)
1118 adjusted_r
= Signal(signed(1 + params
.io_width
))
1120 q_approx
.eq((state
.n
<< state
.n_shift
)
1121 >> params
.expanded_width
),
1122 r_approx
.eq(state
.orig_n
- q_approx
* state
.orig_d
),
1123 adjusted_r
.eq(r_approx
- state
.orig_d
),
1125 state
.quotient
= Signal(params
.io_width
)
1126 state
.remainder
= Signal(params
.io_width
)
1128 with m
.If(adjusted_r
>= 0):
1130 state
.quotient
.eq(q_approx
+ 1),
1131 state
.remainder
.eq(adjusted_r
),
1135 state
.quotient
.eq(q_approx
),
1136 state
.remainder
.eq(r_approx
),
1139 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
1143 class GoldschmidtDivState
:
1145 """original numerator"""
1148 """original denominator"""
1151 """numerator -- N_prime[i] in the paper's algorithm 2"""
1154 """denominator -- D_prime[i] in the paper's algorithm 2"""
1156 f
: "FixedPoint | None" = None
1157 """current factor -- F_prime[i] in the paper's algorithm 2"""
1159 quotient
: "int | None" = None
1160 """final quotient"""
1162 remainder
: "int | None" = None
1163 """final remainder"""
1165 n_shift
: "int | None" = None
1166 """amount the numerator needs to be left-shifted at the end of the
1172 for field
in fields(GoldschmidtDivState
):
1173 value
= getattr(self
, field
.name
)
1176 if isinstance(value
, int) and field
.name
!= "n_shift":
1177 fields_str
.append(f
"{field.name}={hex(value)}")
1179 fields_str
.append(f
"{field.name}={value!r}")
1180 return f
"GoldschmidtDivState({', '.join(fields_str)})"
1183 def goldschmidt_div(n
, d
, params
, trace
=lambda state
: None):
1184 """ Goldschmidt division algorithm.
1187 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1188 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1189 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1193 numerator. a `2*width`-bit unsigned integer.
1194 must be less than `d << width`, otherwise the quotient wouldn't
1195 fit in `width` bits.
1197 denominator. a `width`-bit unsigned integer. must not be zero.
1199 the bit-width of the inputs/outputs. must be a positive integer.
1200 trace: Function[[GoldschmidtDivState], None]
1201 called with the initial state and the state after executing each
1202 operation in `params.ops`.
1204 returns: tuple[int, int]
1205 the quotient and remainder. a tuple of two `width`-bit unsigned
1208 assert isinstance(params
, GoldschmidtDivParams
)
1209 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
1210 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
1212 # this whole algorithm is done with fixed-point arithmetic where values
1213 # have `width` fractional bits
1215 state
= GoldschmidtDivState(
1218 n
=FixedPoint(n
, params
.io_width
),
1219 d
=FixedPoint(d
, params
.io_width
),
1223 for op
in params
.ops
:
1224 op
.run(params
, state
)
1227 assert state
.quotient
is not None
1228 assert state
.remainder
is not None
1230 return state
.quotient
, state
.remainder
1233 @dataclass(eq
=False)
1234 class GoldschmidtDivHDLState
:
1236 """The HDL Module"""
1239 """original numerator"""
1242 """original denominator"""
1245 """numerator -- N_prime[i] in the paper's algorithm 2"""
1248 """denominator -- D_prime[i] in the paper's algorithm 2"""
1250 f
: "Signal | None" = None
1251 """current factor -- F_prime[i] in the paper's algorithm 2"""
1253 quotient
: "Signal | None" = None
1254 """final quotient"""
1256 remainder
: "Signal | None" = None
1257 """final remainder"""
1259 n_shift
: "Signal | None" = None
1260 """amount the numerator needs to be left-shifted at the end of the
1264 old_signals
: "defaultdict[str, list[Signal]]" = field(repr=False,
1267 __signal_name_prefix
: "str" = field(default
="state_", repr=False,
1270 def __post_init__(self
):
1271 # called by the autogenerated __init__
1272 self
.old_signals
= defaultdict(list)
1274 def __setattr__(self
, name
, value
):
1275 assert isinstance(name
, str)
1276 if name
.startswith("_"):
1277 return super().__setattr
__(name
, value
)
1279 old_signals
= self
.old_signals
[name
]
1280 except AttributeError:
1281 # haven't yet finished __post_init__
1282 return super().__setattr
__(name
, value
)
1283 assert name
!= "m" and name
!= "old_signals", f
"can't write to {name}"
1284 assert isinstance(value
, Signal
)
1285 value
.name
= f
"{self.__signal_name_prefix}{name}_{len(old_signals)}"
1286 old_signal
= getattr(self
, name
, None)
1287 if old_signal
is not None:
1288 assert isinstance(old_signal
, Signal
)
1289 old_signals
.append(old_signal
)
1290 return super().__setattr
__(name
, value
)
1292 def insert_pipeline_register(self
):
1293 old_prefix
= self
.__signal
_name
_prefix
1295 for field
in fields(GoldschmidtDivHDLState
):
1296 if field
.name
.startswith("_") or field
.name
== "m":
1298 old_sig
= getattr(self
, field
.name
, None)
1301 assert isinstance(old_sig
, Signal
)
1302 new_sig
= Signal
.like(old_sig
)
1303 setattr(self
, field
.name
, new_sig
)
1304 self
.m
.d
.sync
+= new_sig
.eq(old_sig
)
1306 self
.__signal
_name
_prefix
= old_prefix
1309 class GoldschmidtDivHDL(Elaboratable
):
1310 """ Goldschmidt division algorithm.
1313 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
1314 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
1315 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
1318 params: GoldschmidtDivParams
1319 the goldschmidt division algorithm parameters.
1320 pipe_reg_indexes: list[int]
1321 the operation indexes where pipeline registers should be inserted.
1322 duplicate values mean multiple registers should be inserted for
1323 that operation index -- this is useful to allow yosys to spread a
1324 multiplication across those multiple pipeline stages.
1326 true if the rom should be read synchronously rather than
1327 combinatorially, incurring an extra clock cycle of latency.
1328 n: Signal(unsigned(2 * params.io_width))
1329 input numerator. a `2 * params.io_width`-bit unsigned integer.
1330 must be less than `d << params.io_width`, otherwise the quotient
1331 wouldn't fit in `params.io_width` bits.
1332 d: Signal(unsigned(params.io_width))
1333 input denominator. a `params.io_width`-bit unsigned integer.
1335 q: Signal(unsigned(params.io_width))
1336 output quotient. only valid when `n < (d << params.io_width)`.
1337 r: Signal(unsigned(params.io_width))
1338 output remainder. only valid when `n < (d << params.io_width)`.
1339 trace: list[GoldschmidtDivHDLState]
1340 list of the initial state and the state after executing each
1341 operation in `params.ops`.
1345 def total_pipeline_registers(self
):
1346 """the total number of pipeline registers"""
1347 return len(self
.pipe_reg_indexes
) + self
.sync_rom
1349 def __init__(self
, params
, pipe_reg_indexes
=(), sync_rom
=False):
1350 assert isinstance(params
, GoldschmidtDivParams
)
1351 assert isinstance(sync_rom
, bool)
1352 self
.params
= params
1353 self
.pipe_reg_indexes
= sorted(int(i
) for i
in pipe_reg_indexes
)
1354 self
.sync_rom
= sync_rom
1355 self
.n
= Signal(unsigned(2 * params
.io_width
))
1356 self
.d
= Signal(unsigned(params
.io_width
))
1357 self
.q
= Signal(unsigned(params
.io_width
))
1358 self
.r
= Signal(unsigned(params
.io_width
))
1360 # in constructor so we get trace without needing to call elaborate
1361 state
= GoldschmidtDivHDLState(
1368 self
.trace
= [replace(state
)]
1371 pipe_reg_indexes
= list(reversed(self
.pipe_reg_indexes
))
1373 for op_index
, op
in enumerate(self
.params
.ops
):
1374 while len(pipe_reg_indexes
) > 0 \
1375 and pipe_reg_indexes
[-1] <= op_index
:
1376 pipe_reg_indexes
.pop()
1377 state
.insert_pipeline_register()
1378 op
.gen_hdl(self
.params
, state
, self
.sync_rom
)
1379 self
.trace
.append(replace(state
))
1381 while len(pipe_reg_indexes
) > 0:
1382 pipe_reg_indexes
.pop()
1383 state
.insert_pipeline_register()
1386 self
.q
.eq(state
.quotient
),
1387 self
.r
.eq(state
.remainder
),
1390 def elaborate(self
, platform
):
1391 return self
.trace
[0].m
1394 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
= 2
1398 def goldschmidt_sqrt_rsqrt_table(table_addr_bits
, table_data_bits
):
1399 """Generate the look-up table needed for Goldschmidt's square-root and
1400 reciprocal-square-root algorithm.
1403 table_addr_bits: int
1404 the number of address bits for the look-up table.
1405 table_data_bits: int
1406 the number of data bits for the look-up table.
1408 assert isinstance(table_addr_bits
, int) and \
1409 table_addr_bits
>= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1410 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1412 table_len
= 1 << table_addr_bits
1413 for addr
in range(table_len
):
1415 value
= FixedPoint(0, table_data_bits
)
1416 elif (addr
<< 2) < table_len
:
1417 value
= None # table entries should be unused
1419 table_addr_frac_wid
= table_addr_bits
1420 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1421 max_input_value
= FixedPoint(addr
+ 1, table_addr_bits
- 2)
1422 max_frac_wid
= max(max_input_value
.frac_wid
, table_data_bits
)
1423 value
= max_input_value
.to_frac_wid(max_frac_wid
)
1424 value
= value
.rsqrt(RoundDir
.DOWN
)
1425 value
= value
.to_frac_wid(table_data_bits
, RoundDir
.DOWN
)
1428 # tuple for immutability
1431 # FIXME: add code to calculate error bounds and check that the algorithm will
1432 # actually work (like in the goldschmidt division algorithm).
1433 # FIXME: add code to calculate a good set of parameters based on the error
1437 def goldschmidt_sqrt_rsqrt(radicand
, io_width
, frac_wid
, extra_precision
,
1438 table_addr_bits
, table_data_bits
, iter_count
):
1439 """Goldschmidt's square-root and reciprocal-square-root algorithm.
1441 uses algorithm based on second method at:
1442 https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
1445 radicand: FixedPoint(frac_wid=frac_wid)
1446 the input value to take the square-root and reciprocal-square-root of.
1448 the number of bits in the input (`radicand`) and output values.
1450 the number of fraction bits in the input (`radicand`) and output
1452 extra_precision: int
1453 the number of bits of internal extra precision.
1454 table_addr_bits: int
1455 the number of address bits for the look-up table.
1456 table_data_bits: int
1457 the number of data bits for the look-up table.
1459 returns: tuple[FixedPoint, FixedPoint]
1460 the square-root and reciprocal-square-root, rounded down to the
1461 nearest representable value. If `radicand == 0`, then the
1462 reciprocal-square-root value returned is zero.
1464 assert (isinstance(radicand
, FixedPoint
)
1465 and radicand
.frac_wid
== frac_wid
1466 and 0 <= radicand
.bits
< (1 << io_width
))
1467 assert isinstance(io_width
, int) and io_width
>= 1
1468 assert isinstance(frac_wid
, int) and 0 <= frac_wid
< io_width
1469 assert isinstance(extra_precision
, int) and extra_precision
>= io_width
1470 assert isinstance(table_addr_bits
, int) and table_addr_bits
>= 1
1471 assert isinstance(table_data_bits
, int) and table_data_bits
>= 1
1472 assert isinstance(iter_count
, int) and iter_count
>= 0
1473 expanded_frac_wid
= frac_wid
+ extra_precision
1474 s
= radicand
.to_frac_wid(expanded_frac_wid
)
1475 sqrt_rshift
= extra_precision
1476 rsqrt_rshift
= extra_precision
1477 while s
!= 0 and s
< 1:
1478 s
= (s
* 4).to_frac_wid(expanded_frac_wid
)
1482 s
= s
.div(4, expanded_frac_wid
)
1485 table
= goldschmidt_sqrt_rsqrt_table(table_addr_bits
=table_addr_bits
,
1486 table_data_bits
=table_data_bits
)
1487 # core goldschmidt sqrt/rsqrt algorithm:
1489 table_addr_frac_wid
= table_addr_bits
1490 table_addr_frac_wid
-= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
1491 addr
= s
.to_frac_wid(table_addr_frac_wid
, RoundDir
.DOWN
)
1492 assert 0 <= addr
.bits
< (1 << table_addr_bits
), "table addr out of range"
1493 f
= table
[addr
.bits
]
1494 assert f
is not None, "accessed invalid table entry"
1495 # use with_frac_wid to fix IDE type deduction
1496 f
= FixedPoint
.with_frac_wid(f
, expanded_frac_wid
, RoundDir
.DOWN
)
1497 x
= (s
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1498 h
= (f
* 0.5).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1499 for _
in range(iter_count
):
1501 f
= (1.5 - x
* h
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1502 x
= (x
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1503 h
= (h
* f
).to_frac_wid(expanded_frac_wid
, RoundDir
.DOWN
)
1505 # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
1507 sqrt
= FixedPoint(x
.bits
>> sqrt_rshift
, frac_wid
)
1508 rsqrt
= FixedPoint(r
.bits
>> rsqrt_rshift
, frac_wid
)
1510 next_sqrt
= FixedPoint(sqrt
.bits
+ 1, frac_wid
)
1511 if next_sqrt
* next_sqrt
<= radicand
:
1514 next_rsqrt
= FixedPoint(rsqrt
.bits
+ 1, frac_wid
)
1515 if next_rsqrt
* next_rsqrt
* radicand
<= 1 and radicand
!= 0: