switch algorithm in UnsignedDivRem to match FixedUDivRemSqrtRSqrt
authorJacob Lifshay <programmerjake@gmail.com>
Sun, 7 Jul 2019 07:19:12 +0000 (00:19 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sun, 7 Jul 2019 07:20:01 +0000 (00:20 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 6ba28311401d0165d0eaa890531b827b506deea3..84ea1d4c78965778529854f2f35bed9cb344b0a0 100644 (file)
@@ -38,12 +38,14 @@ class UnsignedDivRem:
 
     NOT the same as the // or % operators
 
-    :attribute remainder: the remainder and/or dividend
+    :attribute dividend: the dividend
+    :attribute remainder: the remainder
     :attribute divisor: the divisor
     :attribute bit_width: the bit width of the inputs/outputs
     :attribute log2_radix: the base-2 log of the division radix. The number of
         bits of quotient that are calculated per pipeline stage.
     :attribute quotient: the quotient
+    :attribute quotient_times_divisor: ``quotient * divisor``
     :attribute current_shift: the current bit index
     """
 
@@ -56,11 +58,12 @@ class UnsignedDivRem:
         :param log2_radix: the base-2 log of the division radix. The number of
             bits of quotient that are calculated per pipeline stage.
         """
-        self.remainder = Const.normalize(dividend, (bit_width, False))
+        self.dividend = Const.normalize(dividend, (bit_width, False))
         self.divisor = Const.normalize(divisor, (bit_width, False))
         self.bit_width = bit_width
         self.log2_radix = log2_radix
         self.quotient = 0
+        self.quotient_times_divisor = self.quotient * self.divisor
         self.current_shift = bit_width
 
     def calculate_stage(self):
@@ -74,17 +77,23 @@ class UnsignedDivRem:
         assert log2_radix > 0
         self.current_shift -= log2_radix
         radix = 1 << log2_radix
-        remainders = []
+        trial_values = []
         for i in range(radix):
-            v = (self.divisor * i) << self.current_shift
-            remainders.append(self.remainder - v)
+            v = self.quotient_times_divisor
+            v += (self.divisor * i) << self.current_shift
+            trial_values.append(v)
         quotient_bits = 0
+        next_product = self.quotient_times_divisor
         for i in range(radix):
-            if remainders[i] >= 0:
+            if self.dividend >= trial_values[i]:
                 quotient_bits = i
-        self.remainder = remainders[quotient_bits]
+                next_product = trial_values[i]
+        self.quotient_times_divisor = next_product
         self.quotient |= quotient_bits << self.current_shift
-        return self.current_shift == 0
+        if self.current_shift == 0:
+            self.remainder = self.dividend - self.quotient_times_divisor
+            return True
+        return False
 
     def calculate(self):
         """ Calculate the results of the division.
index c5c3e7b3dba11b5741ed06f38be69aa8586df06d..7d6b201350cd95f3f74ae20b0d16864c975880e1 100644 (file)
@@ -296,14 +296,22 @@ class TestUnsignedDivRem(unittest.TestCase):
                 with self.subTest(n=n, d=d, q=q, r=r):
                     udr = UnsignedDivRem(n, d, bit_width, log2_radix)
                     for _ in range(250 * bit_width):
-                        self.assertEqual(n, udr.quotient * udr.divisor
-                                         + udr.remainder)
+                        self.assertEqual(udr.dividend, n)
+                        self.assertEqual(udr.divisor, d)
+                        self.assertEqual(udr.quotient_times_divisor,
+                                         udr.quotient * udr.divisor)
+                        self.assertGreaterEqual(udr.dividend,
+                                                udr.quotient_times_divisor)
                         if udr.calculate_stage():
                             break
                     else:
                         self.fail("infinite loop")
-                    self.assertEqual(n, udr.quotient * udr.divisor
-                                     + udr.remainder)
+                    self.assertEqual(udr.dividend, n)
+                    self.assertEqual(udr.divisor, d)
+                    self.assertEqual(udr.quotient_times_divisor,
+                                     udr.quotient * udr.divisor)
+                    self.assertGreaterEqual(udr.dividend,
+                                            udr.quotient_times_divisor)
                     self.assertEqual(udr.quotient, q)
                     self.assertEqual(udr.remainder, r)