change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
index 66345fe20c77f9505b3dfade84d45bff15d40d0e..28e795f4e8bdd54b93a3ec071ebb93599690a1b3 100644 (file)
@@ -4,7 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
-from dataclasses import fields, replace
+from nmutil.plain_data import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
@@ -13,7 +13,7 @@ from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivParams,
     GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
     FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
@@ -109,8 +109,13 @@ class TestGoldschmidtDiv(FHDLTestCase):
                 with self.subTest(n=hex(n), d=hex(d),
                                   expected_q=hex(expected_q),
                                   expected_r=hex(expected_r)):
-                    q, r = goldschmidt_div(n, d, params)
-                    with self.subTest(q=hex(q), r=hex(r)):
+                    trace = []
+
+                    def trace_fn(state):
+                        assert isinstance(state, GoldschmidtDivState)
+                        trace.append((replace(state)))
+                    q, r = goldschmidt_div(n, d, params, trace=trace_fn)
+                    with self.subTest(q=hex(q), r=hex(r), trace=repr(trace)):
                         self.assertEqual((q, r), (expected_q, expected_r))
 
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
@@ -151,15 +156,15 @@ class TestGoldschmidtDiv(FHDLTestCase):
                                   ref_state=repr(ref_state),
                                   last_op=str(last_op)):
                     for field in fields(GoldschmidtDivHDLState):
-                        sig = getattr(state, field.name)
+                        sig = getattr(state, field)
                         if not isinstance(sig, Signal):
                             continue
-                        ref_value = getattr(ref_state, field.name)
+                        ref_value = getattr(ref_state, field)
                         ref_value_str = repr(ref_value)
                         if isinstance(ref_value, int):
                             ref_value_str = hex(ref_value)
                         value = yield sig
-                        with self.subTest(field_name=field.name,
+                        with self.subTest(field_name=field,
                                           sig=repr(sig),
                                           sig_shape=repr(sig.shape()),
                                           value=hex(value),