# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
+import math
import unittest
from nmutil.formaltest import FHDLTestCase
from soc.fu.div.experiment.goldschmidt_div_sqrt import (
- GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint)
+ GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
+ FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
class TestFixedPoint(FHDLTestCase):
round_trip_value = FixedPoint.cast(str(value))
self.assertEqual(value, round_trip_value)
+ @staticmethod
+ def trap(f):
+ try:
+ return f(), None
+ except (ValueError, ZeroDivisionError) as e:
+ return None, e.__class__.__name__
+
+ def test_sqrt(self):
+ for frac_wid in range(8):
+ for bits in range(1 << 9):
+ for round_dir in RoundDir:
+ radicand = FixedPoint(bits, frac_wid)
+ expected_f = math.sqrt(float(radicand))
+ expected = self.trap(lambda: FixedPoint.with_frac_wid(
+ expected_f, frac_wid, round_dir))
+ with self.subTest(radicand=repr(radicand),
+ round_dir=str(round_dir),
+ expected=repr(expected)):
+ result = self.trap(lambda: radicand.sqrt(round_dir))
+ self.assertEqual(result, expected)
+
+ def test_rsqrt(self):
+ for frac_wid in range(8):
+ for bits in range(1, 1 << 9):
+ for round_dir in RoundDir:
+ radicand = FixedPoint(bits, frac_wid)
+ expected_f = 1 / math.sqrt(float(radicand))
+ expected = self.trap(lambda: FixedPoint.with_frac_wid(
+ expected_f, frac_wid, round_dir))
+ with self.subTest(radicand=repr(radicand),
+ round_dir=str(round_dir),
+ expected=repr(expected)):
+ result = self.trap(lambda: radicand.rsqrt(round_dir))
+ self.assertEqual(result, expected)
+
class TestGoldschmidtDiv(FHDLTestCase):
def test_case1(self):
self.tst_params(64)
+class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
+ def tst(self, io_width, frac_wid, extra_precision,
+ table_addr_bits, table_data_bits, iter_count):
+ assert isinstance(io_width, int)
+ assert isinstance(frac_wid, int)
+ assert isinstance(extra_precision, int)
+ assert isinstance(table_addr_bits, int)
+ assert isinstance(table_data_bits, int)
+ assert isinstance(iter_count, int)
+ with self.subTest(io_width=io_width, frac_wid=frac_wid,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits,
+ iter_count=iter_count):
+ for bits in range(1 << io_width):
+ radicand = FixedPoint(bits, frac_wid)
+ expected_sqrt = radicand.sqrt(RoundDir.DOWN)
+ expected_rsqrt = FixedPoint(0, frac_wid)
+ if radicand > 0:
+ expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
+ with self.subTest(radicand=repr(radicand),
+ expected_sqrt=repr(expected_sqrt),
+ expected_rsqrt=repr(expected_rsqrt)):
+ sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
+ radicand=radicand, io_width=io_width,
+ frac_wid=frac_wid,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits,
+ iter_count=iter_count)
+ with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
+ self.assertEqual((sqrt, rsqrt),
+ (expected_sqrt, expected_rsqrt))
+
+ def test1(self):
+ self.tst(io_width=16, frac_wid=8, extra_precision=20,
+ table_addr_bits=4, table_data_bits=28, iter_count=4)
+
+
if __name__ == "__main__":
unittest.main()