add WIP goldschmidt division algorithm
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 22 Apr 2022 07:58:17 +0000 (00:58 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 22 Apr 2022 07:58:17 +0000 (00:58 -0700)
src/soc/fu/div/experiment/__init__.py [new file with mode: 0644]
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py [new file with mode: 0644]
src/soc/fu/div/experiment/test/__test__.py [new file with mode: 0644]
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py [new file with mode: 0644]

diff --git a/src/soc/fu/div/experiment/__init__.py b/src/soc/fu/div/experiment/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
new file mode 100644 (file)
index 0000000..12b9f81
--- /dev/null
@@ -0,0 +1,266 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from dataclasses import dataclass
+import math
+import enum
+
+
+@enum.unique
+class RoundDir(enum.Enum):
+    DOWN = enum.auto()
+    UP = enum.auto()
+    NEAREST_TIES_UP = enum.auto()
+    ERROR_IF_INEXACT = enum.auto()
+
+
+@dataclass(frozen=True)
+class FixedPoint:
+    bits: int
+    frac_wid: int
+
+    def __post_init__(self):
+        assert isinstance(self.bits, int)
+        assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
+
+    @staticmethod
+    def cast(value):
+        """convert `value` to a fixed-point number with enough fractional
+        bits to preserve its value."""
+        if isinstance(value, FixedPoint):
+            return value
+        if isinstance(value, int):
+            return FixedPoint(value, 0)
+        if isinstance(value, str):
+            value = value.strip()
+            neg = value.startswith("-")
+            if neg or value.startswith("+"):
+                value = value[1:]
+            if value.startswith(("0x", "0X")) and "." in value:
+                value = value[2:]
+                got_dot = False
+                bits = 0
+                frac_wid = 0
+                for digit in value:
+                    if digit == "_":
+                        continue
+                    if got_dot:
+                        if digit == ".":
+                            raise ValueError("too many `.` in string")
+                        frac_wid += 4
+                    if digit == ".":
+                        got_dot = True
+                        continue
+                    if not digit.isalnum():
+                        raise ValueError("invalid hexadecimal digit")
+                    bits <<= 4
+                    bits |= int("0x" + digit, base=16)
+            else:
+                bits = int(value, base=0)
+                frac_wid = 0
+            if neg:
+                bits = -bits
+            return FixedPoint(bits, frac_wid)
+
+        if isinstance(value, float):
+            n, d = value.as_integer_ratio()
+            log2_d = d.bit_length() - 1
+            assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever "
+                                      "fail with float being IEEE 754")
+            return FixedPoint(n, log2_d)
+        raise TypeError("can't convert type to FixedPoint")
+
+    @staticmethod
+    def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
+        """convert `value` to the nearest fixed-point number with `frac_wid`
+        fractional bits, rounding according to `round_dir`."""
+        value = FixedPoint.cast(value)
+        assert isinstance(frac_wid, int) and frac_wid >= 0
+        assert isinstance(round_dir, RoundDir)
+        # compute number of bits that should be removed from value
+        del_bits = value.frac_wid - frac_wid
+        if del_bits == 0:
+            return value
+        if del_bits < 0:  # add bits
+            return FixedPoint(value.bits << -del_bits,
+                              frac_wid)
+        if round_dir == RoundDir.DOWN:
+            bits = value.bits >> del_bits
+        elif round_dir == RoundDir.UP:
+            bits = -((-value.bits) >> del_bits)
+        elif round_dir == RoundDir.NEAREST_TIES_UP:
+            bits = value.bits >> (del_bits - 1)
+            bits += 1
+            bits >>= 1
+        elif round_dir == RoundDir.ERROR_IF_INEXACT:
+            bits = value.bits >> del_bits
+            if bits << del_bits != value.bits:
+                raise ValueError("inexact conversion")
+        else:
+            assert False, "unimplemented round_dir"
+        return FixedPoint(bits, frac_wid)
+
+    def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
+        """convert to the nearest fixed-point number with `frac_wid`
+        fractional bits, rounding according to `round_dir`."""
+        return FixedPoint.with_frac_wid(self, frac_wid, round_dir)
+
+    def __float__(self):
+        return self.bits * 2.0 ** -self.frac_wid
+
+    def cmp(self, rhs):
+        """compare self with rhs, returning a positive integer if self is
+        greater than rhs, zero if self is equal to rhs, and a negative integer
+        if self is less than rhs."""
+        rhs = FixedPoint.cast(rhs)
+        common_frac_wid = max(self.frac_wid, rhs.frac_wid)
+        lhs = self.to_frac_wid(common_frac_wid)
+        rhs = rhs.to_frac_wid(common_frac_wid)
+        return lhs.bits - rhs.bits
+
+    def __eq__(self, rhs):
+        return self.cmp(rhs) == 0
+
+    def __ne__(self, rhs):
+        return self.cmp(rhs) != 0
+
+    def __gt__(self, rhs):
+        return self.cmp(rhs) > 0
+
+    def __lt__(self, rhs):
+        return self.cmp(rhs) < 0
+
+    def __ge__(self, rhs):
+        return self.cmp(rhs) >= 0
+
+    def __le__(self, rhs):
+        return self.cmp(rhs) <= 0
+
+    def fract(self):
+        """return the fractional part of `self`.
+        that is `self - math.floor(self)`.
+        """
+        fract_mask = (1 << self.frac_wid) - 1
+        return FixedPoint(self.bits & fract_mask, self.frac_wid)
+
+    def __str__(self):
+        if self < 0:
+            return "-" + str(-self)
+        digit_bits = 4
+        frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits
+        fract = self.fract().to_frac_wid(frac_digit_count * digit_bits)
+        frac_str = hex(fract.bits)[2:].zfill(frac_digit_count)
+        return hex(math.floor(self)) + "." + frac_str
+
+    def __repr__(self):
+        return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})"
+
+    def __add__(self, rhs):
+        rhs = FixedPoint.cast(rhs)
+        common_frac_wid = max(self.frac_wid, rhs.frac_wid)
+        lhs = self.to_frac_wid(common_frac_wid)
+        rhs = rhs.to_frac_wid(common_frac_wid)
+        return FixedPoint(lhs.bits + rhs.bits, common_frac_wid)
+
+    def __neg__(self):
+        return FixedPoint(-self.bits, self.frac_wid)
+
+    def __sub__(self, rhs):
+        rhs = FixedPoint.cast(rhs)
+        common_frac_wid = max(self.frac_wid, rhs.frac_wid)
+        lhs = self.to_frac_wid(common_frac_wid)
+        rhs = rhs.to_frac_wid(common_frac_wid)
+        return FixedPoint(lhs.bits - rhs.bits, common_frac_wid)
+
+    def __mul__(self, rhs):
+        rhs = FixedPoint.cast(rhs)
+        return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid)
+
+    def __floor__(self):
+        return self.bits >> self.frac_wid
+
+
+def goldschmidt_div(n, d, width):
+    """ Goldschmidt division algorithm.
+
+        based on:
+        Even, G., Seidel, P. M., & Ferguson, W. E. (2003).
+        A Parametric Error Analysis of Goldschmidt's Division Algorithm.
+        https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf
+
+        arguments:
+        n: int
+            numerator. a `2*width`-bit unsigned integer.
+            must be less than `d << width`, otherwise the quotient wouldn't
+            fit in `width` bits.
+        d: int
+            denominator. a `width`-bit unsigned integer. must not be zero.
+        width: int
+            the bit-width of the inputs/outputs. must be a positive integer.
+
+        returns: int
+            the quotient. a `width`-bit unsigned integer.
+    """
+    assert isinstance(width, int) and width >= 1
+    assert isinstance(d, int) and 0 < d < (1 << width)
+    assert isinstance(n, int) and 0 <= n < (d << width)
+
+    # FIXME: calculate best values for extra_precision, table_addr_bits, and
+    # table_data_bits -- these are wrong
+    extra_precision = width + 3
+    table_addr_bits = 4
+    table_data_bits = 8
+
+    width += extra_precision
+
+    table = []
+    for i in range(1 << table_addr_bits):
+        value = 1 / (1 + i * 2 ** -table_addr_bits)
+        table.append(FixedPoint.with_frac_wid(value, table_data_bits,
+                                              RoundDir.DOWN))
+
+    # this whole algorithm is done with fixed-point arithmetic where values
+    # have `width` fractional bits
+
+    n = FixedPoint(n, width)
+    d = FixedPoint(d, width)
+
+    # normalize so 1 <= d < 2
+    # can easily be done with count-leading-zeros and left shift
+    while d < 1:
+        n = (n * 2).to_frac_wid(width)
+        d = (d * 2).to_frac_wid(width)
+
+    n_shift = 0
+    # normalize so 1 <= n < 2
+    while n >= 2:
+        n = (n * 0.5).to_frac_wid(width)
+        n_shift += 1
+
+    # compute initial f by table lookup
+    f = table[(d - 1).to_frac_wid(table_addr_bits, RoundDir.DOWN).bits]
+
+    min_bits_of_precision = 1
+    while min_bits_of_precision < width * 2:
+        # multiply both n and d by f
+        n *= f
+        d *= f
+        n = n.to_frac_wid(width, round_dir=RoundDir.DOWN)
+        d = d.to_frac_wid(width, round_dir=RoundDir.UP)
+
+        # slightly less than 2 to make the computation just a bitwise not
+        nearly_two = FixedPoint.with_frac_wid(2, width)
+        nearly_two = FixedPoint(nearly_two.bits - 1, width)
+        f = (nearly_two - d).to_frac_wid(width)
+
+        min_bits_of_precision *= 2
+
+    # scale to correct value
+    n *= 1 << n_shift
+
+    # avoid incorrectly rounding down
+    n = n.to_frac_wid(width - extra_precision, round_dir=RoundDir.UP)
+    return math.floor(n)
diff --git a/src/soc/fu/div/experiment/test/__test__.py b/src/soc/fu/div/experiment/test/__test__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py
new file mode 100644 (file)
index 0000000..84ef7b4
--- /dev/null
@@ -0,0 +1,42 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+import unittest
+from nmutil.formaltest import FHDLTestCase
+from soc.fu.div.experiment.goldschmidt_div_sqrt import goldschmidt_div, FixedPoint
+
+
+class TestFixedPoint(FHDLTestCase):
+    def test_str_roundtrip(self):
+        for frac_wid in range(8):
+            for bits in range(-1 << 9, 1 << 9):
+                with self.subTest(bits=hex(bits), frac_wid=frac_wid):
+                    value = FixedPoint(bits, frac_wid)
+                    round_trip_value = FixedPoint.cast(str(value))
+                    self.assertEqual(value, round_trip_value)
+
+
+class TestGoldschmidtDiv(FHDLTestCase):
+    def tst(self, width):
+        assert isinstance(width, int)
+        for d in range(1, 1 << width):
+            for n in range(d << width):
+                expected = n // d
+                with self.subTest(width=width, n=hex(n), d=hex(d),
+                                  expected=hex(expected)):
+                    result = goldschmidt_div(n, d, width)
+                    self.assertEqual(result, expected, f"result={hex(result)}")
+
+    def test_1_through_5(self):
+        for width in range(1, 5 + 1):
+            self.tst(width)
+
+    def test_6(self):
+        self.tst(6)
+
+
+if __name__ == "__main__":
+    unittest.main()