X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fsoc%2Ffu%2Fdiv%2Fexperiment%2Ftest%2Ftest_goldschmidt_div_sqrt.py;h=b4d4fb85cd232b12b8eb28f40226e0679b23855a;hb=2d9fb70cf873f26b77a90cd938e9d656afc4e1d1;hp=5b4c89ad037e0963006dd96b6ea6f87fc7e253c5;hpb=f8868d2339dbfa251007a06563678ced98c49d90;p=soc.git diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index 5b4c89ad..b4d4fb85 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -4,6 +4,7 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +from dataclasses import fields, replace import math import unittest from nmutil.formaltest import FHDLTestCase @@ -12,8 +13,9 @@ from nmigen.sim import Tick, Delay from nmigen.hdl.ast import Signal from nmigen.hdl.dsl import Module from soc.fu.div.experiment.goldschmidt_div_sqrt import ( - GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough, - goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) + GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams, + GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div, + FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) class TestFixedPoint(FHDLTestCase): @@ -88,10 +90,8 @@ class TestGoldschmidtDiv(FHDLTestCase): with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) - @unittest.skip("hdl/simulation currently broken") def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(), sync_rom=False): - # FIXME: finish getting hdl/simulation to work assert isinstance(io_width, int) params = GoldschmidtDivParams.get(io_width) m = Module() @@ -103,7 +103,12 @@ class TestGoldschmidtDiv(FHDLTestCase): def iter_cases(): if cases is not None: - yield from cases + for n, d in cases: + assert isinstance(d, int) \ + and 0 < d < (1 << params.io_width), "invalid case" + assert isinstance(n, int) \ + and 0 <= n < (d << params.io_width), "invalid case" + yield (n, d) return for d in range(1, 1 << io_width): for n in range(d << io_width): @@ -116,6 +121,45 @@ class TestGoldschmidtDiv(FHDLTestCase): yield dut.d.eq(d) yield Tick() + def check_interals(n, d): + # check internals only if dut is completely combinatorial + # so we don't have to figure out how to read values in + # previous clock cycles + if dut.total_pipeline_registers != 0: + return + ref_trace = [] + + def ref_trace_fn(state): + assert isinstance(state, GoldschmidtDivState) + ref_trace.append((replace(state))) + goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn) + self.assertEqual(len(dut.trace), len(ref_trace)) + for index, state in enumerate(dut.trace): + ref_state = ref_trace[index] + last_op = None if index == 0 else params.ops[index - 1] + with self.subTest(index=index, state=repr(state), + ref_state=repr(ref_state), + last_op=str(last_op)): + for field in fields(GoldschmidtDivHDLState): + sig = getattr(state, field.name) + if not isinstance(sig, Signal): + continue + ref_value = getattr(ref_state, field.name) + ref_value_str = repr(ref_value) + if isinstance(ref_value, int): + ref_value_str = hex(ref_value) + value = yield sig + with self.subTest(field_name=field.name, + sig=repr(sig), + sig_shape=repr(sig.shape()), + value=hex(value), + ref_value=ref_value_str): + if isinstance(ref_value, int): + self.assertEqual(value, ref_value) + else: + assert isinstance(ref_value, FixedPoint) + self.assertEqual(value, ref_value.bits) + def check_outputs(): yield Tick() for _ in range(dut.total_pipeline_registers): @@ -130,6 +174,8 @@ class TestGoldschmidtDiv(FHDLTestCase): r = yield dut.r with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) + yield from check_interals(n, d) + yield Tick() with self.subTest(params=str(params)):