switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / algorithm.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """ Algorithms for div/rem/sqrt/rsqrt.
5
6 code for simulating/testing the various algorithms
7 """
8
9 from nmigen.hdl.ast import Const
10 import math
11 import enum
12
13
14 def div_rem(dividend, divisor, bit_width, signed):
15 """ Compute the quotient/remainder following the RISC-V M extension.
16
17 NOT the same as the // or % operators
18 """
19 dividend = Const.normalize(dividend, (bit_width, signed))
20 divisor = Const.normalize(divisor, (bit_width, signed))
21 if divisor == 0:
22 quotient = -1
23 remainder = dividend
24 else:
25 quotient = abs(dividend) // abs(divisor)
26 remainder = abs(dividend) % abs(divisor)
27 if (dividend < 0) != (divisor < 0):
28 quotient = -quotient
29 if dividend < 0:
30 remainder = -remainder
31 quotient = Const.normalize(quotient, (bit_width, signed))
32 remainder = Const.normalize(remainder, (bit_width, signed))
33 return quotient, remainder
34
35
36 class UnsignedDivRem:
37 """ Unsigned integer division/remainder following the RISC-V M extension.
38
39 NOT the same as the // or % operators
40
41 :attribute dividend: the dividend
42 :attribute remainder: the remainder
43 :attribute divisor: the divisor
44 :attribute bit_width: the bit width of the inputs/outputs
45 :attribute log2_radix: the base-2 log of the division radix. The number of
46 bits of quotient that are calculated per pipeline stage.
47 :attribute quotient: the quotient
48 :attribute quotient_times_divisor: ``quotient * divisor``
49 :attribute current_shift: the current bit index
50 """
51
52 def __init__(self, dividend, divisor, bit_width, log2_radix=3):
53 """ Create an UnsignedDivRem.
54
55 :param dividend: the dividend/numerator
56 :param divisor: the divisor/denominator
57 :param bit_width: the bit width of the inputs/outputs
58 :param log2_radix: the base-2 log of the division radix. The number of
59 bits of quotient that are calculated per pipeline stage.
60 """
61 self.dividend = Const.normalize(dividend, (bit_width, False))
62 self.divisor = Const.normalize(divisor, (bit_width, False))
63 self.bit_width = bit_width
64 self.log2_radix = log2_radix
65 self.quotient = 0
66 self.quotient_times_divisor = self.quotient * self.divisor
67 self.current_shift = bit_width
68
69 def calculate_stage(self):
70 """ Calculate the next pipeline stage of the division.
71
72 :returns bool: True if this is the last pipeline stage.
73 """
74 if self.current_shift == 0:
75 return True
76 log2_radix = min(self.log2_radix, self.current_shift)
77 assert log2_radix > 0
78 self.current_shift -= log2_radix
79 radix = 1 << log2_radix
80 trial_values = []
81 for i in range(radix):
82 v = self.quotient_times_divisor
83 v += (self.divisor * i) << self.current_shift
84 trial_values.append(v)
85 quotient_bits = 0
86 next_product = self.quotient_times_divisor
87 for i in range(radix):
88 if self.dividend >= trial_values[i]:
89 quotient_bits = i
90 next_product = trial_values[i]
91 self.quotient_times_divisor = next_product
92 self.quotient |= quotient_bits << self.current_shift
93 if self.current_shift == 0:
94 self.remainder = self.dividend - self.quotient_times_divisor
95 return True
96 return False
97
98 def calculate(self):
99 """ Calculate the results of the division.
100
101 :returns: self
102 """
103 while not self.calculate_stage():
104 pass
105 return self
106
107
108 class DivRem:
109 """ integer division/remainder following the RISC-V M extension.
110
111 NOT the same as the // or % operators
112
113 :attribute dividend: the dividend
114 :attribute divisor: the divisor
115 :attribute signed: if the inputs/outputs are signed instead of unsigned
116 :attribute quotient: the quotient
117 :attribute remainder: the remainder
118 :attribute divider: the base UnsignedDivRem
119 """
120
121 def __init__(self, dividend, divisor, bit_width, signed, log2_radix=3):
122 """ Create a DivRem.
123
124 :param dividend: the dividend/numerator
125 :param divisor: the divisor/denominator
126 :param bit_width: the bit width of the inputs/outputs
127 :param signed: if the inputs/outputs are signed instead of unsigned
128 :param log2_radix: the base-2 log of the division radix. The number of
129 bits of quotient that are calculated per pipeline stage.
130 """
131 self.dividend = Const.normalize(dividend, (bit_width, signed))
132 self.divisor = Const.normalize(divisor, (bit_width, signed))
133 self.signed = signed
134 self.quotient = 0
135 self.remainder = 0
136 self.divider = UnsignedDivRem(abs(dividend), abs(divisor),
137 bit_width, log2_radix)
138
139 def calculate_stage(self):
140 """ Calculate the next pipeline stage of the division.
141
142 :returns bool: True if this is the last pipeline stage.
143 """
144 if not self.divider.calculate_stage():
145 return False
146 divisor_sign = self.divisor < 0
147 dividend_sign = self.dividend < 0
148 if self.divisor != 0 and divisor_sign != dividend_sign:
149 quotient = -self.divider.quotient
150 else:
151 quotient = self.divider.quotient
152 if dividend_sign:
153 remainder = -self.divider.remainder
154 else:
155 remainder = self.divider.remainder
156 bit_width = self.divider.bit_width
157 self.quotient = Const.normalize(quotient, (bit_width, self.signed))
158 self.remainder = Const.normalize(remainder, (bit_width, self.signed))
159 return True
160
161
162 class Fixed:
163 """ Fixed-point number.
164
165 the value is bits * 2 ** -fract_width
166
167 :attribute bits: the bits of the fixed-point number
168 :attribute fract_width: the number of bits in the fractional portion
169 :attribute bit_width: the total number of bits
170 :attribute signed: if the type is signed
171 """
172
173 @staticmethod
174 def from_bits(bits, fract_width, bit_width, signed):
175 """ Create a new Fixed.
176
177 :param bits: the bits of the fixed-point number
178 :param fract_width: the number of bits in the fractional portion
179 :param bit_width: the total number of bits
180 :param signed: if the type is signed
181 """
182 retval = Fixed(0, fract_width, bit_width, signed)
183 retval.bits = Const.normalize(bits, (bit_width, signed))
184 return retval
185
186 def __init__(self, value, fract_width, bit_width, signed):
187 """ Create a new Fixed.
188
189 Note: ``value`` is not the same as ``bits``. To put a particular number
190 in ``bits``, use ``Fixed.from_bits``.
191
192 :param value: the value of the fixed-point number
193 :param fract_width: the number of bits in the fractional portion
194 :param bit_width: the total number of bits
195 :param signed: if the type is signed
196 """
197 assert fract_width >= 0
198 assert bit_width > 0
199 if isinstance(value, Fixed):
200 if fract_width < value.fract_width:
201 bits = value.bits >> (value.fract_width - fract_width)
202 else:
203 bits = value.bits << (fract_width - value.fract_width)
204 elif isinstance(value, int):
205 bits = value << fract_width
206 else:
207 bits = math.floor(value * 2 ** fract_width)
208 self.bits = Const.normalize(bits, (bit_width, signed))
209 self.fract_width = fract_width
210 self.bit_width = bit_width
211 self.signed = signed
212
213 def with_bits(self, bits):
214 """ Create a new Fixed with the specified bits.
215
216 :param bits: the new bits.
217 :returns Fixed: the new Fixed.
218 """
219 return self.from_bits(bits,
220 self.fract_width,
221 self.bit_width,
222 self.signed)
223
224 def with_value(self, value):
225 """ Create a new Fixed with the specified value.
226
227 :param value: the new value.
228 :returns Fixed: the new Fixed.
229 """
230 return Fixed(value,
231 self.fract_width,
232 self.bit_width,
233 self.signed)
234
235 def __repr__(self):
236 """ Get representation."""
237 retval = f"Fixed.from_bits({self.bits}, {self.fract_width}, "
238 return retval + f"{self.bit_width}, {self.signed})"
239
240 def __trunc__(self):
241 """ Truncate to integer."""
242 if self.bits < 0:
243 return self.__ceil__()
244 return self.__floor__()
245
246 def __int__(self):
247 """ Truncate to integer."""
248 return self.__trunc__()
249
250 def __float__(self):
251 """ Convert to float."""
252 return self.bits * 2.0 ** -self.fract_width
253
254 def __floor__(self):
255 """ Floor to integer."""
256 return self.bits >> self.fract_width
257
258 def __ceil__(self):
259 """ Ceil to integer."""
260 return -((-self.bits) >> self.fract_width)
261
262 def __neg__(self):
263 """ Negate."""
264 return self.from_bits(-self.bits, self.fract_width,
265 self.bit_width, self.signed)
266
267 def __pos__(self):
268 """ Unary Positive."""
269 return self
270
271 def __abs__(self):
272 """ Absolute Value."""
273 return self.from_bits(abs(self.bits), self.fract_width,
274 self.bit_width, self.signed)
275
276 def __invert__(self):
277 """ Inverse."""
278 return self.from_bits(~self.bits, self.fract_width,
279 self.bit_width, self.signed)
280
281 def _binary_op(self, rhs, operation, full=False):
282 """ Handle binary arithmetic operators. """
283 if isinstance(rhs, int):
284 rhs_fract_width = 0
285 rhs_bits = rhs
286 int_width = self.bit_width - self.fract_width
287 elif isinstance(rhs, Fixed):
288 if self.signed != rhs.signed:
289 return TypeError("signedness must match")
290 rhs_fract_width = rhs.fract_width
291 rhs_bits = rhs.bits
292 int_width = max(self.bit_width - self.fract_width,
293 rhs.bit_width - rhs.fract_width)
294 else:
295 return NotImplemented
296 fract_width = max(self.fract_width, rhs_fract_width)
297 rhs_bits <<= fract_width - rhs_fract_width
298 lhs_bits = self.bits << fract_width - self.fract_width
299 bit_width = int_width + fract_width
300 if full:
301 return operation(lhs_bits, rhs_bits,
302 fract_width, bit_width, self.signed)
303 bits = operation(lhs_bits, rhs_bits,
304 fract_width)
305 return self.from_bits(bits, fract_width, bit_width, self.signed)
306
307 def __add__(self, rhs):
308 """ Addition."""
309 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs + rhs)
310
311 def __radd__(self, lhs):
312 """ Reverse Addition."""
313 return self.__add__(lhs)
314
315 def __sub__(self, rhs):
316 """ Subtraction."""
317 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs - rhs)
318
319 def __rsub__(self, lhs):
320 """ Reverse Subtraction."""
321 # note swapped argument and parameter order
322 return self._binary_op(lhs, lambda rhs, lhs, fract_width: lhs - rhs)
323
324 def __and__(self, rhs):
325 """ Bitwise And."""
326 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs & rhs)
327
328 def __rand__(self, lhs):
329 """ Reverse Bitwise And."""
330 return self.__and__(lhs)
331
332 def __or__(self, rhs):
333 """ Bitwise Or."""
334 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs | rhs)
335
336 def __ror__(self, lhs):
337 """ Reverse Bitwise Or."""
338 return self.__or__(lhs)
339
340 def __xor__(self, rhs):
341 """ Bitwise Xor."""
342 return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs ^ rhs)
343
344 def __rxor__(self, lhs):
345 """ Reverse Bitwise Xor."""
346 return self.__xor__(lhs)
347
348 def __mul__(self, rhs):
349 """ Multiplication. """
350 if isinstance(rhs, int):
351 rhs_fract_width = 0
352 rhs_bits = rhs
353 int_width = self.bit_width - self.fract_width
354 elif isinstance(rhs, Fixed):
355 if self.signed != rhs.signed:
356 return TypeError("signedness must match")
357 rhs_fract_width = rhs.fract_width
358 rhs_bits = rhs.bits
359 int_width = (self.bit_width - self.fract_width
360 + rhs.bit_width - rhs.fract_width)
361 else:
362 return NotImplemented
363 fract_width = self.fract_width + rhs_fract_width
364 bit_width = int_width + fract_width
365 bits = self.bits * rhs_bits
366 return self.from_bits(bits, fract_width, bit_width, self.signed)
367
368 def __rmul__(self, rhs):
369 """ Reverse Multiplication. """
370 return self.__mul__(rhs)
371
372 @staticmethod
373 def _cmp_impl(lhs, rhs, fract_width, bit_width, signed):
374 if lhs < rhs:
375 return -1
376 elif lhs == rhs:
377 return 0
378 return 1
379
380 def cmp(self, rhs):
381 """ Compare self with rhs.
382
383 :returns int: returns -1 if self is less than rhs, 0 if they're equal,
384 and 1 for greater than.
385 Returns NotImplemented for unimplemented cases
386 """
387 return self._binary_op(rhs, self._cmp_impl, full=True)
388
389 def __lt__(self, rhs):
390 """ Less Than."""
391 return self.cmp(rhs) < 0
392
393 def __le__(self, rhs):
394 """ Less Than or Equal."""
395 return self.cmp(rhs) <= 0
396
397 def __eq__(self, rhs):
398 """ Equal."""
399 return self.cmp(rhs) == 0
400
401 def __ne__(self, rhs):
402 """ Not Equal."""
403 return self.cmp(rhs) != 0
404
405 def __gt__(self, rhs):
406 """ Greater Than."""
407 return self.cmp(rhs) > 0
408
409 def __ge__(self, rhs):
410 """ Greater Than or Equal."""
411 return self.cmp(rhs) >= 0
412
413 def __bool__(self):
414 """ Convert to bool."""
415 return bool(self.bits)
416
417 def __str__(self):
418 """ Get text representation."""
419 # don't just use self.__float__() in order to work with numbers more
420 # than 53 bits wide
421 retval = "fixed:"
422 bits = self.bits
423 if bits < 0:
424 retval += "-"
425 bits = -bits
426 int_part = bits >> self.fract_width
427 fract_part = bits & ~(-1 << self.fract_width)
428 # round up fract_width to nearest multiple of 4
429 fract_width = (self.fract_width + 3) & ~3
430 fract_part <<= (fract_width - self.fract_width)
431 fract_width_in_hex_digits = fract_width // 4
432 retval += f"0x{int_part:x}."
433 if fract_width_in_hex_digits != 0:
434 retval += f"{fract_part:x}".zfill(fract_width_in_hex_digits)
435 return retval
436
437
438 class RootRemainder:
439 """ A polynomial root and remainder.
440
441 :attribute root: the polynomial root.
442 :attribute remainder: the remainder.
443 """
444
445 def __init__(self, root, remainder):
446 """ Create a new RootRemainder.
447
448 :param root: the polynomial root.
449 :param remainder: the remainder.
450 """
451 self.root = root
452 self.remainder = remainder
453
454 def __repr__(self):
455 """ Get the representation as a string. """
456 return f"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
457
458 def __str__(self):
459 """ Convert to a string. """
460 return f"RootRemainder({str(self.root)}, {str(self.remainder)})"
461
462
463 def fixed_sqrt(radicand):
464 """ Compute the Square Root and Remainder.
465
466 Solves the polynomial ``radicand - x * x == 0``
467
468 :param radicand: the ``Fixed`` to take the square root of.
469 :returns RootRemainder:
470 """
471 # Written for correctness, not speed
472 if radicand < 0:
473 return None
474 is_int = isinstance(radicand, int)
475 if is_int:
476 radicand = Fixed(radicand, 0, radicand.bit_length() + 1, True)
477 elif not isinstance(radicand, Fixed):
478 raise TypeError()
479
480 def is_remainder_non_negative(root):
481 return radicand >= root * root
482
483 root = radicand.with_bits(0)
484 for i in reversed(range(root.bit_width)):
485 new_root = root.with_bits(root.bits | (1 << i))
486 if new_root < 0: # skip sign bit
487 continue
488 if is_remainder_non_negative(new_root):
489 root = new_root
490 remainder = radicand - root * root
491 if is_int:
492 root = int(root)
493 remainder = int(remainder)
494 return RootRemainder(root, remainder)
495
496
497 class FixedSqrt:
498 """ Fixed-point Square-Root/Remainder.
499
500 :attribute radicand: the radicand
501 :attribute root: the square root
502 :attribute root_squared: the square of ``root``
503 :attribute remainder: the remainder
504 :attribute log2_radix: the base-2 log of the operation radix. The number of
505 bits of root that are calculated per pipeline stage.
506 :attribute current_shift: the current bit index
507 """
508
509 def __init__(self, radicand, log2_radix=3):
510 """ Create an FixedSqrt.
511
512 :param radicand: the radicand.
513 :param log2_radix: the base-2 log of the operation radix. The number of
514 bits of root that are calculated per pipeline stage.
515 """
516 assert isinstance(radicand, Fixed)
517 assert radicand.signed is False
518 self.radicand = radicand
519 self.root = radicand.with_bits(0)
520 self.root_squared = self.root * self.root
521 self.remainder = radicand.with_bits(0) - self.root_squared
522 self.log2_radix = log2_radix
523 self.current_shift = self.root.bit_width
524
525 def calculate_stage(self):
526 """ Calculate the next pipeline stage of the operation.
527
528 :returns bool: True if this is the last pipeline stage.
529 """
530 if self.current_shift == 0:
531 return True
532 log2_radix = min(self.log2_radix, self.current_shift)
533 assert log2_radix > 0
534 self.current_shift -= log2_radix
535 radix = 1 << log2_radix
536 trial_squares = []
537 for i in range(radix):
538 v = self.root_squared
539 factor1 = Fixed.from_bits(i << (self.current_shift + 1),
540 self.root.fract_width,
541 self.root.bit_width + 1 + log2_radix,
542 False)
543 v += self.root * factor1
544 factor2 = Fixed.from_bits(i << self.current_shift,
545 self.root.fract_width,
546 self.root.bit_width + log2_radix,
547 False)
548 v += factor2 * factor2
549 trial_squares.append(self.root_squared.with_value(v))
550 root_bits = 0
551 new_root_squared = self.root_squared
552 for i in range(radix):
553 if self.radicand >= trial_squares[i]:
554 root_bits = i
555 new_root_squared = trial_squares[i]
556 self.root |= Fixed.from_bits(root_bits << self.current_shift,
557 self.root.fract_width,
558 self.root.bit_width + log2_radix,
559 False)
560 self.root_squared = new_root_squared
561 if self.current_shift == 0:
562 self.remainder = self.radicand - self.root_squared
563 return True
564 return False
565
566 def calculate(self):
567 """ Calculate the results of the square root.
568
569 :returns: self
570 """
571 while not self.calculate_stage():
572 pass
573 return self
574
575
576 def fixed_rsqrt(radicand):
577 """ Compute the Reciprocal Square Root and Remainder.
578
579 Solves the polynomial ``1 - x * x * radicand == 0``
580
581 :param radicand: the ``Fixed`` to take the reciprocal square root of.
582 :returns RootRemainder:
583 """
584 # Written for correctness, not speed
585 if radicand <= 0:
586 return None
587 if not isinstance(radicand, Fixed):
588 raise TypeError()
589
590 def is_remainder_non_negative(root):
591 return 1 >= root * root * radicand
592
593 root = radicand.with_bits(0)
594 for i in reversed(range(root.bit_width)):
595 new_root = root.with_bits(root.bits | (1 << i))
596 if new_root < 0: # skip sign bit
597 continue
598 if is_remainder_non_negative(new_root):
599 root = new_root
600 remainder = 1 - root * root * radicand
601 return RootRemainder(root, remainder)
602
603
604 class FixedRSqrt:
605 """ Fixed-point Reciprocal-Square-Root/Remainder.
606
607 :attribute radicand: the radicand
608 :attribute root: the reciprocal square root
609 :attribute radicand_root: ``radicand * root``
610 :attribute radicand_root_squared: ``radicand * root * root``
611 :attribute remainder: the remainder
612 :attribute log2_radix: the base-2 log of the operation radix. The number of
613 bits of root that are calculated per pipeline stage.
614 :attribute current_shift: the current bit index
615 """
616
617 def __init__(self, radicand, log2_radix=3):
618 """ Create an FixedRSqrt.
619
620 :param radicand: the radicand.
621 :param log2_radix: the base-2 log of the operation radix. The number of
622 bits of root that are calculated per pipeline stage.
623 """
624 assert isinstance(radicand, Fixed)
625 assert radicand.signed is False
626 self.radicand = radicand
627 self.root = radicand.with_bits(0)
628 self.radicand_root = radicand.with_bits(0) * self.root
629 self.radicand_root_squared = self.radicand_root * self.root
630 self.remainder = radicand.with_bits(0) - self.radicand_root_squared
631 self.log2_radix = log2_radix
632 self.current_shift = self.root.bit_width
633
634 def calculate_stage(self):
635 """ Calculate the next pipeline stage of the operation.
636
637 :returns bool: True if this is the last pipeline stage.
638 """
639 if self.current_shift == 0:
640 return True
641 log2_radix = min(self.log2_radix, self.current_shift)
642 assert log2_radix > 0
643 self.current_shift -= log2_radix
644 radix = 1 << log2_radix
645 trial_values = []
646 for i in range(radix):
647 v = self.radicand_root_squared
648 factor1 = Fixed.from_bits(i << (self.current_shift + 1),
649 self.root.fract_width,
650 self.root.bit_width + 1 + log2_radix,
651 False)
652 v += self.radicand_root * factor1
653 factor2 = Fixed.from_bits(i << self.current_shift,
654 self.root.fract_width,
655 self.root.bit_width + log2_radix,
656 False)
657 v += self.radicand * factor2 * factor2
658 trial_values.append(self.radicand_root_squared.with_value(v))
659 root_bits = 0
660 new_radicand_root_squared = self.radicand_root_squared
661 for i in range(radix):
662 if 1 >= trial_values[i]:
663 root_bits = i
664 new_radicand_root_squared = trial_values[i]
665 v = self.radicand_root
666 v += self.radicand * Fixed.from_bits(root_bits << self.current_shift,
667 self.root.fract_width,
668 self.root.bit_width + log2_radix,
669 False)
670 self.radicand_root = self.radicand_root.with_value(v)
671 self.root |= Fixed.from_bits(root_bits << self.current_shift,
672 self.root.fract_width,
673 self.root.bit_width + log2_radix,
674 False)
675 self.radicand_root_squared = new_radicand_root_squared
676 if self.current_shift == 0:
677 self.remainder = 1 - self.radicand_root_squared
678 return True
679 return False
680
681 def calculate(self):
682 """ Calculate the results of the reciprocal square root.
683
684 :returns: self
685 """
686 while not self.calculate_stage():
687 pass
688 return self
689
690
691 class Operation(enum.Enum):
692 """ Operation for ``FixedUDivRemSqrtRSqrt``. """
693
694 UDivRem = "unsigned-divide/remainder"
695 SqrtRem = "square-root/remainder"
696 RSqrtRem = "reciprocal-square-root/remainder"
697
698
699 class FixedUDivRemSqrtRSqrt:
700 """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt.
701
702 Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``.
703
704 Formulas solved are:
705 * div/rem:
706 ``dividend == quotient_root * divisor_radicand``
707 * sqrt/rem:
708 ``divisor_radicand == quotient_root * quotient_root``
709 * rsqrt/rem:
710 ``1 == quotient_root * quotient_root * divisor_radicand``
711
712 The remainder is the left-hand-side of the comparison minus the
713 right-hand-side of the comparison in the above formulas.
714
715 Important: not all variables have the same bit-width or fract-width. For
716 instance, ``dividend`` has a bit-width of ``bit_width + fract_width``
717 and a fract-width of ``2 * fract_width`` bits.
718
719 :attribute dividend: dividend for div/rem. Variable with a bit-width of
720 ``bit_width + fract_width`` and a fract-width of ``fract_width * 2``
721 bits.
722 :attribute divisor_radicand: divisor for div/rem and radicand for
723 sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a
724 fract-width of ``fract_width`` bits.
725 :attribute operation: the ``Operation`` to be computed.
726 :attribute quotient_root: the quotient or root part of the result of the
727 operation. Variable with a bit-width of ``bit_width`` and a fract-width
728 of ``fract_width`` bits.
729 :attribute remainder: the remainder part of the result of the operation.
730 Variable with a bit-width of ``bit_width * 3`` and a fract-width
731 of ``fract_width * 3`` bits.
732 :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
733 Variable with a bit-width of ``bit_width * 2`` and a fract-width of
734 ``fract_width * 2`` bits.
735 :attribute compare_lhs: The left-hand-side of the comparison in the
736 equation to be solved. Variable with a bit-width of ``bit_width * 3``
737 and a fract-width of ``fract_width * 3`` bits.
738 :attribute compare_rhs: The right-hand-side of the comparison in the
739 equation to be solved. Variable with a bit-width of ``bit_width * 3``
740 and a fract-width of ``fract_width * 3`` bits.
741 :attribute bit_width: base bit-width. Constant int.
742 :attribute fract_width: base fract-width. Specifies location of base-2
743 radix point. Constant int.
744 :attribute log2_radix: number of bits of ``quotient_root`` that should be
745 computed per pipeline stage (invocation of ``calculate_stage``).
746 Constant int.
747 :attribute current_shift: the current bit index. Variable int.
748 """
749
750 def __init__(self,
751 dividend,
752 divisor_radicand,
753 operation,
754 bit_width,
755 fract_width,
756 log2_radix):
757 """ Create a new ``FixedUDivRemSqrtRSqrt``.
758
759 :param dividend: ``dividend`` attribute's initializer.
760 :param divisor_radicand: ``divisor_radicand`` attribute's initializer.
761 :param operation: ``operation`` attribute's initializer.
762 :param bit_width: ``bit_width`` attribute's initializer.
763 :param fract_width: ``fract_width`` attribute's initializer.
764 :param log2_radix: ``log2_radix`` attribute's initializer.
765 """
766 assert bit_width > 0
767 assert fract_width >= 0
768 assert fract_width <= bit_width
769 assert log2_radix > 0
770 self.dividend = Const.normalize(dividend,
771 (bit_width + fract_width, False))
772 self.divisor_radicand = Const.normalize(divisor_radicand,
773 (bit_width, False))
774 self.quotient_root = 0
775 self.root_times_radicand = 0
776 if operation is Operation.UDivRem:
777 self.compare_lhs = self.dividend << fract_width
778 elif operation is Operation.SqrtRem:
779 self.compare_lhs = self.divisor_radicand << (fract_width * 2)
780 else:
781 assert operation is Operation.RSqrtRem
782 self.compare_lhs = 1 << (fract_width * 3)
783 self.compare_rhs = 0
784 self.remainder = self.compare_lhs
785 self.operation = operation
786 self.bit_width = bit_width
787 self.fract_width = fract_width
788 self.log2_radix = log2_radix
789 self.current_shift = bit_width
790
791 def calculate_stage(self):
792 """ Calculate the next pipeline stage of the operation.
793
794 :returns bool: True if this is the last pipeline stage.
795 """
796 if self.current_shift == 0:
797 return True
798 log2_radix = min(self.log2_radix, self.current_shift)
799 assert log2_radix > 0
800 self.current_shift -= log2_radix
801 radix = 1 << log2_radix
802 trial_compare_rhs_values = []
803 for trial_bits in range(radix):
804 shifted_trial_bits = trial_bits << self.current_shift
805 shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
806 v = self.compare_rhs
807 if self.operation is Operation.UDivRem:
808 factor1 = self.divisor_radicand * shifted_trial_bits
809 v += factor1 << self.fract_width
810 elif self.operation is Operation.SqrtRem:
811 factor1 = self.quotient_root * (shifted_trial_bits << 1)
812 v += factor1 << self.fract_width
813 factor2 = shifted_trial_bits_sqrd
814 v += factor2 << self.fract_width
815 else:
816 assert self.operation is Operation.RSqrtRem
817 factor1 = self.root_times_radicand * (shifted_trial_bits << 1)
818 v += factor1
819 factor2 = self.divisor_radicand * shifted_trial_bits_sqrd
820 v += factor2
821 trial_compare_rhs_values.append(v)
822 shifted_next_bits = 0
823 next_compare_rhs = trial_compare_rhs_values[0]
824 for trial_bits in range(radix):
825 if self.compare_lhs >= trial_compare_rhs_values[trial_bits]:
826 shifted_next_bits = trial_bits << self.current_shift
827 next_compare_rhs = trial_compare_rhs_values[trial_bits]
828 self.root_times_radicand += self.divisor_radicand * shifted_next_bits
829 self.compare_rhs = next_compare_rhs
830 self.quotient_root |= shifted_next_bits
831 self.remainder = self.compare_lhs - self.compare_rhs
832 return self.current_shift == 0
833
834 def calculate(self):
835 """ Calculate the results of the operation.
836
837 :returns: self
838 """
839 while not self.calculate_stage():
840 pass
841 return self