integer division algorithm works
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 29 Jun 2019 01:31:26 +0000 (18:31 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 29 Jun 2019 01:33:27 +0000 (18:33 -0700)
.gitignore
src/ieee754/div_rem_sqrt_rsqrt/__init__.py [new file with mode: 0644]
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py [new file with mode: 0644]
src/ieee754/div_rem_sqrt_rsqrt/algorithm.pyi [new file with mode: 0644]
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py [new file with mode: 0644]

index e77dcf4f47f850ae62256eb18b8b93415b52ac38..31fdbbe5574abf7346eea73e2ff0f0d1e0033800 100644 (file)
@@ -1,6 +1,9 @@
 *.vcd
 *.py?
+!*.pyi
 .*.sw?
 __pycache__
 *.v
 *.il
+.eggs
+*.egg-info
diff --git a/src/ieee754/div_rem_sqrt_rsqrt/__init__.py b/src/ieee754/div_rem_sqrt_rsqrt/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
new file mode 100644 (file)
index 0000000..69de47d
--- /dev/null
@@ -0,0 +1,148 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+""" Algorithms for div/rem/sqrt/rsqrt.
+
+code for simulating/testing the various algorithms
+"""
+
+from nmigen.hdl.ast import Const
+
+
+def div_rem(dividend, divisor, bit_width, signed):
+    """ Compute the quotient/remainder following the RISC-V M extension.
+
+    NOT the same as the // or % operators
+    """
+    dividend = Const.normalize(dividend, (bit_width, signed))
+    divisor = Const.normalize(divisor, (bit_width, signed))
+    if divisor == 0:
+        quotient = -1
+        remainder = dividend
+    else:
+        quotient = abs(dividend) // abs(divisor)
+        remainder = abs(dividend) % abs(divisor)
+        if (dividend < 0) != (divisor < 0):
+            quotient = -quotient
+        if dividend < 0:
+            remainder = -remainder
+    quotient = Const.normalize(quotient, (bit_width, signed))
+    remainder = Const.normalize(remainder, (bit_width, signed))
+    return quotient, remainder
+
+
+class UnsignedDivRem:
+    """ Unsigned integer division/remainder following the RISC-V M extension.
+
+    NOT the same as the // or % operators
+
+    :attribute remainder: the remainder and/or dividend
+    :attribute divisor: the divisor
+    :attribute bit_width: the bit width of the inputs/outputs
+    :attribute log2_radix: the base-2 log of the division radix. The number of
+        bits of quotient that are calculated per pipeline stage.
+    :attribute quotient: the quotient
+    :attribute current_shift: the current bit index
+    """
+
+    def __init__(self, dividend, divisor, bit_width, log2_radix=3):
+        """ Create an UnsignedDivRem.
+
+        :param dividend: the dividend/numerator
+        :param divisor: the divisor/denominator
+        :param bit_width: the bit width of the inputs/outputs
+        :param log2_radix: the base-2 log of the division radix. The number of
+            bits of quotient that are calculated per pipeline stage.
+        """
+        self.remainder = Const.normalize(dividend, (bit_width, False))
+        self.divisor = Const.normalize(divisor, (bit_width, False))
+        self.bit_width = bit_width
+        self.log2_radix = log2_radix
+        self.quotient = 0
+        self.current_shift = bit_width
+
+    def calculate_stage(self):
+        """ Calculate the next pipeline stage of the division.
+
+        :returns bool: True if this is the last pipeline stage.
+        """
+        if self.current_shift == 0:
+            return True
+        log2_radix = min(self.log2_radix, self.current_shift)
+        assert log2_radix > 0
+        self.current_shift -= log2_radix
+        radix = 1 << log2_radix
+        remainders = []
+        for i in range(radix):
+            v = (self.divisor * i) << self.current_shift
+            remainders.append(self.remainder - v)
+        quotient_bits = 0
+        for i in range(radix):
+            if remainders[i] >= 0:
+                quotient_bits = i
+        self.remainder = remainders[quotient_bits]
+        self.quotient |= quotient_bits << self.current_shift
+        return self.current_shift == 0
+
+    def calculate(self):
+        """ Calculate the results of the division.
+
+        :returns: self
+        """
+        while not self.calculate_stage():
+            pass
+        return self
+
+
+class DivRem:
+    """ integer division/remainder following the RISC-V M extension.
+
+    NOT the same as the // or % operators
+
+    :attribute dividend: the dividend
+    :attribute divisor: the divisor
+    :attribute signed: if the inputs/outputs are signed instead of unsigned
+    :attribute quotient: the quotient
+    :attribute remainder: the remainder
+    :attribute divider: the base UnsignedDivRem
+    """
+
+    def __init__(self, dividend, divisor, bit_width, signed, log2_radix=3):
+        """ Create a DivRem.
+
+        :param dividend: the dividend/numerator
+        :param divisor: the divisor/denominator
+        :param bit_width: the bit width of the inputs/outputs
+        :param signed: if the inputs/outputs are signed instead of unsigned
+        :param log2_radix: the base-2 log of the division radix. The number of
+            bits of quotient that are calculated per pipeline stage.
+        """
+        self.dividend = Const.normalize(dividend, (bit_width, signed))
+        self.divisor = Const.normalize(divisor, (bit_width, signed))
+        self.signed = signed
+        self.quotient = 0
+        self.remainder = 0
+        self.divider = UnsignedDivRem(abs(dividend), abs(divisor),
+                                      bit_width, log2_radix)
+
+    def calculate_stage(self):
+        """ Calculate the next pipeline stage of the division.
+
+        :returns bool: True if this is the last pipeline stage.
+        """
+        if not self.divider.calculate_stage():
+            return False
+        divisor_sign = self.divisor < 0
+        dividend_sign = self.dividend < 0
+        if self.divisor != 0 and divisor_sign != dividend_sign:
+            quotient = -self.divider.quotient
+        else:
+            quotient = self.divider.quotient
+        if dividend_sign:
+            remainder = -self.divider.remainder
+        else:
+            remainder = self.divider.remainder
+        bit_width = self.divider.bit_width
+        self.quotient = Const.normalize(quotient, (bit_width, self.signed))
+        self.remainder = Const.normalize(remainder, (bit_width, self.signed))
+        return True
diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.pyi b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.pyi
new file mode 100644 (file)
index 0000000..d7bd43e
--- /dev/null
@@ -0,0 +1,33 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+from typing import Tuple
+
+
+def div_rem(dividend: int,
+            divisor: int,
+            bit_width: int,
+            signed: int) -> Tuple[int, int]:
+    ...
+
+
+class UnsignedDivRem:
+    remainder: int
+    divisor: int
+    bit_width: int
+    log2_radix: int
+    quotient: int
+    current_shift: int
+
+    def __init__(self,
+                 dividend: int,
+                 divisor: int,
+                 bit_width: int,
+                 log2_radix: int = 3):
+        ...
+
+    def calculate_stage(self) -> bool:
+        ...
+
+    def calculate(self) -> 'UnsignedDivRem':
+        ...
diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py
new file mode 100644 (file)
index 0000000..ea14c96
--- /dev/null
@@ -0,0 +1,348 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+from nmigen.hdl.ast import Const
+from .algorithm import div_rem, UnsignedDivRem, DivRem
+import unittest
+
+
+class TestDivRemFn(unittest.TestCase):
+    def test_signed(self):
+        test_cases = [
+            # numerator, denominator, quotient, remainder
+            (-8, -8, 1, 0),
+            (-7, -8, 0, -7),
+            (-6, -8, 0, -6),
+            (-5, -8, 0, -5),
+            (-4, -8, 0, -4),
+            (-3, -8, 0, -3),
+            (-2, -8, 0, -2),
+            (-1, -8, 0, -1),
+            (0, -8, 0, 0),
+            (1, -8, 0, 1),
+            (2, -8, 0, 2),
+            (3, -8, 0, 3),
+            (4, -8, 0, 4),
+            (5, -8, 0, 5),
+            (6, -8, 0, 6),
+            (7, -8, 0, 7),
+            (-8, -7, 1, -1),
+            (-7, -7, 1, 0),
+            (-6, -7, 0, -6),
+            (-5, -7, 0, -5),
+            (-4, -7, 0, -4),
+            (-3, -7, 0, -3),
+            (-2, -7, 0, -2),
+            (-1, -7, 0, -1),
+            (0, -7, 0, 0),
+            (1, -7, 0, 1),
+            (2, -7, 0, 2),
+            (3, -7, 0, 3),
+            (4, -7, 0, 4),
+            (5, -7, 0, 5),
+            (6, -7, 0, 6),
+            (7, -7, -1, 0),
+            (-8, -6, 1, -2),
+            (-7, -6, 1, -1),
+            (-6, -6, 1, 0),
+            (-5, -6, 0, -5),
+            (-4, -6, 0, -4),
+            (-3, -6, 0, -3),
+            (-2, -6, 0, -2),
+            (-1, -6, 0, -1),
+            (0, -6, 0, 0),
+            (1, -6, 0, 1),
+            (2, -6, 0, 2),
+            (3, -6, 0, 3),
+            (4, -6, 0, 4),
+            (5, -6, 0, 5),
+            (6, -6, -1, 0),
+            (7, -6, -1, 1),
+            (-8, -5, 1, -3),
+            (-7, -5, 1, -2),
+            (-6, -5, 1, -1),
+            (-5, -5, 1, 0),
+            (-4, -5, 0, -4),
+            (-3, -5, 0, -3),
+            (-2, -5, 0, -2),
+            (-1, -5, 0, -1),
+            (0, -5, 0, 0),
+            (1, -5, 0, 1),
+            (2, -5, 0, 2),
+            (3, -5, 0, 3),
+            (4, -5, 0, 4),
+            (5, -5, -1, 0),
+            (6, -5, -1, 1),
+            (7, -5, -1, 2),
+            (-8, -4, 2, 0),
+            (-7, -4, 1, -3),
+            (-6, -4, 1, -2),
+            (-5, -4, 1, -1),
+            (-4, -4, 1, 0),
+            (-3, -4, 0, -3),
+            (-2, -4, 0, -2),
+            (-1, -4, 0, -1),
+            (0, -4, 0, 0),
+            (1, -4, 0, 1),
+            (2, -4, 0, 2),
+            (3, -4, 0, 3),
+            (4, -4, -1, 0),
+            (5, -4, -1, 1),
+            (6, -4, -1, 2),
+            (7, -4, -1, 3),
+            (-8, -3, 2, -2),
+            (-7, -3, 2, -1),
+            (-6, -3, 2, 0),
+            (-5, -3, 1, -2),
+            (-4, -3, 1, -1),
+            (-3, -3, 1, 0),
+            (-2, -3, 0, -2),
+            (-1, -3, 0, -1),
+            (0, -3, 0, 0),
+            (1, -3, 0, 1),
+            (2, -3, 0, 2),
+            (3, -3, -1, 0),
+            (4, -3, -1, 1),
+            (5, -3, -1, 2),
+            (6, -3, -2, 0),
+            (7, -3, -2, 1),
+            (-8, -2, 4, 0),
+            (-7, -2, 3, -1),
+            (-6, -2, 3, 0),
+            (-5, -2, 2, -1),
+            (-4, -2, 2, 0),
+            (-3, -2, 1, -1),
+            (-2, -2, 1, 0),
+            (-1, -2, 0, -1),
+            (0, -2, 0, 0),
+            (1, -2, 0, 1),
+            (2, -2, -1, 0),
+            (3, -2, -1, 1),
+            (4, -2, -2, 0),
+            (5, -2, -2, 1),
+            (6, -2, -3, 0),
+            (7, -2, -3, 1),
+            (-8, -1, -8, 0),  # overflows and wraps around
+            (-7, -1, 7, 0),
+            (-6, -1, 6, 0),
+            (-5, -1, 5, 0),
+            (-4, -1, 4, 0),
+            (-3, -1, 3, 0),
+            (-2, -1, 2, 0),
+            (-1, -1, 1, 0),
+            (0, -1, 0, 0),
+            (1, -1, -1, 0),
+            (2, -1, -2, 0),
+            (3, -1, -3, 0),
+            (4, -1, -4, 0),
+            (5, -1, -5, 0),
+            (6, -1, -6, 0),
+            (7, -1, -7, 0),
+            (-8, 0, -1, -8),
+            (-7, 0, -1, -7),
+            (-6, 0, -1, -6),
+            (-5, 0, -1, -5),
+            (-4, 0, -1, -4),
+            (-3, 0, -1, -3),
+            (-2, 0, -1, -2),
+            (-1, 0, -1, -1),
+            (0, 0, -1, 0),
+            (1, 0, -1, 1),
+            (2, 0, -1, 2),
+            (3, 0, -1, 3),
+            (4, 0, -1, 4),
+            (5, 0, -1, 5),
+            (6, 0, -1, 6),
+            (7, 0, -1, 7),
+            (-8, 1, -8, 0),
+            (-7, 1, -7, 0),
+            (-6, 1, -6, 0),
+            (-5, 1, -5, 0),
+            (-4, 1, -4, 0),
+            (-3, 1, -3, 0),
+            (-2, 1, -2, 0),
+            (-1, 1, -1, 0),
+            (0, 1, 0, 0),
+            (1, 1, 1, 0),
+            (2, 1, 2, 0),
+            (3, 1, 3, 0),
+            (4, 1, 4, 0),
+            (5, 1, 5, 0),
+            (6, 1, 6, 0),
+            (7, 1, 7, 0),
+            (-8, 2, -4, 0),
+            (-7, 2, -3, -1),
+            (-6, 2, -3, 0),
+            (-5, 2, -2, -1),
+            (-4, 2, -2, 0),
+            (-3, 2, -1, -1),
+            (-2, 2, -1, 0),
+            (-1, 2, 0, -1),
+            (0, 2, 0, 0),
+            (1, 2, 0, 1),
+            (2, 2, 1, 0),
+            (3, 2, 1, 1),
+            (4, 2, 2, 0),
+            (5, 2, 2, 1),
+            (6, 2, 3, 0),
+            (7, 2, 3, 1),
+            (-8, 3, -2, -2),
+            (-7, 3, -2, -1),
+            (-6, 3, -2, 0),
+            (-5, 3, -1, -2),
+            (-4, 3, -1, -1),
+            (-3, 3, -1, 0),
+            (-2, 3, 0, -2),
+            (-1, 3, 0, -1),
+            (0, 3, 0, 0),
+            (1, 3, 0, 1),
+            (2, 3, 0, 2),
+            (3, 3, 1, 0),
+            (4, 3, 1, 1),
+            (5, 3, 1, 2),
+            (6, 3, 2, 0),
+            (7, 3, 2, 1),
+            (-8, 4, -2, 0),
+            (-7, 4, -1, -3),
+            (-6, 4, -1, -2),
+            (-5, 4, -1, -1),
+            (-4, 4, -1, 0),
+            (-3, 4, 0, -3),
+            (-2, 4, 0, -2),
+            (-1, 4, 0, -1),
+            (0, 4, 0, 0),
+            (1, 4, 0, 1),
+            (2, 4, 0, 2),
+            (3, 4, 0, 3),
+            (4, 4, 1, 0),
+            (5, 4, 1, 1),
+            (6, 4, 1, 2),
+            (7, 4, 1, 3),
+            (-8, 5, -1, -3),
+            (-7, 5, -1, -2),
+            (-6, 5, -1, -1),
+            (-5, 5, -1, 0),
+            (-4, 5, 0, -4),
+            (-3, 5, 0, -3),
+            (-2, 5, 0, -2),
+            (-1, 5, 0, -1),
+            (0, 5, 0, 0),
+            (1, 5, 0, 1),
+            (2, 5, 0, 2),
+            (3, 5, 0, 3),
+            (4, 5, 0, 4),
+            (5, 5, 1, 0),
+            (6, 5, 1, 1),
+            (7, 5, 1, 2),
+            (-8, 6, -1, -2),
+            (-7, 6, -1, -1),
+            (-6, 6, -1, 0),
+            (-5, 6, 0, -5),
+            (-4, 6, 0, -4),
+            (-3, 6, 0, -3),
+            (-2, 6, 0, -2),
+            (-1, 6, 0, -1),
+            (0, 6, 0, 0),
+            (1, 6, 0, 1),
+            (2, 6, 0, 2),
+            (3, 6, 0, 3),
+            (4, 6, 0, 4),
+            (5, 6, 0, 5),
+            (6, 6, 1, 0),
+            (7, 6, 1, 1),
+            (-8, 7, -1, -1),
+            (-7, 7, -1, 0),
+            (-6, 7, 0, -6),
+            (-5, 7, 0, -5),
+            (-4, 7, 0, -4),
+            (-3, 7, 0, -3),
+            (-2, 7, 0, -2),
+            (-1, 7, 0, -1),
+            (0, 7, 0, 0),
+            (1, 7, 0, 1),
+            (2, 7, 0, 2),
+            (3, 7, 0, 3),
+            (4, 7, 0, 4),
+            (5, 7, 0, 5),
+            (6, 7, 0, 6),
+            (7, 7, 1, 0),
+        ]
+        for (n, d, q, r) in test_cases:
+            self.assertEqual(div_rem(n, d, 4, True), (q, r))
+
+    def test_unsigned(self):
+        for n in range(16):
+            for d in range(16):
+                if d == 0:
+                    q = 16 - 1
+                    r = n
+                else:
+                    # div_rem matches // and % for unsigned integers
+                    q = n // d
+                    r = n % d
+                self.assertEqual(div_rem(n, d, 4, False), (q, r))
+
+
+class TestUnsignedDivRem(unittest.TestCase):
+    def helper(self, log2_radix):
+        bit_width = 4
+        for n in range(1 << bit_width):
+            for d in range(1 << bit_width):
+                q, r = div_rem(n, d, bit_width, False)
+                with self.subTest(n=n, d=d, q=q, r=r):
+                    udr = UnsignedDivRem(n, d, bit_width, log2_radix)
+                    for _ in range(250 * bit_width):
+                        self.assertEqual(n, udr.quotient * udr.divisor
+                                         + udr.remainder)
+                        if udr.calculate_stage():
+                            break
+                    else:
+                        self.fail("infinite loop")
+                    self.assertEqual(n, udr.quotient * udr.divisor
+                                     + udr.remainder)
+                    self.assertEqual(udr.quotient, q)
+                    self.assertEqual(udr.remainder, r)
+
+    def test_radix_2(self):
+        self.helper(1)
+
+    def test_radix_4(self):
+        self.helper(2)
+
+    def test_radix_8(self):
+        self.helper(3)
+
+    def test_radix_16(self):
+        self.helper(4)
+
+
+class TestDivRem(unittest.TestCase):
+    def helper(self, log2_radix):
+        bit_width = 4
+        for n in range(1 << bit_width):
+            for d in range(1 << bit_width):
+                for signed in False, True:
+                    n = Const.normalize(n, (bit_width, signed))
+                    d = Const.normalize(d, (bit_width, signed))
+                    q, r = div_rem(n, d, bit_width, signed)
+                    with self.subTest(n=n, d=d, q=q, r=r, signed=signed):
+                        dr = DivRem(n, d, bit_width, signed, log2_radix)
+                        for _ in range(250 * bit_width):
+                            if dr.calculate_stage():
+                                break
+                        else:
+                            self.fail("infinite loop")
+                        self.assertEqual(dr.quotient, q)
+                        self.assertEqual(dr.remainder, r)
+
+    def test_radix_2(self):
+        self.helper(1)
+
+    def test_radix_4(self):
+        self.helper(2)
+
+    def test_radix_8(self):
+        self.helper(3)
+
+    def test_radix_16(self):
+        self.helper(4)