From cc2c51cf2cd7b22f7d18871fbfb73e0b7707d91a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 22 Apr 2022 00:58:17 -0700 Subject: [PATCH] add WIP goldschmidt division algorithm --- src/soc/fu/div/experiment/__init__.py | 0 .../fu/div/experiment/goldschmidt_div_sqrt.py | 266 ++++++++++++++++++ src/soc/fu/div/experiment/test/__test__.py | 0 .../test/test_goldschmidt_div_sqrt.py | 42 +++ 4 files changed, 308 insertions(+) create mode 100644 src/soc/fu/div/experiment/__init__.py create mode 100644 src/soc/fu/div/experiment/goldschmidt_div_sqrt.py create mode 100644 src/soc/fu/div/experiment/test/__test__.py create mode 100644 src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py diff --git a/src/soc/fu/div/experiment/__init__.py b/src/soc/fu/div/experiment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py new file mode 100644 index 00000000..12b9f81f --- /dev/null +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: LGPL-3-or-later +# Copyright 2022 Jacob Lifshay programmerjake@gmail.com + +# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part +# of Horizon 2020 EU Programme 957073. + +from dataclasses import dataclass +import math +import enum + + +@enum.unique +class RoundDir(enum.Enum): + DOWN = enum.auto() + UP = enum.auto() + NEAREST_TIES_UP = enum.auto() + ERROR_IF_INEXACT = enum.auto() + + +@dataclass(frozen=True) +class FixedPoint: + bits: int + frac_wid: int + + def __post_init__(self): + assert isinstance(self.bits, int) + assert isinstance(self.frac_wid, int) and self.frac_wid >= 0 + + @staticmethod + def cast(value): + """convert `value` to a fixed-point number with enough fractional + bits to preserve its value.""" + if isinstance(value, FixedPoint): + return value + if isinstance(value, int): + return FixedPoint(value, 0) + if isinstance(value, str): + value = value.strip() + neg = value.startswith("-") + if neg or value.startswith("+"): + value = value[1:] + if value.startswith(("0x", "0X")) and "." in value: + value = value[2:] + got_dot = False + bits = 0 + frac_wid = 0 + for digit in value: + if digit == "_": + continue + if got_dot: + if digit == ".": + raise ValueError("too many `.` in string") + frac_wid += 4 + if digit == ".": + got_dot = True + continue + if not digit.isalnum(): + raise ValueError("invalid hexadecimal digit") + bits <<= 4 + bits |= int("0x" + digit, base=16) + else: + bits = int(value, base=0) + frac_wid = 0 + if neg: + bits = -bits + return FixedPoint(bits, frac_wid) + + if isinstance(value, float): + n, d = value.as_integer_ratio() + log2_d = d.bit_length() - 1 + assert d == 1 << log2_d, ("d isn't a power of 2 -- won't ever " + "fail with float being IEEE 754") + return FixedPoint(n, log2_d) + raise TypeError("can't convert type to FixedPoint") + + @staticmethod + def with_frac_wid(value, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT): + """convert `value` to the nearest fixed-point number with `frac_wid` + fractional bits, rounding according to `round_dir`.""" + value = FixedPoint.cast(value) + assert isinstance(frac_wid, int) and frac_wid >= 0 + assert isinstance(round_dir, RoundDir) + # compute number of bits that should be removed from value + del_bits = value.frac_wid - frac_wid + if del_bits == 0: + return value + if del_bits < 0: # add bits + return FixedPoint(value.bits << -del_bits, + frac_wid) + if round_dir == RoundDir.DOWN: + bits = value.bits >> del_bits + elif round_dir == RoundDir.UP: + bits = -((-value.bits) >> del_bits) + elif round_dir == RoundDir.NEAREST_TIES_UP: + bits = value.bits >> (del_bits - 1) + bits += 1 + bits >>= 1 + elif round_dir == RoundDir.ERROR_IF_INEXACT: + bits = value.bits >> del_bits + if bits << del_bits != value.bits: + raise ValueError("inexact conversion") + else: + assert False, "unimplemented round_dir" + return FixedPoint(bits, frac_wid) + + def to_frac_wid(self, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT): + """convert to the nearest fixed-point number with `frac_wid` + fractional bits, rounding according to `round_dir`.""" + return FixedPoint.with_frac_wid(self, frac_wid, round_dir) + + def __float__(self): + return self.bits * 2.0 ** -self.frac_wid + + def cmp(self, rhs): + """compare self with rhs, returning a positive integer if self is + greater than rhs, zero if self is equal to rhs, and a negative integer + if self is less than rhs.""" + rhs = FixedPoint.cast(rhs) + common_frac_wid = max(self.frac_wid, rhs.frac_wid) + lhs = self.to_frac_wid(common_frac_wid) + rhs = rhs.to_frac_wid(common_frac_wid) + return lhs.bits - rhs.bits + + def __eq__(self, rhs): + return self.cmp(rhs) == 0 + + def __ne__(self, rhs): + return self.cmp(rhs) != 0 + + def __gt__(self, rhs): + return self.cmp(rhs) > 0 + + def __lt__(self, rhs): + return self.cmp(rhs) < 0 + + def __ge__(self, rhs): + return self.cmp(rhs) >= 0 + + def __le__(self, rhs): + return self.cmp(rhs) <= 0 + + def fract(self): + """return the fractional part of `self`. + that is `self - math.floor(self)`. + """ + fract_mask = (1 << self.frac_wid) - 1 + return FixedPoint(self.bits & fract_mask, self.frac_wid) + + def __str__(self): + if self < 0: + return "-" + str(-self) + digit_bits = 4 + frac_digit_count = (self.frac_wid + digit_bits - 1) // digit_bits + fract = self.fract().to_frac_wid(frac_digit_count * digit_bits) + frac_str = hex(fract.bits)[2:].zfill(frac_digit_count) + return hex(math.floor(self)) + "." + frac_str + + def __repr__(self): + return f"FixedPoint.with_frac_wid({str(self)!r}, {self.frac_wid})" + + def __add__(self, rhs): + rhs = FixedPoint.cast(rhs) + common_frac_wid = max(self.frac_wid, rhs.frac_wid) + lhs = self.to_frac_wid(common_frac_wid) + rhs = rhs.to_frac_wid(common_frac_wid) + return FixedPoint(lhs.bits + rhs.bits, common_frac_wid) + + def __neg__(self): + return FixedPoint(-self.bits, self.frac_wid) + + def __sub__(self, rhs): + rhs = FixedPoint.cast(rhs) + common_frac_wid = max(self.frac_wid, rhs.frac_wid) + lhs = self.to_frac_wid(common_frac_wid) + rhs = rhs.to_frac_wid(common_frac_wid) + return FixedPoint(lhs.bits - rhs.bits, common_frac_wid) + + def __mul__(self, rhs): + rhs = FixedPoint.cast(rhs) + return FixedPoint(self.bits * rhs.bits, self.frac_wid + rhs.frac_wid) + + def __floor__(self): + return self.bits >> self.frac_wid + + +def goldschmidt_div(n, d, width): + """ Goldschmidt division algorithm. + + based on: + Even, G., Seidel, P. M., & Ferguson, W. E. (2003). + A Parametric Error Analysis of Goldschmidt's Division Algorithm. + https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.1238&rep=rep1&type=pdf + + arguments: + n: int + numerator. a `2*width`-bit unsigned integer. + must be less than `d << width`, otherwise the quotient wouldn't + fit in `width` bits. + d: int + denominator. a `width`-bit unsigned integer. must not be zero. + width: int + the bit-width of the inputs/outputs. must be a positive integer. + + returns: int + the quotient. a `width`-bit unsigned integer. + """ + assert isinstance(width, int) and width >= 1 + assert isinstance(d, int) and 0 < d < (1 << width) + assert isinstance(n, int) and 0 <= n < (d << width) + + # FIXME: calculate best values for extra_precision, table_addr_bits, and + # table_data_bits -- these are wrong + extra_precision = width + 3 + table_addr_bits = 4 + table_data_bits = 8 + + width += extra_precision + + table = [] + for i in range(1 << table_addr_bits): + value = 1 / (1 + i * 2 ** -table_addr_bits) + table.append(FixedPoint.with_frac_wid(value, table_data_bits, + RoundDir.DOWN)) + + # this whole algorithm is done with fixed-point arithmetic where values + # have `width` fractional bits + + n = FixedPoint(n, width) + d = FixedPoint(d, width) + + # normalize so 1 <= d < 2 + # can easily be done with count-leading-zeros and left shift + while d < 1: + n = (n * 2).to_frac_wid(width) + d = (d * 2).to_frac_wid(width) + + n_shift = 0 + # normalize so 1 <= n < 2 + while n >= 2: + n = (n * 0.5).to_frac_wid(width) + n_shift += 1 + + # compute initial f by table lookup + f = table[(d - 1).to_frac_wid(table_addr_bits, RoundDir.DOWN).bits] + + min_bits_of_precision = 1 + while min_bits_of_precision < width * 2: + # multiply both n and d by f + n *= f + d *= f + n = n.to_frac_wid(width, round_dir=RoundDir.DOWN) + d = d.to_frac_wid(width, round_dir=RoundDir.UP) + + # slightly less than 2 to make the computation just a bitwise not + nearly_two = FixedPoint.with_frac_wid(2, width) + nearly_two = FixedPoint(nearly_two.bits - 1, width) + f = (nearly_two - d).to_frac_wid(width) + + min_bits_of_precision *= 2 + + # scale to correct value + n *= 1 << n_shift + + # avoid incorrectly rounding down + n = n.to_frac_wid(width - extra_precision, round_dir=RoundDir.UP) + return math.floor(n) diff --git a/src/soc/fu/div/experiment/test/__test__.py b/src/soc/fu/div/experiment/test/__test__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py new file mode 100644 index 00000000..84ef7b44 --- /dev/null +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: LGPL-3-or-later +# Copyright 2022 Jacob Lifshay programmerjake@gmail.com + +# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part +# of Horizon 2020 EU Programme 957073. + +import unittest +from nmutil.formaltest import FHDLTestCase +from soc.fu.div.experiment.goldschmidt_div_sqrt import goldschmidt_div, FixedPoint + + +class TestFixedPoint(FHDLTestCase): + def test_str_roundtrip(self): + for frac_wid in range(8): + for bits in range(-1 << 9, 1 << 9): + with self.subTest(bits=hex(bits), frac_wid=frac_wid): + value = FixedPoint(bits, frac_wid) + round_trip_value = FixedPoint.cast(str(value)) + self.assertEqual(value, round_trip_value) + + +class TestGoldschmidtDiv(FHDLTestCase): + def tst(self, width): + assert isinstance(width, int) + for d in range(1, 1 << width): + for n in range(d << width): + expected = n // d + with self.subTest(width=width, n=hex(n), d=hex(d), + expected=hex(expected)): + result = goldschmidt_div(n, d, width) + self.assertEqual(result, expected, f"result={hex(result)}") + + def test_1_through_5(self): + for width in range(1, 5 + 1): + self.tst(width) + + def test_6(self): + self.tst(6) + + +if __name__ == "__main__": + unittest.main() -- 2.30.2