HDL works for io_width=5
[soc.git] / src / soc / fu / div / experiment / goldschmidt_div_sqrt.py
index ea0ddda0763c6847336589ec38cc8b5426013ba5..a86fa78d111e3957b31e45a645b0ba8ac5e85174 100644 (file)
@@ -12,7 +12,7 @@ import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, Mux, signed
+from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
@@ -937,7 +937,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             else:
                 break
 
-        return cached_new(params)
+        retval = cached_new(params)
+        assert isinstance(retval, GoldschmidtDivParams)
+        return retval
 
 
 def clz(v, wid):
@@ -979,6 +981,8 @@ class GoldschmidtDivOp(enum.Enum):
             d_m_1 = d_m_1.to_frac_wid(table_addr_bits, RoundDir.DOWN)
             assert 0 <= d_m_1.bits < (1 << params.table_addr_bits)
             state.f = params.table[d_m_1.bits]
+            state.f = state.f.to_frac_wid(expanded_width,
+                                          round_dir=RoundDir.DOWN)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.f is not None
             n = state.n * state.f
@@ -1003,7 +1007,6 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
     def gen_hdl(self, params, state, sync_rom):
-        # FIXME: finish getting hdl/simulation to work
         """generate the hdl for this operation.
 
         arguments:
@@ -1018,8 +1021,6 @@ class GoldschmidtDivOp(enum.Enum):
         assert isinstance(params, GoldschmidtDivParams)
         assert isinstance(state, GoldschmidtDivHDLState)
         m = state.m
-        expanded_width = params.expanded_width
-        table_addr_bits = params.table_addr_bits
         if self == GoldschmidtDivOp.Normalize:
             # normalize so 1 <= d < 2
             assert state.d.width == params.io_width
@@ -1029,27 +1030,41 @@ class GoldschmidtDivOp(enum.Enum):
             m.d.comb += d_leading_zeros.sig_in.eq(state.d)
             d_shift_out = Signal.like(state.d)
             m.d.comb += d_shift_out.eq(state.d << d_leading_zeros.lz)
-            state.d = Signal(params.n_d_f_total_wid)
-            m.d.comb += state.d.eq(d_shift_out << (params.extra_precision
-                                                   + params.n_d_f_int_wid))
+            d = Signal(params.n_d_f_total_wid)
+            m.d.comb += d.eq((d_shift_out << (1 + params.expanded_width))
+                             >> state.d.width)
 
             # normalize so 1 <= n < 2
             n_leading_zeros = CLZ(2 * params.io_width)
             m.submodules.n_leading_zeros = n_leading_zeros
             m.d.comb += n_leading_zeros.sig_in.eq(state.n)
-            n_shift_s_v = (params.io_width + d_leading_zeros.lz
+            signed_zero = Const(0, signed(1))  # force subtraction to be signed
+            n_shift_s_v = (params.io_width + signed_zero + d_leading_zeros.lz
                            - n_leading_zeros.lz)
             n_shift_s = Signal.like(n_shift_s_v)
-            state.n_shift = Signal(d_leading_zeros.lz.width)
+            n_shift_n_lz_out = Signal.like(state.n)
+            n_shift_d_lz_out = Signal.like(state.n << d_leading_zeros.lz)
             m.d.comb += [
                 n_shift_s.eq(n_shift_s_v),
-                state.n_shift.eq(Mux(n_shift_s < 0, 0, n_shift_s)),
+                n_shift_d_lz_out.eq(state.n << d_leading_zeros.lz),
+                n_shift_n_lz_out.eq(state.n << n_leading_zeros.lz),
             ]
+            state.n_shift = Signal(d_leading_zeros.lz.width)
             n = Signal(params.n_d_f_total_wid)
-            shifted_n = state.n << state.n_shift
-            fixed_shift = params.expanded_width - state.n.width
-            m.d.comb += n.eq(shifted_n << fixed_shift)
+            with m.If(n_shift_s < 0):
+                m.d.comb += [
+                    state.n_shift.eq(0),
+                    n.eq((n_shift_d_lz_out << (1 + params.expanded_width))
+                         >> state.d.width),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    state.n_shift.eq(n_shift_s),
+                    n.eq((n_shift_n_lz_out << (1 + params.expanded_width))
+                         >> state.n.width),
+                ]
             state.n = n
+            state.d = d
         elif self == GoldschmidtDivOp.FEqTableLookup:
             assert state.d.width == params.n_d_f_total_wid, "invalid d width"
             # compute initial f by table lookup
@@ -1058,7 +1073,7 @@ class GoldschmidtDivOp(enum.Enum):
             table_width = 1 + params.table_data_bits
             table = Memory(width=table_width, depth=len(params.table),
                            init=[i.bits for i in params.table])
-            addr = state.d[:-params.n_d_f_int_wid][-table_addr_bits:]
+            addr = state.d[:-params.n_d_f_int_wid][-params.table_addr_bits:]
             if sync_rom:
                 table_read = table.read_port()
                 m.d.comb += table_read.addr.eq(addr)
@@ -1068,8 +1083,7 @@ class GoldschmidtDivOp(enum.Enum):
                 m.d.comb += table_read.addr.eq(addr)
             m.submodules.table_read = table_read
             state.f = Signal(params.n_d_f_int_wid + params.expanded_width)
-            data_shift = (table_width - params.table_data_bits
-                          + params.expanded_width)
+            data_shift = params.expanded_width - params.table_data_bits
             m.d.comb += state.f.eq(table_read.data << data_shift)
         elif self == GoldschmidtDivOp.MulNByF:
             assert state.n.width == params.n_d_f_total_wid, "invalid n width"
@@ -1150,8 +1164,20 @@ class GoldschmidtDivState:
     algorithm.
     """
 
