implement fixed_sqrt
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """ Algorithms for div/rem/sqrt/rsqrt.
5
6 code for simulating/testing the various algorithms
7 """
8
9 from nmigen.hdl.ast import Const
10 import math
11
12
13 def div_rem(dividend, divisor, bit_width, signed):
14 """ Compute the quotient/remainder following the RISC-V M extension.
15
16 NOT the same as the // or % operators
17 """
18 dividend = Const.normalize(dividend, (bit_width, signed))
19 divisor = Const.normalize(divisor, (bit_width, signed))
20 if divisor == 0:
21 quotient = -1
22 remainder = dividend
23 else:
24 quotient = abs(dividend) // abs(divisor)
25 remainder = abs(dividend) % abs(divisor)
26 if (dividend < 0) != (divisor < 0):
27 quotient = -quotient
28 if dividend < 0:
29 remainder = -remainder
30 quotient = Const.normalize(quotient, (bit_width, signed))
31 remainder = Const.normalize(remainder, (bit_width, signed))
32 return quotient, remainder
33
34
35 class UnsignedDivRem:
36 """ Unsigned integer division/remainder following the RISC-V M extension.
37
38 NOT the same as the // or % operators
39
40 :attribute remainder: the remainder and/or dividend
41 :attribute divisor: the divisor
42 :attribute bit_width: the bit width of the inputs/outputs
43 :attribute log2_radix: the base-2 log of the division radix. The number of
44 bits of quotient that are calculated per pipeline stage.
45 :attribute quotient: the quotient
46 :attribute current_shift: the current bit index
47 """
48
49 def __init__(self, dividend, divisor, bit_width, log2_radix=3):
50 """ Create an UnsignedDivRem.
51
52 :param dividend: the dividend/numerator
53 :param divisor: the divisor/denominator
54 :param bit_width: the bit width of the inputs/outputs
55 :param log2_radix: the base-2 log of the division radix. The number of
56 bits of quotient that are calculated per pipeline stage.
57 """
58 self.remainder = Const.normalize(dividend, (bit_width, False))
59 self.divisor = Const.normalize(divisor, (bit_width, False))
60 self.bit_width = bit_width
61 self.log2_radix = log2_radix
62 self.quotient = 0
63 self.current_shift = bit_width
64
65 def calculate_stage(self):
66 """ Calculate the next pipeline stage of the division.
67
68 :returns bool: True if this is the last pipeline stage.
69 """
70 if self.current_shift == 0:
71 return True
72 log2_radix = min(self.log2_radix, self.current_shift)
73 assert log2_radix > 0
74 self.current_shift -= log2_radix
75 radix = 1 << log2_radix
76 remainders = []
77 for i in range(radix):
78 v = (self.divisor * i) << self.current_shift
79 remainders.append(self.remainder - v)
80 quotient_bits = 0
81 for i in range(radix):
82 if remainders[i] >= 0:
83 quotient_bits = i
84 self.remainder = remainders[quotient_bits]
85 self.quotient |= quotient_bits << self.current_shift
86 return self.current_shift == 0
87
88 def calculate(self):
89 """ Calculate the results of the division.
90
91 :returns: self
92 """
93 while not self.calculate_stage():
94 pass
95 return self
96
97
98 class DivRem:
99 """ integer division/remainder following the RISC-V M extension.
100
101 NOT the same as the // or % operators
102
103 :attribute dividend: the dividend
104 :attribute divisor: the divisor
105 :attribute signed: if the inputs/outputs are signed instead of unsigned
106 :attribute quotient: the quotient
107 :attribute remainder: the remainder
108 :attribute divider: the base UnsignedDivRem
109 """
110
111 def __init__(self, dividend, divisor, bit_width, signed, log2_radix=3):
112 """ Create a DivRem.
113
114 :param dividend: the dividend/numerator
115 :param divisor: the divisor/denominator
116 :param bit_width: the bit width of the inputs/outputs
117 :param signed: if the inputs/outputs are signed instead of unsigned
118 :param log2_radix: the base-2 log of the division radix. The number of
119 bits of quotient that are calculated per pipeline stage.
120 """
121 self.dividend = Const.normalize(dividend, (bit_width, signed))
122 self.divisor = Const.normalize(divisor, (bit_width, signed))
123 self.signed = signed
124 self.quotient = 0
125 self.remainder = 0
126 self.divider = UnsignedDivRem(abs(dividend), abs(divisor),
127 bit_width, log2_radix)
128
129 def calculate_stage(self):
130 """ Calculate the next pipeline stage of the division.
131
132 :returns bool: True if this is the last pipeline stage.
133 """
134 if not self.divider.calculate_stage():
135 return False
136 divisor_sign = self.divisor < 0
137 dividend_sign = self.dividend < 0
138 if self.divisor != 0 and divisor_sign != dividend_sign:
139 quotient = -self.divider.quotient
140 else:
141 quotient = self.divider.quotient
142 if dividend_sign:
143 remainder = -self.divider.remainder
144 else:
145 remainder = self.divider.remainder
146 bit_width = self.divider.bit_width
147 self.quotient = Const.normalize(quotient, (bit_width, self.signed))
148 self.remainder = Const.normalize(remainder, (bit_width, self.signed))
149 return True
150
151
152 class Fixed:
153 """ Fixed-point number.
154
155 the value is bits * 2 ** -fract_width
156
157 :attribute bits: the bits of the fixed-point number
158 :attribute fract_width: the number of bits in the fractional portion
159 :attribute bit_width: the total number of bits
160 :attribute signed: if the type is signed
161 """
162
163 @staticmethod
164 def from_bits(bits, fract_width, bit_width, signed):
165 """ Create a new Fixed.
166
167 :param bits: the bits of the fixed-point number
168 :param fract_width: the number of bits in the fractional portion
169 :param bit_width: the total number of bits
170 :param signed: if the type is signed
171 """
172 retval = Fixed(0, fract_width, bit_width, signed)
173 retval.bits = Const.normalize(bits, (bit_width, signed))
174 return retval
175
176 def __init__(self, value, fract_width, bit_width, signed):
177 """ Create a new Fixed.
178
179 Note: ``value`` is not the same as ``bits``. To put a particular number
180 in ``bits``, use ``Fixed.from_bits``.
181
182 :param value: the value of the fixed-point number
183 :param fract_width: the number of bits in the fractional portion
184 :param bit_width: the total number of bits
185 :param signed: if the type is signed
186 """
187 assert fract_width >= 0
188 assert bit_width > 0
189 if isinstance(value, Fixed):
190 if fract_width < value.fract_width:
191 bits = value.bits >> (value.fract_width - fract_width)
192 else:
193 bits = value.bits << (fract_width - value.fract_width)
194 elif isinstance(value, int):
195 bits = value << fract_width
196 else:
197 bits = math.floor(value * 2 ** fract_width)
198 self.bits = Const.normalize(bits, (bit_width, signed))
199 self.fract_width = fract_width
200 self.bit_width = bit_width
201 self.signed = signed
202
203 def with_bits(self, bits):
204 """ Create a new Fixed with the specified bits.
205
206 :param bits: the new bits.
207 :returns Fixed: the new Fixed.
208 """
209 return self.from_bits(bits,
210 self.fract_width,
211 self.bit_width,
212 self.signed)
213
214 def with_value(self, value):
215 """ Create a new Fixed with the specified value.
216
217 :param value: the new value.
218 :returns Fixed: the new Fixed.
219 """
220 return Fixed(value,
221 self.fract_width,
222 self.bit_width,
223 self.signed)
224
225 def __repr__(self):
226 """ Get representation."""
227 retval = f"Fixed.from_bits({self.bits}, {self.fract_width}, "
228 return retval + f"{self.bit_width}, {self.signed})"
229
230 def __trunc__(self):
231 """ Truncate to integer."""
232 if self.bits < 0:
233 return self.__ceil__()
234 return self.__floor__()
235
236 def __int__(self):
237 """ Truncate to integer."""
238 return self.__trunc__()
239
240 def __float__(self):
241 """ Convert to float."""
242 return self.bits * 2.0 ** -self.fract_width
243
244 def __floor__(self):
245 """ Floor to integer."""
246 return self.bits >> self.fract_width
247
248 def __ceil__(self):
249 """ Ceil to integer."""
250 return -((-self.bits) >> self.fract_width)
251
252 def __neg__(self):
253 """ Negate."""
254 return self.from_bits(-self.bits, self.fract_width,
255 self.bit_width, self.signed)
256
257 def __pos__(self):
258 """ Unary Positive."""
259 return self
260
261 def __abs__(self):
262 """ Absolute Value."""
263 return self.from_bits(abs(self.bits), self.fract_width,
264 self.bit_width, self.signed)
265
266 def __invert__(self):
267 """ Inverse."""
268 return self.from_bits(~self.bits, self.fract_width,
269 self.bit_width, self.signed)
270
271 def _binary_op(self, rhs, operation, full=False):
272 """ Handle binary arithmetic operators. """
273 if isinstance(rhs, int):
274 rhs_fract_width = 0
275 rhs_bits = rhs
276 int_width = self.bit_width - self.fract_width
277 elif isinstance(rhs, Fixed):
278 if self.signed != rhs.signed:
279 return TypeError("signedness must match")
280 rhs_fract_width = rhs.fract_width
281 rhs_bits = rhs.bits
282 int_width = max(self.bit_width - self.fract_width,
283 rhs.bit_width - rhs.fract_width)
284 else:
285 return NotImplemented
286 fract_width = max(self.fract_width, rhs_fract_width)
287 rhs_bits <<= fract_width - rhs_fract_width
288 lhs_bits = self.bits << fract_width - self.fract_width
289 bit_width = int_width + fract_width
290 if full:
291 return operation(lhs_bits, rhs_bits,
292 fract_width, bit_width, self.signed)
293 bits = operation(lhs_bits, rhs_bits,
294 fract_width)
295 return self.from_bits(bits, fract_width, bit_width, self.signed)
296
297 def __add__(self, rhs):
298 """ Addition."""
299 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs + rhs)
300
301 def __radd__(self, lhs):
302 """ Reverse Addition."""
303 return self.__add__(lhs)
304
305 def __sub__(self, rhs):
306 """ Subtraction."""
307 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs - rhs)
308
309 def __rsub__(self, lhs):
310 """ Reverse Subtraction."""
311 # note swapped argument and parameter order
312 return self._binary_op(lhs, lambda rhs, lhs, fract_width: lhs - rhs)
313
314 def __and__(self, rhs):
315 """ Bitwise And."""
316 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs & rhs)
317
318 def __rand__(self, lhs):
319 """ Reverse Bitwise And."""
320 return self.__and__(lhs)
321
322 def __or__(self, rhs):
323 """ Bitwise Or."""
324 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs | rhs)
325
326 def __ror__(self, lhs):
327 """ Reverse Bitwise Or."""
328 return self.__or__(lhs)
329
330 def __xor__(self, rhs):
331 """ Bitwise Xor."""
332 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs ^ rhs)
333
334 def __rxor__(self, lhs):
335 """ Reverse Bitwise Xor."""
336 return self.__xor__(lhs)
337
338 def __mul__(self, rhs):
339 """ Multiplication. """
340 if isinstance(rhs, int):
341 rhs_fract_width = 0
342 rhs_bits = rhs
343 int_width = self.bit_width - self.fract_width
344 elif isinstance(rhs, Fixed):
345 if self.signed != rhs.signed:
346 return TypeError("signedness must match")
347 rhs_fract_width = rhs.fract_width
348 rhs_bits = rhs.bits
349 int_width = (self.bit_width - self.fract_width
350 + rhs.bit_width - rhs.fract_width)
351 else:
352 return NotImplemented
353 fract_width = self.fract_width + rhs_fract_width
354 bit_width = int_width + fract_width
355 bits = self.bits * rhs_bits
356 return self.from_bits(bits, fract_width, bit_width, self.signed)
357
358 def __rmul__(self, rhs):
359 """ Reverse Multiplication. """
360 return self.__mul__(rhs)
361
362 @staticmethod
363 def _cmp_impl(lhs, rhs, fract_width, bit_width, signed):
364 if lhs < rhs:
365 return -1
366 elif lhs == rhs:
367 return 0
368 return 1
369
370 def cmp(self, rhs):
371 """ Compare self with rhs.
372
373 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
374 and 1 for greater than.
375 Returns NotImplemented for unimplemented cases
376 """
377 return self._binary_op(rhs, self._cmp_impl, full=True)
378
379 def __lt__(self, rhs):
380 """ Less Than."""
381 return self.cmp(rhs) < 0
382
383 def __le__(self, rhs):
384 """ Less Than or Equal."""
385 return self.cmp(rhs) <= 0
386
387 def __eq__(self, rhs):
388 """ Equal."""
389 return self.cmp(rhs) == 0
390
391 def __ne__(self, rhs):
392 """ Not Equal."""
393 return self.cmp(rhs) != 0
394
395 def __gt__(self, rhs):
396 """ Greater Than."""
397 return self.cmp(rhs) > 0
398
399 def __ge__(self, rhs):
400 """ Greater Than or Equal."""
401 return self.cmp(rhs) >= 0
402
403 def __bool__(self):
404 """ Convert to bool."""
405 return bool(self.bits)
406
407 def __str__(self):
408 """ Get text representation."""
409 # don't just use self.__float__() in order to work with numbers more
410 # than 53 bits wide
411 retval = "fixed:"
412 bits = self.bits
413 if bits < 0:
414 retval += "-"
415 bits = -bits
416 int_part = bits >> self.fract_width
417 fract_part = bits & ~(-1 << self.fract_width)
418 # round up fract_width to nearest multiple of 4
419 fract_width = (self.fract_width + 3) & ~3
420 fract_part <<= (fract_width - self.fract_width)
421 fract_width_in_hex_digits = fract_width // 4
422 retval += f"0x{int_part:x}."
423 if fract_width_in_hex_digits != 0:
424 retval += f"{fract_part:x}".zfill(fract_width_in_hex_digits)
425 return retval
426
427
428 class RootRemainder:
429 """ A polynomial root and remainder.
430
431 :attribute root: the polynomial root.
432 :attribute remainder: the remainder.
433 """
434
435 def __init__(self, root, remainder):
436 """ Create a new RootRemainder.
437
438 :param root: the polynomial root.
439 :param remainder: the remainder.
440 """
441 self.root = root
442 self.remainder = remainder
443
444 def __repr__(self):
445 """ Get the representation as a string. """
446 return f"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
447
448 def __str__(self):
449 """ Convert to a string. """
450 return f"RootRemainder({str(self.root)}, {str(self.remainder)})"
451
452
453 def fixed_sqrt(radicand):
454 """ Compute the Square Root and Remainder.
455
456 Solves the polynomial ``radicand - x * x == 0``
457
458 :param radicand: the ``Fixed`` to take the square root of.
459 :returns RootRemainder:
460 """
461 # Written for correctness, not speed
462 if radicand < 0:
463 return None
464 is_int = isinstance(radicand, int)
465 if is_int:
466 radicand = Fixed(radicand, 0, radicand.bit_length() + 1, True)
467 elif not isinstance(radicand, Fixed):
468 raise TypeError()
469
470 def is_remainder_non_negative(root):
471 return radicand >= root * root
472
473 root = radicand.with_bits(0)
474 for i in reversed(range(root.bit_width)):
475 new_root = root.with_bits(root.bits | (1 << i))
476 if new_root < 0: # skip sign bit
477 continue
478 if is_remainder_non_negative(new_root):
479 root = new_root
480 remainder = radicand - root * root
481 if is_int:
482 root = int(root)
483 remainder = int(remainder)
484 return RootRemainder(root, remainder)
485
486
487 class FixedSqrt:
488 # FIXME: finish
489 pass
490
491
492 def fixed_rsqrt(radicand):
493 # FIXME: finish
494 raise NotImplementedError()
495
496
497 class FixedRSqrt:
498 # FIXME: finish
499 pass