change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 06:43:13 +0000 (23:43 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 06:43:13 +0000 (23:43 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index 6f739c33db8d6a97f55015b795709c5b4ee34b2e..3f7c2480742d6913859461da120099385f99d18a 100644 (file)
@@ -5,17 +5,17 @@
 # of Horizon 2020 EU Programme 957073.
 
 from collections import defaultdict
 # of Horizon 2020 EU Programme 957073.
 
 from collections import defaultdict
-from dataclasses import dataclass, field, fields, replace
 import logging
 import math
 import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
 import logging
 import math
 import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
+from nmigen.hdl.ast import Signal, unsigned, signed, Const
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
+from nmutil.plain_data import plain_data, fields, replace
 
 try:
     from functools import cached_property
 
 try:
     from functools import cached_property
@@ -65,13 +65,13 @@ class RoundDir(enum.Enum):
     ERROR_IF_INEXACT = enum.auto()
 
 
     ERROR_IF_INEXACT = enum.auto()
 
 
-@dataclass(frozen=True)
+@plain_data(frozen=True, eq=False, repr=False)
 class FixedPoint:
 class FixedPoint:
-    bits: int
-    frac_wid: int
+    __slots__ = "bits", "frac_wid"
 
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+    def __init__(self, bits, frac_wid):
+        self.bits = bits
+        self.frac_wid = frac_wid
         assert isinstance(self.bits, int)
         assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
 
         assert isinstance(self.bits, int)
         assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
 
@@ -332,43 +332,47 @@ def _assert_accuracy(condition, msg="not accurate enough"):
     raise ParamsNotAccurateEnough(msg)
 
 
     raise ParamsNotAccurateEnough(msg)
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParamsBase:
     """parameters for a Goldschmidt division algorithm, excluding derived
     parameters.
     """
 
 class GoldschmidtDivParamsBase:
     """parameters for a Goldschmidt division algorithm, excluding derived
     parameters.
     """
 
-    io_width: int
-    """bit-width of the input divisor and the result.
-    the input numerator is `2 * io_width`-bits wide.
-    """
+    __slots__ = ("io_width", "extra_precision", "table_addr_bits",
+                 "table_data_bits", "iter_count")
+
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        assert isinstance(io_width, int)
+        assert isinstance(extra_precision, int)
+        assert isinstance(table_addr_bits, int)
+        assert isinstance(table_data_bits, int)
+        assert isinstance(iter_count, int)
+        self.io_width = io_width
+        """bit-width of the input divisor and the result.
+        the input numerator is `2 * io_width`-bits wide.
+        """
 
 
-    extra_precision: int
-    """number of bits of additional precision used inside the algorithm."""
+        self.extra_precision = extra_precision
+        """number of bits of additional precision used inside the algorithm."""
 
 
-    table_addr_bits: int
-    """the number of address bits used in the lookup-table."""
+        self.table_addr_bits = table_addr_bits
+        """the number of address bits used in the lookup-table."""
 
 
-    table_data_bits: int
-    """the number of data bits used in the lookup-table."""
+        self.table_data_bits = table_data_bits
+        """the number of data bits used in the lookup-table."""
 
 
-    iter_count: int
-    """the total number of iterations of the division algorithm's loop"""
+        self.iter_count = iter_count
+        """the total number of iterations of the division algorithm's loop"""
 
 
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     """parameters for a Goldschmidt division algorithm.
     Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
     """
 
 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     """parameters for a Goldschmidt division algorithm.
     Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
     """
 
-    # tuple to be immutable, repr=False so repr() works for debugging even when
-    # __post_init__ hasn't finished running yet
-    table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
-    """the lookup-table"""
-
-    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
-    """the operations needed to perform the goldschmidt division algorithm."""
+    __slots__ = "table", "ops"
 
     def _shrink_bound(self, bound, round_dir):
         """prevent fractions from having huge numerators/denominators by
 
     def _shrink_bound(self, bound, round_dir):
         """prevent fractions from having huge numerators/denominators by
@@ -445,8 +449,13 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         # we round down
         return min_value
 
         # we round down
         return min_value
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        super().__init__(io_width=io_width,
+                         extra_precision=extra_precision,
+                         table_addr_bits=table_addr_bits,
+                         table_data_bits=table_data_bits,
+                         iter_count=iter_count)
         _assert_accuracy(self.io_width >= 1, "io_width out of range")
         _assert_accuracy(self.extra_precision >= 0,
                          "extra_precision out of range")
         _assert_accuracy(self.io_width >= 1, "io_width out of range")
         _assert_accuracy(self.extra_precision >= 0,
                          "extra_precision out of range")
@@ -460,9 +469,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
                                                   self.table_data_bits,
                                                   RoundDir.DOWN))
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
                                                   self.table_data_bits,
                                                   RoundDir.DOWN))