+    def __repr__(self):
+        fields_str = []
+        for field in fields(GoldschmidtDivState):
+            value = getattr(self, field.name)
+            if value is None:
+                continue
+            if isinstance(value, int) and field.name != "n_shift":
+                fields_str.append(f"{field.name}={hex(value)}")
+            else:
+                fields_str.append(f"{field.name}={value!r}")
+        return f"GoldschmidtDivState({', '.join(fields_str)})"
+
 
-def goldschmidt_div(n, d, params):
+def goldschmidt_div(n, d, params, trace=lambda state: None):
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1168,6 +1194,9 @@ def goldschmidt_div(n, d, params):
             denominator. a `width`-bit unsigned integer. must not be zero.
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
+        trace: Function[[GoldschmidtDivState], None]
+            called with the initial state and the state after executing each
+            operation in `params.ops`.
 
         returns: tuple[int, int]
             the quotient and remainder. a tuple of two `width`-bit unsigned
@@ -1187,8 +1216,10 @@ def goldschmidt_div(n, d, params):
         d=FixedPoint(d, params.io_width),
     )
 
+    trace(state)
     for op in params.ops:
         op.run(params, state)
+        trace(state)
 
     assert state.quotient is not None
     assert state.remainder is not None
@@ -1273,7 +1304,6 @@ class GoldschmidtDivHDLState:
 
 
 class GoldschmidtDivHDL(Elaboratable):
-    # FIXME: finish getting hdl/simulation to work
     """ Goldschmidt division algorithm.
 
         based on:
@@ -1303,6 +1333,9 @@ class GoldschmidtDivHDL(Elaboratable):
             output quotient. only valid when `n < (d << params.io_width)`.
         r: Signal(unsigned(params.io_width))
             output remainder. only valid when `n < (d << params.io_width)`.
+        trace: list[GoldschmidtDivHDLState]
+            list of the initial state and the state after executing each
+            operation in `params.ops`.
     """
 
     @property
@@ -1321,15 +1354,16 @@ class GoldschmidtDivHDL(Elaboratable):
         self.q = Signal(unsigned(params.io_width))
         self.r = Signal(unsigned(params.io_width))
 
-    def elaborate(self, platform):
-        m = Module()
+        # in constructor so we get trace without needing to call elaborate
         state = GoldschmidtDivHDLState(
-            m=m,
+            m=Module(),
             orig_n=self.n,
             orig_d=self.d,
             n=self.n,
             d=self.d)
 
+        self.trace = [replace(state)]
+
         # copy and reverse
         pipe_reg_indexes = list(reversed(self.pipe_reg_indexes))
 
@@ -1339,14 +1373,19 @@ class GoldschmidtDivHDL(Elaboratable):
                 pipe_reg_indexes.pop()
                 state.insert_pipeline_register()
             op.gen_hdl(self.params, state, self.sync_rom)
+            self.trace.append(replace(state))
 
         while len(pipe_reg_indexes) > 0:
             pipe_reg_indexes.pop()
             state.insert_pipeline_register()
 
-        m.d.comb += self.q.eq(state.quotient)
-        m.d.comb += self.r.eq(state.remainder)
-        return m
+        state.m.d.comb += [
+            self.q.eq(state.quotient),
+            self.r.eq(state.remainder),
+        ]
+
+    def elaborate(self, platform):
+        return self.trace[0].m
 
 
 GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2