From f9cc8a5bd4220f3ed4b739fbe460c5b94234e3bb Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 15 Aug 2022 23:43:13 -0700 Subject: [PATCH] change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses --- .../fu/div/experiment/goldschmidt_div_sqrt.py | 227 ++++++++++-------- .../test/test_goldschmidt_div_sqrt.py | 10 +- 2 files changed, 136 insertions(+), 101 deletions(-) diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index 6f739c33..3f7c2480 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -5,17 +5,17 @@ # of Horizon 2020 EU Programme 957073. from collections import defaultdict -from dataclasses import dataclass, field, fields, replace import logging import math import enum from fractions import Fraction from types import FunctionType from functools import lru_cache -from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat +from nmigen.hdl.ast import Signal, unsigned, signed, Const from nmigen.hdl.dsl import Module, Elaboratable from nmigen.hdl.mem import Memory from nmutil.clz import CLZ +from nmutil.plain_data import plain_data, fields, replace try: from functools import cached_property @@ -65,13 +65,13 @@ class RoundDir(enum.Enum): ERROR_IF_INEXACT = enum.auto() -@dataclass(frozen=True) +@plain_data(frozen=True, eq=False, repr=False) class FixedPoint: - bits: int - frac_wid: int + __slots__ = "bits", "frac_wid" - def __post_init__(self): - # called by the autogenerated __init__ + def __init__(self, bits, frac_wid): + self.bits = bits + self.frac_wid = frac_wid assert isinstance(self.bits, int) assert isinstance(self.frac_wid, int) and self.frac_wid >= 0 @@ -332,43 +332,47 @@ def _assert_accuracy(condition, msg="not accurate enough"): raise ParamsNotAccurateEnough(msg) -@dataclass(frozen=True, unsafe_hash=True) +@plain_data(frozen=True, unsafe_hash=True) class GoldschmidtDivParamsBase: """parameters for a Goldschmidt division algorithm, excluding derived parameters. """ - io_width: int - """bit-width of the input divisor and the result. - the input numerator is `2 * io_width`-bits wide. - """ + __slots__ = ("io_width", "extra_precision", "table_addr_bits", + "table_data_bits", "iter_count") + + def __init__(self, io_width, extra_precision, table_addr_bits, + table_data_bits, iter_count): + assert isinstance(io_width, int) + assert isinstance(extra_precision, int) + assert isinstance(table_addr_bits, int) + assert isinstance(table_data_bits, int) + assert isinstance(iter_count, int) + self.io_width = io_width + """bit-width of the input divisor and the result. + the input numerator is `2 * io_width`-bits wide. + """ - extra_precision: int - """number of bits of additional precision used inside the algorithm.""" + self.extra_precision = extra_precision + """number of bits of additional precision used inside the algorithm.""" - table_addr_bits: int - """the number of address bits used in the lookup-table.""" + self.table_addr_bits = table_addr_bits + """the number of address bits used in the lookup-table.""" - table_data_bits: int - """the number of data bits used in the lookup-table.""" + self.table_data_bits = table_data_bits + """the number of data bits used in the lookup-table.""" - iter_count: int - """the total number of iterations of the division algorithm's loop""" + self.iter_count = iter_count + """the total number of iterations of the division algorithm's loop""" -@dataclass(frozen=True, unsafe_hash=True) +@plain_data(frozen=True, unsafe_hash=True) class GoldschmidtDivParams(GoldschmidtDivParamsBase): """parameters for a Goldschmidt division algorithm. Use `GoldschmidtDivParams.get` to find a efficient set of parameters. """ - # tuple to be immutable, repr=False so repr() works for debugging even when - # __post_init__ hasn't finished running yet - table: "tuple[FixedPoint, ...]" = field(init=False, repr=False) - """the lookup-table""" - - ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False) - """the operations needed to perform the goldschmidt division algorithm.""" + __slots__ = "table", "ops" def _shrink_bound(self, bound, round_dir): """prevent fractions from having huge numerators/denominators by @@ -445,8 +449,13 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): # we round down return min_value - def __post_init__(self): - # called by the autogenerated __init__ + def __init__(self, io_width, extra_precision, table_addr_bits, + table_data_bits, iter_count): + super().__init__(io_width=io_width, + extra_precision=extra_precision, + table_addr_bits=table_addr_bits, + table_data_bits=table_data_bits, + iter_count=iter_count) _assert_accuracy(self.io_width >= 1, "io_width out of range") _assert_accuracy(self.extra_precision >= 0, "extra_precision out of range") @@ -460,9 +469,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr), self.table_data_bits, RoundDir.DOWN)) - # we have to use object.__setattr__ since frozen=True - object.__setattr__(self, "table", tuple(table)) - object.__setattr__(self, "ops", tuple(self.__make_ops())) + + self.table = tuple(table) + """ the lookup-table. + type: tuple[FixedPoint, ...] + """ + + self.ops = tuple(self.__make_ops()) + "the operations needed to perform the goldschmidt division algorithm." @property def expanded_width(self): @@ -800,11 +814,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase): @lru_cache(maxsize=1 << 16) def __cached_new(base_params): assert isinstance(base_params, GoldschmidtDivParamsBase) - # can't use dataclasses.asdict, since it's recursive and will also give - # child class fields too, which we don't want. kwargs = {} for field in fields(GoldschmidtDivParamsBase): - kwargs[field.name] = getattr(base_params, field.name) + kwargs[field] = getattr(base_params, field) try: return GoldschmidtDivParams(**kwargs), None except ParamsNotAccurateEnough as e: @@ -1139,44 +1151,57 @@ class GoldschmidtDivOp(enum.Enum): assert False, f"unimplemented GoldschmidtDivOp: {self}" -@dataclass +@plain_data(repr=False) class GoldschmidtDivState: - orig_n: int - """original numerator""" - - orig_d: int - """original denominator""" - - n: FixedPoint - """numerator -- N_prime[i] in the paper's algorithm 2""" - - d: FixedPoint - """denominator -- D_prime[i] in the paper's algorithm 2""" - - f: "FixedPoint | None" = None - """current factor -- F_prime[i] in the paper's algorithm 2""" - - quotient: "int | None" = None - """final quotient""" - - remainder: "int | None" = None - """final remainder""" - - n_shift: "int | None" = None - """amount the numerator needs to be left-shifted at the end of the - algorithm. - """ + __slots__ = ("orig_n", "orig_d", "n", "d", + "f", "quotient", "remainder", "n_shift") + + def __init__(self, orig_n, orig_d, n, d, + f=None, quotient=None, remainder=None, n_shift=None): + assert isinstance(orig_n, int) + assert isinstance(orig_d, int) + assert isinstance(n, FixedPoint) + assert isinstance(d, FixedPoint) + assert f is None or isinstance(f, FixedPoint) + assert quotient is None or isinstance(quotient, int) + assert remainder is None or isinstance(remainder, int) + assert n_shift is None or isinstance(n_shift, int) + self.orig_n = orig_n + """original numerator""" + + self.orig_d = orig_d + """original denominator""" + + self.n = n + """numerator -- N_prime[i] in the paper's algorithm 2""" + + self.d = d + """denominator -- D_prime[i] in the paper's algorithm 2""" + + self.f = f + """current factor -- F_prime[i] in the paper's algorithm 2""" + + self.quotient = quotient + """final quotient""" + + self.remainder = remainder + """final remainder""" + + self.n_shift = n_shift + """amount the numerator needs to be left-shifted at the end of the + algorithm. + """ def __repr__(self): fields_str = [] for field in fields(GoldschmidtDivState): - value = getattr(self, field.name) + value = getattr(self, field) if value is None: continue - if isinstance(value, int) and field.name != "n_shift": - fields_str.append(f"{field.name}={hex(value)}") + if isinstance(value, int) and field != "n_shift": + fields_str.append(f"{field}={hex(value)}") else: - fields_str.append(f"{field.name}={value!r}") + fields_str.append(f"{field}={value!r}") return f"GoldschmidtDivState({', '.join(fields_str)})" @@ -1230,45 +1255,55 @@ def goldschmidt_div(n, d, params, trace=lambda state: None): return state.quotient, state.remainder -@dataclass(eq=False) +@plain_data(eq=False) class GoldschmidtDivHDLState: - m: Module - """The HDL Module""" + __slots__ = ("m", "orig_n", "orig_d", "n", "d", + "f", "quotient", "remainder", "n_shift") - orig_n: Signal - """original numerator""" + __signal_name_prefix = "state_" - orig_d: Signal - """original denominator""" + def __init__(self, m, orig_n, orig_d, n, d, + f=None, quotient=None, remainder=None, n_shift=None): + assert isinstance(m, Module) + assert isinstance(orig_n, Signal) + assert isinstance(orig_d, Signal) + assert isinstance(n, Signal) + assert isinstance(d, Signal) + assert f is None or isinstance(f, Signal) + assert quotient is None or isinstance(quotient, Signal) + assert remainder is None or isinstance(remainder, Signal) + assert n_shift is None or isinstance(n_shift, Signal) - n: Signal - """numerator -- N_prime[i] in the paper's algorithm 2""" + self.m = m + """The HDL Module""" - d: Signal - """denominator -- D_prime[i] in the paper's algorithm 2""" + self.orig_n = orig_n + """original numerator""" - f: "Signal | None" = None - """current factor -- F_prime[i] in the paper's algorithm 2""" + self.orig_d = orig_d + """original denominator""" - quotient: "Signal | None" = None - """final quotient""" + self.n = n + """numerator -- N_prime[i] in the paper's algorithm 2""" - remainder: "Signal | None" = None - """final remainder""" + self.d = d + """denominator -- D_prime[i] in the paper's algorithm 2""" - n_shift: "Signal | None" = None - """amount the numerator needs to be left-shifted at the end of the - algorithm. - """ + self.f = f + """current factor -- F_prime[i] in the paper's algorithm 2""" - old_signals: "defaultdict[str, list[Signal]]" = field(repr=False, - init=False) + self.quotient = quotient + """final quotient""" - __signal_name_prefix: "str" = field(default="state_", repr=False, - init=False) + self.remainder = remainder + """final remainder""" + + self.n_shift = n_shift + """amount the numerator needs to be left-shifted at the end of the + algorithm. + """ - def __post_init__(self): - # called by the autogenerated __init__ + # old_signals must be set last self.old_signals = defaultdict(list) def __setattr__(self, name, value): @@ -1293,14 +1328,14 @@ class GoldschmidtDivHDLState: old_prefix = self.__signal_name_prefix try: for field in fields(GoldschmidtDivHDLState): - if field.name.startswith("_") or field.name == "m": + if field.startswith("_") or field == "m": continue - old_sig = getattr(self, field.name, None) + old_sig = getattr(self, field, None) if old_sig is None: continue assert isinstance(old_sig, Signal) new_sig = Signal.like(old_sig) - setattr(self, field.name, new_sig) + setattr(self, field, new_sig) self.m.d.sync += new_sig.eq(old_sig) finally: self.__signal_name_prefix = old_prefix diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index bf999bd8..28e795f4 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -4,7 +4,7 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. -from dataclasses import fields, replace +from nmutil.plain_data import fields, replace import math import unittest from nmutil.formaltest import FHDLTestCase @@ -13,7 +13,7 @@ from nmigen.sim import Tick, Delay from nmigen.hdl.ast import Signal from nmigen.hdl.dsl import Module from soc.fu.div.experiment.goldschmidt_div_sqrt import ( - GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams, + GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivParams, GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) @@ -156,15 +156,15 @@ class TestGoldschmidtDiv(FHDLTestCase): ref_state=repr(ref_state), last_op=str(last_op)): for field in fields(GoldschmidtDivHDLState): - sig = getattr(state, field.name) + sig = getattr(state, field) if not isinstance(sig, Signal): continue - ref_value = getattr(ref_state, field.name) + ref_value = getattr(ref_state, field) ref_value_str = repr(ref_value) if isinstance(ref_value, int): ref_value_str = hex(ref_value) value = yield sig - with self.subTest(field_name=field.name, + with self.subTest(field_name=field, sig=repr(sig), sig_shape=repr(sig.shape()), value=hex(value), -- 2.30.2