-        # we have to use object.__setattr__ since frozen=True
-        object.__setattr__(self, "table", tuple(table))
-        object.__setattr__(self, "ops", tuple(self.__make_ops()))
+
+        self.table = tuple(table)
+        """ the lookup-table.
+        type: tuple[FixedPoint, ...]
+        """
+
+        self.ops = tuple(self.__make_ops())
+        "the operations needed to perform the goldschmidt division algorithm."
 
     @property
     def expanded_width(self):
 
     @property
     def expanded_width(self):
@@ -800,11 +814,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     @lru_cache(maxsize=1 << 16)
     def __cached_new(base_params):
         assert isinstance(base_params, GoldschmidtDivParamsBase)
     @lru_cache(maxsize=1 << 16)
     def __cached_new(base_params):
         assert isinstance(base_params, GoldschmidtDivParamsBase)
-        # can't use dataclasses.asdict, since it's recursive and will also give
-        # child class fields too, which we don't want.
         kwargs = {}
         for field in fields(GoldschmidtDivParamsBase):
         kwargs = {}
         for field in fields(GoldschmidtDivParamsBase):
-            kwargs[field.name] = getattr(base_params, field.name)
+            kwargs[field] = getattr(base_params, field)
         try:
             return GoldschmidtDivParams(**kwargs), None
         except ParamsNotAccurateEnough as e:
         try:
             return GoldschmidtDivParams(**kwargs), None
         except ParamsNotAccurateEnough as e:
@@ -1139,44 +1151,57 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
 
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
 
-@dataclass
+@plain_data(repr=False)
 class GoldschmidtDivState:
 class GoldschmidtDivState:
-    orig_n: int
-    """original numerator"""
-
-    orig_d: int
-    """original denominator"""
-
-    n: FixedPoint
-    """numerator -- N_prime[i] in the paper's algorithm 2"""
-
-    d: FixedPoint
-    """denominator -- D_prime[i] in the paper's algorithm 2"""
-
-    f: "FixedPoint | None" = None
-    """current factor -- F_prime[i] in the paper's algorithm 2"""
-
-    quotient: "int | None" = None
-    """final quotient"""
-
-    remainder: "int | None" = None
-    """final remainder"""
-
-    n_shift: "int | None" = None
-    """amount the numerator needs to be left-shifted at the end of the
-    algorithm.
-    """
+    __slots__ = ("orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
+
+    def __init__(self, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(orig_n, int)
+        assert isinstance(orig_d, int)
+        assert isinstance(n, FixedPoint)
+        assert isinstance(d, FixedPoint)
+        assert f is None or isinstance(f, FixedPoint)
+        assert quotient is None or isinstance(quotient, int)
+        assert remainder is None or isinstance(remainder, int)
+        assert n_shift is None or isinstance(n_shift, int)
+        self.orig_n = orig_n
+        """original numerator"""
+
+        self.orig_d = orig_d
+        """original denominator"""
+
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+        self.quotient = quotient
+        """final quotient"""
+
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
 
     def __repr__(self):
         fields_str = []
         for field in fields(GoldschmidtDivState):
 
     def __repr__(self):
         fields_str = []
         for field in fields(GoldschmidtDivState):
-            value = getattr(self, field.name)
+            value = getattr(self, field)
             if value is None:
                 continue
             if value is None:
                 continue
-            if isinstance(value, int) and field.name != "n_shift":
-                fields_str.append(f"{field.name}={hex(value)}")
+            if isinstance(value, int) and field != "n_shift":
+                fields_str.append(f"{field}={hex(value)}")
             else:
             else:
-                fields_str.append(f"{field.name}={value!r}")
+                fields_str.append(f"{field}={value!r}")
         return f"GoldschmidtDivState({', '.join(fields_str)})"
 
 
         return f"GoldschmidtDivState({', '.join(fields_str)})"
 
 
@@ -1230,45 +1255,55 @@ def goldschmidt_div(n, d, params, trace=lambda state: None):
     return state.quotient, state.remainder
 
 
     return state.quotient, state.remainder
 
 
-@dataclass(eq=False)
+@plain_data(eq=False)
 class GoldschmidtDivHDLState:
 class GoldschmidtDivHDLState:
-    m: Module
-    """The HDL Module"""
+    __slots__ = ("m", "orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
 
 
-    orig_n: Signal
-    """original numerator"""
+    __signal_name_prefix = "state_"
 
 
-    orig_d: Signal
-    """original denominator"""
+    def __init__(self, m, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(m, Module)
+        assert isinstance(orig_n, Signal)
+        assert isinstance(orig_d, Signal)
+        assert isinstance(n, Signal)
+        assert isinstance(d, Signal)
+        assert f is None or isinstance(f, Signal)
+        assert quotient is None or isinstance(quotient, Signal)
+        assert remainder is None or isinstance(remainder, Signal)
+        assert n_shift is None or isinstance(n_shift, Signal)
 
 
-    n: Signal
-    """numerator -- N_prime[i] in the paper's algorithm 2"""
+        self.m = m
+        """The HDL Module"""
 
 
-    d: Signal
-    """denominator -- D_prime[i] in the paper's algorithm 2"""
+        self.orig_n = orig_n
+        """original numerator"""
 
 
-    f: "Signal | None" = None
-    """current factor -- F_prime[i] in the paper's algorithm 2"""
+        self.orig_d = orig_d
+        """original denominator"""
 
 
-    quotient: "Signal | None" = None
-    """final quotient"""
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
 
 
-    remainder: "Signal | None" = None
-    """final remainder"""
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
 
 
-    n_shift: "Signal | None" = None
-    """amount the numerator needs to be left-shifted at the end of the
-    algorithm.
-    """
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
 
 
-    old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
-                                                          init=False)
+        self.quotient = quotient
+        """final quotient"""
 
 
-    __signal_name_prefix: "str" = field(default="state_", repr=False,
-                                        init=False)
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
 
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+        # old_signals must be set last
         self.old_signals = defaultdict(list)
 
     def __setattr__(self, name, value):
         self.old_signals = defaultdict(list)
 
     def __setattr__(self, name, value):
@@ -1293,14 +1328,14 @@ class GoldschmidtDivHDLState:
         old_prefix = self.__signal_name_prefix
         try:
             for field in fields(GoldschmidtDivHDLState):
         old_prefix = self.__signal_name_prefix
         try:
             for field in fields(GoldschmidtDivHDLState):
-                if field.name.startswith("_") or field.name == "m":
+                if field.startswith("_") or field == "m":
                     continue
                     continue
-                old_sig = getattr(self, field.name, None)
+                old_sig = getattr(self, field, None)
                 if old_sig is None:
                     continue
                 assert isinstance(old_sig, Signal)
                 new_sig = Signal.like(old_sig)
                 if old_sig is None:
                     continue
                 assert isinstance(old_sig, Signal)
                 new_sig = Signal.like(old_sig)
-                setattr(self, field.name, new_sig)
+                setattr(self, field, new_sig)
                 self.m.d.sync += new_sig.eq(old_sig)
         finally:
             self.__signal_name_prefix = old_prefix
                 self.m.d.sync += new_sig.eq(old_sig)
         finally:
             self.__signal_name_prefix = old_prefix
index bf999bd850237387120691f16bcda7eb308e21df..28e795f4e8bdd54b93a3ec071ebb93599690a1b3 100644 (file)
@@ -4,7 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
-from dataclasses import fields, replace
+from nmutil.plain_data import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
@@ -13,7 +13,7 @@ from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivParams,
     GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
     FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
     GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
     FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
@@ -156,15 +156,15 @@ class TestGoldschmidtDiv(FHDLTestCase):
                                   ref_state=repr(ref_state),
                                   last_op=str(last_op)):
                     for field in fields(GoldschmidtDivHDLState):
                                   ref_state=repr(ref_state),
                                   last_op=str(last_op)):
                     for field in fields(GoldschmidtDivHDLState):
-                        sig = getattr(state, field.name)
+                        sig = getattr(state, field)
                         if not isinstance(sig, Signal):
                             continue
                         if not isinstance(sig, Signal):
                             continue
-                        ref_value = getattr(ref_state, field.name)
+                        ref_value = getattr(ref_state, field)
                         ref_value_str = repr(ref_value)
                         if isinstance(ref_value, int):
                             ref_value_str = hex(ref_value)
                         value = yield sig
                         ref_value_str = repr(ref_value)
                         if isinstance(ref_value, int):
                             ref_value_str = hex(ref_value)
                         value = yield sig
-                        with self.subTest(field_name=field.name,
+                        with self.subTest(field_name=field,
                                           sig=repr(sig),
                                           sig_shape=repr(sig.shape()),
                                           value=hex(value),
                                           sig=repr(sig),
                                           sig_shape=repr(sig.shape()),
                                           value=hex(value),