1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from dataclasses
import dataclass
, field
10 from fractions
import Fraction
14 class RoundDir(enum
.Enum
):
17 NEAREST_TIES_UP
= enum
.auto()
18 ERROR_IF_INEXACT
= enum
.auto()
21 @dataclass(frozen
=True)
26 def __post_init__(self
):
27 assert isinstance(self
.bits
, int)
28 assert isinstance(self
.frac_wid
, int) and self
.frac_wid
>= 0
32 """convert `value` to a fixed-point number with enough fractional
33 bits to preserve its value."""
34 if isinstance(value
, FixedPoint
):
36 if isinstance(value
, int):
37 return FixedPoint(value
, 0)
38 if isinstance(value
, str):
40 neg
= value
.startswith("-")
41 if neg
or value
.startswith("+"):
43 if value
.startswith(("0x", "0X")) and "." in value
:
53 raise ValueError("too many `.` in string")
58 if not digit
.isalnum():
59 raise ValueError("invalid hexadecimal digit")
61 bits |
= int("0x" + digit
, base
=16)
63 bits
= int(value
, base
=0)
67 return FixedPoint(bits
, frac_wid
)
69 if isinstance(value
, float):
70 n
, d
= value
.as_integer_ratio()
71 log2_d
= d
.bit_length() - 1
72 assert d
== 1 << log2_d
, ("d isn't a power of 2 -- won't ever "
73 "fail with float being IEEE 754")
74 return FixedPoint(n
, log2_d
)
75 raise TypeError("can't convert type to FixedPoint")
78 def with_frac_wid(value
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
79 """convert `value` to the nearest fixed-point number with `frac_wid`
80 fractional bits, rounding according to `round_dir`."""
81 assert isinstance(frac_wid
, int) and frac_wid
>= 0
82 assert isinstance(round_dir
, RoundDir
)
83 if isinstance(value
, Fraction
):
84 numerator
= value
.numerator
85 denominator
= value
.denominator
87 value
= FixedPoint
.cast(value
)
88 # compute number of bits that should be removed from value
89 del_bits
= value
.frac_wid
- frac_wid
92 if del_bits
< 0: # add bits
93 return FixedPoint(value
.bits
<< -del_bits
,
95 numerator
= value
.bits
96 denominator
= 1 << value
.frac_wid
98 numerator
= -numerator
99 denominator
= -denominator
100 bits
, remainder
= divmod(numerator
<< frac_wid
, denominator
)
101 if round_dir
== RoundDir
.DOWN
:
103 elif round_dir
== RoundDir
.UP
:
106 elif round_dir
== RoundDir
.NEAREST_TIES_UP
:
107 if remainder
* 2 >= denominator
:
109 elif round_dir
== RoundDir
.ERROR_IF_INEXACT
:
111 raise ValueError("inexact conversion")
113 assert False, "unimplemented round_dir"
114 return FixedPoint(bits
, frac_wid
)
116 def to_frac_wid(self
, frac_wid
, round_dir
=RoundDir
.ERROR_IF_INEXACT
):
117 """convert to the nearest fixed-point number with `frac_wid`
118 fractional bits, rounding according to `round_dir`."""
119 return FixedPoint
.with_frac_wid(self
, frac_wid
, round_dir
)
122 # use truediv to get correct result even when bits
123 # and frac_wid are huge
124 return float(self
.bits
/ (1 << self
.frac_wid
))
126 def as_fraction(self
):
127 return Fraction(self
.bits
, 1 << self
.frac_wid
)
130 """compare self with rhs, returning a positive integer if self is
131 greater than rhs, zero if self is equal to rhs, and a negative integer
132 if self is less than rhs."""
133 rhs
= FixedPoint
.cast(rhs
)
134 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
135 lhs
= self
.to_frac_wid(common_frac_wid
)
136 rhs
= rhs
.to_frac_wid(common_frac_wid
)
137 return lhs
.bits
- rhs
.bits
139 def __eq__(self
, rhs
):
140 return self
.cmp(rhs
) == 0
142 def __ne__(self
, rhs
):
143 return self
.cmp(rhs
) != 0
145 def __gt__(self
, rhs
):
146 return self
.cmp(rhs
) > 0
148 def __lt__(self
, rhs
):
149 return self
.cmp(rhs
) < 0
151 def __ge__(self
, rhs
):
152 return self
.cmp(rhs
) >= 0
154 def __le__(self
, rhs
):
155 return self
.cmp(rhs
) <= 0
158 """return the fractional part of `self`.
159 that is `self - math.floor(self)`.
161 fract_mask
= (1 << self
.frac_wid
) - 1
162 return FixedPoint(self
.bits
& fract_mask
, self
.frac_wid
)
166 return "-" + str(-self
)
168 frac_digit_count
= (self
.frac_wid
+ digit_bits
- 1) // digit_bits
169 fract
= self
.fract().to_frac_wid(frac_digit_count
* digit_bits
)
170 frac_str
= hex(fract
.bits
)[2:].zfill(frac_digit_count
)
171 return hex(math
.floor(self
)) + "." + frac_str
174 return f
"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
176 def __add__(self
, rhs
):
177 rhs
= FixedPoint
.cast(rhs
)
178 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
179 lhs
= self
.to_frac_wid(common_frac_wid
)
180 rhs
= rhs
.to_frac_wid(common_frac_wid
)
181 return FixedPoint(lhs
.bits
+ rhs
.bits
, common_frac_wid
)
183 def __radd__(self
, lhs
):
185 return self
.__add
__(lhs
)
188 return FixedPoint(-self
.bits
, self
.frac_wid
)
190 def __sub__(self
, rhs
):
191 rhs
= FixedPoint
.cast(rhs
)
192 common_frac_wid
= max(self
.frac_wid
, rhs
.frac_wid
)
193 lhs
= self
.to_frac_wid(common_frac_wid
)
194 rhs
= rhs
.to_frac_wid(common_frac_wid
)
195 return FixedPoint(lhs
.bits
- rhs
.bits
, common_frac_wid
)
197 def __rsub__(self
, lhs
):
199 return -self
.__sub
__(lhs
)
201 def __mul__(self
, rhs
):
202 rhs
= FixedPoint
.cast(rhs
)
203 return FixedPoint(self
.bits
* rhs
.bits
, self
.frac_wid
+ rhs
.frac_wid
)
205 def __rmul__(self
, lhs
):
207 return self
.__mul
__(lhs
)
210 return self
.bits
>> self
.frac_wid
214 class GoldschmidtDivState
:
216 """numerator -- N_prime[i] in the paper's algorithm 2"""
218 """denominator -- D_prime[i] in the paper's algorithm 2"""
219 f
: "FixedPoint | None" = None
220 """current factor -- F_prime[i] in the paper's algorithm 2"""
221 result
: "int | None" = None
223 n_shift
: "int | None" = None
224 """amount the numerator needs to be left-shifted at the end of the
229 class ParamsNotAccurateEnough(Exception):
230 """raised when the parameters aren't accurate enough to have goldschmidt
234 def _assert_accuracy(condition
, msg
="not accurate enough"):
237 raise ParamsNotAccurateEnough(msg
)
240 @dataclass(frozen
=True, unsafe_hash
=True)
241 class GoldschmidtDivParams
:
242 """parameters for a Goldschmidt division algorithm.
243 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
246 """bit-width of the input divisor and the result.
247 the input numerator is `2 * io_width`-bits wide.
250 """number of bits of additional precision used inside the algorithm."""
252 """the number of address bits used in the lookup-table."""
254 """the number of data bits used in the lookup-table."""
255 # tuple to be immutable
256 table
: "tuple[FixedPoint, ...]" = field(init
=False)
257 """the lookup-table"""
258 ops
: "tuple[GoldschmidtDivOp, ...]" = field(init
=False)
259 """the operations needed to perform the goldschmidt division algorithm."""
262 def table_addr_count(self
):
263 """number of distinct addresses in the lookup-table."""
264 # used while computing self.table, so can't just do len(self.table)
265 return 1 << self
.table_addr_bits
267 def table_input_exact_range(self
, addr
):
268 """return the range of inputs as `Fraction`s used for the table entry
269 with address `addr`."""
270 assert isinstance(addr
, int)
271 assert 0 <= addr
< self
.table_addr_count
272 assert self
.io_width
>= self
.table_addr_bits
273 min_numerator
= (1 << self
.table_addr_bits
) + addr
274 denominator
= 1 << self
.table_addr_bits
275 values_per_table_entry
= 1 << (self
.io_width
- self
.table_addr_bits
)
276 max_numerator
= min_numerator
+ values_per_table_entry
277 min_input
= Fraction(min_numerator
, denominator
)
278 max_input
= Fraction(max_numerator
, denominator
)
279 return min_input
, max_input
281 def table_value_exact_range(self
, addr
):
282 """return the range of values as `Fraction`s used for the table entry
283 with address `addr`."""
284 min_value
, max_value
= self
.table_input_exact_range(addr
)
285 # division swaps min/max
286 return 1 / max_value
, 1 / min_value
288 def table_exact_value(self
, index
):
289 min_value
, max_value
= self
.table_value_exact_range(index
)
293 def __post_init__(self
):
294 # called by the autogenerated __init__
295 assert self
.io_width
>= 1
296 assert self
.extra_precision
>= 0
297 assert self
.table_addr_bits
>= 1
298 assert self
.table_data_bits
>= 1
300 for addr
in range(1 << self
.table_addr_bits
):
301 table
.append(FixedPoint
.with_frac_wid(self
.table_exact_value(addr
),
302 self
.table_data_bits
,
304 # we have to use object.__setattr__ since frozen=True
305 object.__setattr
__(self
, "table", tuple(table
))
306 object.__setattr
__(self
, "ops", tuple(_goldschmidt_div_ops(self
)))
310 """ find efficient parameters for a goldschmidt division algorithm
311 with `params.io_width == io_width`.
313 assert isinstance(io_width
, int) and io_width
>= 1
314 for extra_precision
in range(io_width
* 2):
315 for table_addr_bits
in range(3, 7 + 1):
316 table_data_bits
= io_width
+ extra_precision
318 return GoldschmidtDivParams(
320 extra_precision
=extra_precision
,
321 table_addr_bits
=table_addr_bits
,
322 table_data_bits
=table_data_bits
)
323 except ParamsNotAccurateEnough
:
325 raise ValueError(f
"can't find working parameters for a goldschmidt "
326 f
"division algorithm with io_width={io_width}")
329 def expanded_width(self
):
330 """the total number of bits of precision used inside the algorithm."""
331 return self
.io_width
+ self
.extra_precision
335 class GoldschmidtDivOp(enum
.Enum
):
336 Normalize
= "n, d, n_shift = normalize(n, d)"
337 FEqTableLookup
= "f = table_lookup(d)"
340 FEq2MinusD
= "f = 2 - d"
341 CalcResult
= "result = unnormalize_and_round(n)"
343 def run(self
, params
, state
):
344 assert isinstance(params
, GoldschmidtDivParams
)
345 assert isinstance(state
, GoldschmidtDivState
)
346 expanded_width
= params
.expanded_width
347 table_addr_bits
= params
.table_addr_bits
348 if self
== GoldschmidtDivOp
.Normalize
:
349 # normalize so 1 <= d < 2
350 # can easily be done with count-leading-zeros and left shift
352 state
.n
= (state
.n
* 2).to_frac_wid(expanded_width
)
353 state
.d
= (state
.d
* 2).to_frac_wid(expanded_width
)
356 # normalize so 1 <= n < 2
358 state
.n
= (state
.n
* 0.5).to_frac_wid(expanded_width
)
360 elif self
== GoldschmidtDivOp
.FEqTableLookup
:
361 # compute initial f by table lookup
363 d_m_1
= d_m_1
.to_frac_wid(table_addr_bits
, RoundDir
.DOWN
)
364 assert 0 <= d_m_1
.bits
< (1 << params
.table_addr_bits
)
365 state
.f
= params
.table
[d_m_1
.bits
]
366 elif self
== GoldschmidtDivOp
.MulNByF
:
367 assert state
.f
is not None
368 n
= state
.n
* state
.f
369 state
.n
= n
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.DOWN
)
370 elif self
== GoldschmidtDivOp
.MulDByF
:
371 assert state
.f
is not None
372 d
= state
.d
* state
.f
373 state
.d
= d
.to_frac_wid(expanded_width
, round_dir
=RoundDir
.UP
)
374 elif self
== GoldschmidtDivOp
.FEq2MinusD
:
375 state
.f
= (2 - state
.d
).to_frac_wid(expanded_width
)
376 elif self
== GoldschmidtDivOp
.CalcResult
:
377 assert state
.n_shift
is not None
378 # scale to correct value
379 n
= state
.n
* (1 << state
.n_shift
)
381 # avoid incorrectly rounding down
382 n
= n
.to_frac_wid(params
.io_width
, round_dir
=RoundDir
.UP
)
383 state
.result
= math
.floor(n
)
385 assert False, f
"unimplemented GoldschmidtDivOp: {self}"
388 def _goldschmidt_div_ops(params
):
389 """ Goldschmidt division algorithm.
392 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
393 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
394 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
397 params: GoldschmidtDivParams
398 the parameters for the algorithm
400 yields: GoldschmidtDivOp
401 the operations needed to perform the division.
403 assert isinstance(params
, GoldschmidtDivParams
)
405 # establish assumptions of the paper's error analysis (section 3.1):
407 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
408 yield GoldschmidtDivOp
.Normalize
410 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
411 # the assumption is met by multipliers with > 4-bits precision
412 _assert_accuracy(params
.expanded_width
> 4)
414 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
416 # maximum `abs(e[0])`
420 # `f[i] = 0` for all `i`
422 for addr
in range(params
.table_addr_count
):
423 # `F_prime[-1] = (1 - e[0]) / B`
424 # => `e[0] = 1 - B * F_prime[-1]`
425 min_b
, max_b
= params
.table_input_exact_range(addr
)
426 f_prime_m1
= params
.table
[addr
].as_fraction()
427 assert min_b
>= 0 and f_prime_m1
>= 0, \
428 "only positive quadrant of interval multiplication implemented"
429 min_product
= min_b
* f_prime_m1
430 max_product
= max_b
* f_prime_m1
431 # negation swaps min/max
432 min_e0
= 1 - max_product
433 max_e0
= 1 - min_product
434 max_abs_e0
= max(max_abs_e0
, abs(min_e0
), abs(max_e0
))
436 # `D_prime[0] = (1 + d[0]) * B * F_prime[-1]`
437 # `D_prime[0] = abs_round_err + B * F_prime[-1]`
438 # => `d[0] = abs_round_err / (B * F_prime[-1])`
439 max_abs_round_err
= Fraction(1, 1 << params
.expanded_width
)
440 assert min_product
> 0 and max_abs_round_err
>= 0, \
441 "only positive quadrant of interval division implemented"
442 # division swaps divisor's min/max
443 max_d0
= max(max_d0
, max_abs_round_err
/ min_product
)
445 _assert_accuracy(max_abs_e0
+ 3 * max_d0
/ 2 + fi
< Fraction(1, 2))
447 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
448 # (B is the denominator)
450 for addr
in range(params
.table_addr_count
):
451 f_prime_m1
= params
.table
[addr
]
452 _assert_accuracy(0.5 <= f_prime_m1
<= 1)
454 yield GoldschmidtDivOp
.FEqTableLookup
456 # we use Setting I (section 4.1 of the paper)
458 min_bits_of_precision
= 1
459 # FIXME: calculate error and check if it's small enough
460 while min_bits_of_precision
< params
.io_width
* 2:
461 yield GoldschmidtDivOp
.MulNByF
462 yield GoldschmidtDivOp
.MulDByF
463 yield GoldschmidtDivOp
.FEq2MinusD
465 min_bits_of_precision
*= 2
467 yield GoldschmidtDivOp
.CalcResult
470 def goldschmidt_div(n
, d
, params
):
471 """ Goldschmidt division algorithm.
474 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
475 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
476 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
480 numerator. a `2*width`-bit unsigned integer.
481 must be less than `d << width`, otherwise the quotient wouldn't
484 denominator. a `width`-bit unsigned integer. must not be zero.
486 the bit-width of the inputs/outputs. must be a positive integer.
489 the quotient. a `width`-bit unsigned integer.
491 assert isinstance(params
, GoldschmidtDivParams
)
492 assert isinstance(d
, int) and 0 < d
< (1 << params
.io_width
)
493 assert isinstance(n
, int) and 0 <= n
< (d
<< params
.io_width
)
495 # this whole algorithm is done with fixed-point arithmetic where values
496 # have `width` fractional bits
498 state
= GoldschmidtDivState(
499 n
=FixedPoint(n
, params
.io_width
),
500 d
=FixedPoint(d
, params
.io_width
),
503 for op
in params
.ops
:
504 op
.run(params
, state
)
506 assert state
.result
is not None