working on goldschmidt division algorithm
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from dataclasses import dataclass, field
8 import math
9 import enum
10 from fractions import Fraction
11
12
13 @enum.unique
14 class RoundDir(enum.Enum):
15 DOWN = enum.auto()
16 UP = enum.auto()
17 NEAREST_TIES_UP = enum.auto()
18 ERROR_IF_INEXACT = enum.auto()
19
20
21 @dataclass(frozen=True)
22 class FixedPoint:
23 bits: int
24 frac_wid: int
25
26 def __post_init__(self):
27 assert isinstance(self.bits, int)
28 assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
29
30 @staticmethod
31 def cast(value):
32 """convert `value` to a fixed-point number with enough fractional
33 bits to preserve its value."""
34 if isinstance(value, FixedPoint):
35 return value
36 if isinstance(value, int):
37 return FixedPoint(value, 0)
38 if isinstance(value, str):
39 value = value.strip()
40 neg = value.startswith("-")
41 if neg or value.startswith("+"):
42 value = value[1:]
43 if value.startswith(("0x", "0X")) and "." in value:
44 value = value[2:]
45 got_dot = False
46 bits = 0
47 frac_wid = 0
48 for digit in value:
49 if digit == "_":
50 continue
51 if got_dot:
52 if digit == ".":
53 raise ValueError("too many `.` in string")
54 frac_wid += 4
55 if digit == ".":
56 got_dot = True
57 continue
58 if not digit.isalnum():
59 raise ValueError("invalid hexadecimal digit")
60 bits <<= 4
61 bits |= int("0x" + digit, base=16)
62 else:
63 bits = int(value, base=0)
64 frac_wid = 0
65 if neg:
66 bits = -bits
67 return FixedPoint(bits, frac_wid)
68
69 if isinstance(value, float):
70 n, d = value.as_integer_ratio()
71 log2_d = d.bit_length() - 1
72 assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
73 "fail with float being IEEE 754")
74 return FixedPoint(n, log2_d)
75 raise TypeError("can't convert type to FixedPoint")
76
77 @staticmethod
78 def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
79 """convert `value` to the nearest fixed-point number with `frac_wid`
80 fractional bits, rounding according to `round_dir`."""
81 assert isinstance(frac_wid, int) and frac_wid >= 0
82 assert isinstance(round_dir, RoundDir)
83 if isinstance(value, Fraction):
84 numerator = value.numerator
85 denominator = value.denominator
86 else:
87 value = FixedPoint.cast(value)
88 # compute number of bits that should be removed from value
89 del_bits = value.frac_wid - frac_wid
90 if del_bits == 0:
91 return value
92 if del_bits < 0: # add bits
93 return FixedPoint(value.bits << -del_bits,
94 frac_wid)
95 numerator = value.bits
96 denominator = 1 << value.frac_wid
97 if denominator < 0:
98 numerator = -numerator
99 denominator = -denominator
100 bits, remainder = divmod(numerator << frac_wid, denominator)
101 if round_dir == RoundDir.DOWN:
102 pass
103 elif round_dir == RoundDir.UP:
104 if remainder != 0:
105 bits += 1
106 elif round_dir == RoundDir.NEAREST_TIES_UP:
107 if remainder * 2 >= denominator:
108 bits += 1
109 elif round_dir == RoundDir.ERROR_IF_INEXACT:
110 if remainder != 0:
111 raise ValueError("inexact conversion")
112 else:
113 assert False, "unimplemented round_dir"
114 return FixedPoint(bits, frac_wid)
115
116 def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
117 """convert to the nearest fixed-point number with `frac_wid`
118 fractional bits, rounding according to `round_dir`."""
119 return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
120
121 def __float__(self):
122 # use truediv to get correct result even when bits
123 # and frac_wid are huge
124 return float(self.bits / (1 << self.frac_wid))
125
126 def as_fraction(self):
127 return Fraction(self.bits, 1 << self.frac_wid)
128
129 def cmp(self, rhs):
130 """compare self with rhs, returning a positive integer if self is
131 greater than rhs, zero if self is equal to rhs, and a negative integer
132 if self is less than rhs."""
133 rhs = FixedPoint.cast(rhs)
134 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
135 lhs = self.to_frac_wid(common_frac_wid)
136 rhs = rhs.to_frac_wid(common_frac_wid)
137 return lhs.bits - rhs.bits
138
139 def __eq__(self, rhs):
140 return self.cmp(rhs) == 0
141
142 def __ne__(self, rhs):
143 return self.cmp(rhs) != 0
144
145 def __gt__(self, rhs):
146 return self.cmp(rhs) > 0
147
148 def __lt__(self, rhs):
149 return self.cmp(rhs) < 0
150
151 def __ge__(self, rhs):
152 return self.cmp(rhs) >= 0
153
154 def __le__(self, rhs):
155 return self.cmp(rhs) <= 0
156
157 def fract(self):
158 """return the fractional part of `self`.
159 that is `self - math.floor(self)`.
160 """
161 fract_mask = (1 << self.frac_wid) - 1
162 return FixedPoint(self.bits & fract_mask, self.frac_wid)
163
164 def __str__(self):
165 if self < 0:
166 return "-" + str(-self)
167 digit_bits = 4
168 frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
169 fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
170 frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
171 return hex(math.floor(self)) + "." + frac_str
172
173 def __repr__(self):
174 return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
175
176 def __add__(self, rhs):
177 rhs = FixedPoint.cast(rhs)
178 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
179 lhs = self.to_frac_wid(common_frac_wid)
180 rhs = rhs.to_frac_wid(common_frac_wid)
181 return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
182
183 def __radd__(self, lhs):
184 # symmetric
185 return self.__add__(lhs)
186
187 def __neg__(self):
188 return FixedPoint(-self.bits, self.frac_wid)
189
190 def __sub__(self, rhs):
191 rhs = FixedPoint.cast(rhs)
192 common_frac_wid = max(self.frac_wid, rhs.frac_wid)
193 lhs = self.to_frac_wid(common_frac_wid)
194 rhs = rhs.to_frac_wid(common_frac_wid)
195 return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
196
197 def __rsub__(self, lhs):
198 # a - b == -(b - a)
199 return -self.__sub__(lhs)
200
201 def __mul__(self, rhs):
202 rhs = FixedPoint.cast(rhs)
203 return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
204
205 def __rmul__(self, lhs):
206 # symmetric
207 return self.__mul__(lhs)
208
209 def __floor__(self):
210 return self.bits >> self.frac_wid
211
212
213 @dataclass
214 class GoldschmidtDivState:
215 n: FixedPoint
216 """numerator -- N_prime[i] in the paper's algorithm 2"""
217 d: FixedPoint
218 """denominator -- D_prime[i] in the paper's algorithm 2"""
219 f: "FixedPoint | None" = None
220 """current factor -- F_prime[i] in the paper's algorithm 2"""
221 result: "int | None" = None
222 """final result"""
223 n_shift: "int | None" = None
224 """amount the numerator needs to be left-shifted at the end of the
225 algorithm.
226 """
227
228
229 class ParamsNotAccurateEnough(Exception):
230 """raised when the parameters aren't accurate enough to have goldschmidt
231 division work."""
232
233
234 def _assert_accuracy(condition, msg="not accurate enough"):
235 if condition:
236 return
237 raise ParamsNotAccurateEnough(msg)
238
239
240 @dataclass(frozen=True, unsafe_hash=True)
241 class GoldschmidtDivParams:
242 """parameters for a Goldschmidt division algorithm.
243 Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
244 """
245 io_width: int
246 """bit-width of the input divisor and the result.
247 the input numerator is `2 * io_width`-bits wide.
248 """
249 extra_precision: int
250 """number of bits of additional precision used inside the algorithm."""
251 table_addr_bits: int
252 """the number of address bits used in the lookup-table."""
253 table_data_bits: int
254 """the number of data bits used in the lookup-table."""
255 # tuple to be immutable
256 table: "tuple[FixedPoint, ...]" = field(init=False)
257 """the lookup-table"""
258 ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
259 """the operations needed to perform the goldschmidt division algorithm."""
260
261 @property
262 def table_addr_count(self):
263 """number of distinct addresses in the lookup-table."""
264 # used while computing self.table, so can't just do len(self.table)
265 return 1 << self.table_addr_bits
266
267 def table_input_exact_range(self, addr):
268 """return the range of inputs as `Fraction`s used for the table entry
269 with address `addr`."""
270 assert isinstance(addr, int)
271 assert 0 <= addr < self.table_addr_count
272 assert self.io_width >= self.table_addr_bits
273 min_numerator = (1 << self.table_addr_bits) + addr
274 denominator = 1 << self.table_addr_bits
275 values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
276 max_numerator = min_numerator + values_per_table_entry
277 min_input = Fraction(min_numerator, denominator)
278 max_input = Fraction(max_numerator, denominator)
279 return min_input, max_input
280
281 def table_value_exact_range(self, addr):
282 """return the range of values as `Fraction`s used for the table entry
283 with address `addr`."""
284 min_value, max_value = self.table_input_exact_range(addr)
285 # division swaps min/max
286 return 1 / max_value, 1 / min_value
287
288 def table_exact_value(self, index):
289 min_value, max_value = self.table_value_exact_range(index)
290 # we round down
291 return min_value
292
293 def __post_init__(self):
294 # called by the autogenerated __init__
295 assert self.io_width >= 1
296 assert self.extra_precision >= 0
297 assert self.table_addr_bits >= 1
298 assert self.table_data_bits >= 1
299 table = []
300 for addr in range(1 << self.table_addr_bits):
301 table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
302 self.table_data_bits,
303 RoundDir.DOWN))
304 # we have to use object.__setattr__ since frozen=True
305 object.__setattr__(self, "table", tuple(table))
306 object.__setattr__(self, "ops", tuple(_goldschmidt_div_ops(self)))
307
308 @staticmethod
309 def get(io_width):
310 """ find efficient parameters for a goldschmidt division algorithm
311 with `params.io_width == io_width`.
312 """
313 assert isinstance(io_width, int) and io_width >= 1
314 for extra_precision in range(io_width * 2):
315 for table_addr_bits in range(3, 7 + 1):
316 table_data_bits = io_width + extra_precision
317 try:
318 return GoldschmidtDivParams(
319 io_width=io_width,
320 extra_precision=extra_precision,
321 table_addr_bits=table_addr_bits,
322 table_data_bits=table_data_bits)
323 except ParamsNotAccurateEnough:
324 pass
325 raise ValueError(f"can't find working parameters for a goldschmidt "
326 f"division algorithm with io_width={io_width}")
327
328 @property
329 def expanded_width(self):
330 """the total number of bits of precision used inside the algorithm."""
331 return self.io_width + self.extra_precision
332
333
334 @enum.unique
335 class GoldschmidtDivOp(enum.Enum):
336 Normalize = "n, d, n_shift = normalize(n, d)"
337 FEqTableLookup = "f = table_lookup(d)"
338 MulNByF = "n *= f"
339 MulDByF = "d *= f"
340 FEq2MinusD = "f = 2 - d"
341 CalcResult = "result = unnormalize_and_round(n)"
342
343 def run(self, params, state):
344 assert isinstance(params, GoldschmidtDivParams)
345 assert isinstance(state, GoldschmidtDivState)
346 expanded_width = params.expanded_width
347 table_addr_bits = params.table_addr_bits
348 if self == GoldschmidtDivOp.Normalize:
349 # normalize so 1 <= d < 2
350 # can easily be done with count-leading-zeros and left shift
351 while state.d < 1:
352 state.n = (state.n * 2).to_frac_wid(expanded_width)
353 state.d = (state.d * 2).to_frac_wid(expanded_width)
354
355 state.n_shift = 0
356 # normalize so 1 <= n < 2
357 while state.n >= 2:
358 state.n = (state.n * 0.5).to_frac_wid(expanded_width)
359 state.n_shift += 1
360 elif self == GoldschmidtDivOp.FEqTableLookup:
361 # compute initial f by table lookup
362 d_m_1 = state.d - 1
363 d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
364 assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
365 state.f = params.table[d_m_1.bits]
366 elif self == GoldschmidtDivOp.MulNByF:
367 assert state.f is not None
368 n = state.n * state.f
369 state.n = n.to_frac_wid(expanded_width, round_dir=RoundDir.DOWN)
370 elif self == GoldschmidtDivOp.MulDByF:
371 assert state.f is not None
372 d = state.d * state.f
373 state.d = d.to_frac_wid(expanded_width, round_dir=RoundDir.UP)
374 elif self == GoldschmidtDivOp.FEq2MinusD:
375 state.f = (2 - state.d).to_frac_wid(expanded_width)
376 elif self == GoldschmidtDivOp.CalcResult:
377 assert state.n_shift is not None
378 # scale to correct value
379 n = state.n * (1 << state.n_shift)
380
381 # avoid incorrectly rounding down
382 n = n.to_frac_wid(params.io_width, round_dir=RoundDir.UP)
383 state.result = math.floor(n)
384 else:
385 assert False, f"unimplemented GoldschmidtDivOp: {self}"
386
387
388 def _goldschmidt_div_ops(params):
389 """ Goldschmidt division algorithm.
390
391 based on:
392 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
393 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
394 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
395
396 arguments:
397 params: GoldschmidtDivParams
398 the parameters for the algorithm
399
400 yields: GoldschmidtDivOp
401 the operations needed to perform the division.
402 """
403 assert isinstance(params, GoldschmidtDivParams)
404
405 # establish assumptions of the paper's error analysis (section 3.1):
406
407 # 1. normalize so A (numerator) and B (denominator) are in [1, 2)
408 yield GoldschmidtDivOp.Normalize
409
410 # 2. ensure all relative errors from directed rounding are <= 1 / 4.
411 # the assumption is met by multipliers with > 4-bits precision
412 _assert_accuracy(params.expanded_width > 4)
413
414 # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
415
416 # maximum `abs(e[0])`
417 max_abs_e0 = 0
418 # maximum `d[0]`
419 max_d0 = 0
420 # `f[i] = 0` for all `i`
421 fi = 0
422 for addr in range(params.table_addr_count):
423 # `F_prime[-1] = (1 - e[0]) / B`
424 # => `e[0] = 1 - B * F_prime[-1]`
425 min_b, max_b = params.table_input_exact_range(addr)
426 f_prime_m1 = params.table[addr].as_fraction()
427 assert min_b >= 0 and f_prime_m1 >= 0, \
428 "only positive quadrant of interval multiplication implemented"
429 min_product = min_b * f_prime_m1
430 max_product = max_b * f_prime_m1
431 # negation swaps min/max
432 min_e0 = 1 - max_product
433 max_e0 = 1 - min_product
434 max_abs_e0 = max(max_abs_e0, abs(min_e0), abs(max_e0))
435
436 # `D_prime[0] = (1 + d[0]) * B * F_prime[-1]`
437 # `D_prime[0] = abs_round_err + B * F_prime[-1]`
438 # => `d[0] = abs_round_err / (B * F_prime[-1])`
439 max_abs_round_err = Fraction(1, 1 << params.expanded_width)
440 assert min_product > 0 and max_abs_round_err >= 0, \
441 "only positive quadrant of interval division implemented"
442 # division swaps divisor's min/max
443 max_d0 = max(max_d0, max_abs_round_err / min_product)
444
445 _assert_accuracy(max_abs_e0 + 3 * max_d0 / 2 + fi < Fraction(1, 2))
446
447 # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
448 # (B is the denominator)
449
450 for addr in range(params.table_addr_count):
451 f_prime_m1 = params.table[addr]
452 _assert_accuracy(0.5 <= f_prime_m1 <= 1)
453
454 yield GoldschmidtDivOp.FEqTableLookup
455
456 # we use Setting I (section 4.1 of the paper)
457
458 min_bits_of_precision = 1
459 # FIXME: calculate error and check if it's small enough
460 while min_bits_of_precision < params.io_width * 2:
461 yield GoldschmidtDivOp.MulNByF
462 yield GoldschmidtDivOp.MulDByF
463 yield GoldschmidtDivOp.FEq2MinusD
464
465 min_bits_of_precision *= 2
466
467 yield GoldschmidtDivOp.CalcResult
468
469
470 def goldschmidt_div(n, d, params):
471 """ Goldschmidt division algorithm.
472
473 based on:
474 Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
475 A Parametric Error Analysis of Goldschmidt's Division Algorithm.
476 https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
477
478 arguments:
479 n: int
480 numerator. a `2*width`-bit unsigned integer.
481 must be less than `d << width`, otherwise the quotient wouldn't
482 fit in `width` bits.
483 d: int
484 denominator. a `width`-bit unsigned integer. must not be zero.
485 width: int
486 the bit-width of the inputs/outputs. must be a positive integer.
487
488 returns: int
489 the quotient. a `width`-bit unsigned integer.
490 """
491 assert isinstance(params, GoldschmidtDivParams)
492 assert isinstance(d, int) and 0 < d < (1 << params.io_width)
493 assert isinstance(n, int) and 0 <= n < (d << params.io_width)
494
495 # this whole algorithm is done with fixed-point arithmetic where values
496 # have `width` fractional bits
497
498 state = GoldschmidtDivState(
499 n=FixedPoint(n, params.io_width),
500 d=FixedPoint(d, params.io_width),
501 )
502
503 for op in params.ops:
504 op.run(params, state)
505
506 assert state.result is not None
507
508 return state.result