test_core.py doesn't crash anymore
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / core.py
index 3c76f59b7078449d24ea5f50062fc14242f9c7ed..141deb7365456d727952588c6e8c848c91652625 100644 (file)
@@ -18,12 +18,9 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal, Const, Mux)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux, Cat)
 import enum
 
-# TODO, move to new (suitable) location
-#from ieee754.fpcommon.getop import FPPipeContext
-
 
 class DivPipeCoreConfig:
     """ Configuration for core of the div/rem/sqrt/rsqrt pipeline.
@@ -52,7 +49,7 @@ class DivPipeCoreConfig:
         return (self.bit_width + self.log2_radix - 1) // self.log2_radix
 
 
-class DivPipeCoreOperation(enum.IntEnum):
+class DivPipeCoreOperation(enum.Enum):
     """ Operation for ``DivPipeCore``.
 
     :attribute UDivRem: unsigned divide/remainder.
@@ -64,41 +61,20 @@ class DivPipeCoreOperation(enum.IntEnum):
     SqrtRem = 1
     RSqrtRem = 2
 
+    def __int__(self):
+        """ Convert to int. """
+        return self.value
+
     @classmethod
     def create_signal(cls, *, src_loc_at=0, **kwargs):
         """ Create a signal that can contain a ``DivPipeCoreOperation``. """
-        return Signal(min=int(min(cls)),
-                      max=int(max(cls)),
+        return Signal(min=min(map(int, cls)),
+                      max=max(map(int, cls)),
                       src_loc_at=(src_loc_at + 1),
-                      decoder=cls,
+                      decoder=lambda v: str(cls(v)),
                       **kwargs)
 
 
-# TODO: move to suitable location
-class DivPipeBaseData:
-    """ input data base type for ``DivPipe``.
-    """
-
-    def __init__(self, width, pspec):
-        """ Create a ``DivPipeBaseData`` instance. """
-        self.out_do_z = Signal(reset_less=True)
-        self.oz = Signal(width, reset_less=True)
-
-        self.ctx = FPPipeContext(width, pspec)  # context: muxid, operator etc.
-        self.muxid = self.ctx.muxid             # annoying. complicated.
-
-    def __iter__(self):
-        """ Get member signals. """
-        yield self.out_do_z
-        yield self.oz
-        yield from self.ctx
-
-    def eq(self, rhs):
-        """ Assign member signals. """
-        return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
-                self.ctx.eq(i.ctx)]
-
-
 class DivPipeCoreInputData:
     """ input data type for ``DivPipeCore``.
 
@@ -113,59 +89,30 @@ class DivPipeCoreInputData:
     :attribute operation: the ``DivPipeCoreOperation`` to be computed.
     """
 
-    def __init__(self, core_config):
+    def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInputData`` instance. """
         self.core_config = core_config
         self.dividend = Signal(core_config.bit_width + core_config.fract_width,
-                               reset_less=True)
-        self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
+                               reset_less=reset_less)
+        self.divisor_radicand = Signal(core_config.bit_width,
+                                       reset_less=reset_less)
 
         # FIXME: this goes into (is replaced by) self.ctx.op
-        self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
+        self.operation = \
+            DivPipeCoreOperation.create_signal(reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
         yield self.dividend
         yield self.divisor_radicand
         yield self.operation  # FIXME: delete.  already covered by self.ctx
-        return
-        yield self.z
-        yield self.out_do_z
-        yield self.oz
-        yield from self.ctx
 
     def eq(self, rhs):
         """ Assign member signals. """
         return [self.dividend.eq(rhs.dividend),
                 self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation)]  # FIXME: delete.
-
-
-# TODO: move to suitable location
-class DivPipeInputData(DivPipeCoreInputData, DivPipeBaseData):
-    """ input data type for ``DivPipe``.
-    """
-
-    def __init__(self, core_config):
-        """ Create a ``DivPipeInputData`` instance. """
-        DivPipeCoreInputData.__init__(self, core_config)
-        DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
-        self.out_do_z = Signal(reset_less=True)
-        self.oz = Signal(width, reset_less=True)
-
-        self.ctx = FPPipeContext(width, pspec)  # context: muxid, operator etc.
-        self.muxid = self.ctx.muxid             # annoying. complicated.
-
-    def __iter__(self):
-        """ Get member signals. """
-        yield from DivPipeCoreInputData.__iter__(self)
-        yield from DivPipeBaseData.__iter__(self)
-
-    def eq(self, rhs):
-        """ Assign member signals. """
-        return DivPipeBaseData.eq(self, rhs) + \
-               DivPipeCoreInputData.eq(self, rhs)
-
+                self.operation.eq(rhs.operation),  # FIXME: delete.
+                ]
 
 
 class DivPipeCoreInterstageData:
@@ -193,22 +140,27 @@ class DivPipeCoreInterstageData:
         ``core_config.fract_width * 3`` bits.
     """
 
-    def __init__(self, core_config):
+    def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreInterstageData`` instance. """
         self.core_config = core_config
-        self.divisor_radicand = Signal(core_config.bit_width, reset_less=True)
-        # XXX FIXME: delete.  already covered by self.ctx.op
-        self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
-        self.quotient_root = Signal(core_config.bit_width, reset_less=True)
+        self.divisor_radicand = Signal(core_config.bit_width,
+                                       reset_less=reset_less)
+        # FIXME: delete self.operation.  already covered by self.ctx.op
+        self.operation = \
+            DivPipeCoreOperation.create_signal(reset_less=reset_less)
+        self.quotient_root = Signal(core_config.bit_width,
+                                    reset_less=reset_less)
         self.root_times_radicand = Signal(core_config.bit_width * 2,
-                                          reset_less=True)
-        self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True)
-        self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True)
+                                          reset_less=reset_less)
+        self.compare_lhs = Signal(core_config.bit_width * 3,
+                                  reset_less=reset_less)
+        self.compare_rhs = Signal(core_config.bit_width * 3,
+                                  reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
         yield self.divisor_radicand
-        yield self.operation  # XXX FIXME: delete.  already in self.ctx.op
+        yield self.operation  # FIXME: delete.  already in self.ctx.op
         yield self.quotient_root
         yield self.root_times_radicand
         yield self.compare_lhs
@@ -224,30 +176,6 @@ class DivPipeCoreInterstageData:
                 self.compare_rhs.eq(rhs.compare_rhs)]
 
 
-# TODO: move to suitable location
-class DivPipeInterstageData(DivPipeCoreInterstageData, DivPipeBaseData):
-    """ interstage data type for ``DivPipe``.
-
-    :attribute core_config: ``DivPipeCoreConfig`` instance describing the
-        configuration to be used.
-    """
-
-    def __init__(self, core_config):
-        """ Create a ``DivPipeCoreInterstageData`` instance. """
-        DivPipeCoreInterstageData.__init__(self, core_config)
-        DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
-
-    def __iter__(self):
-        """ Get member signals. """
-        yield from DivPipeInterstageData.__iter__(self)
-        yield from DivPipeBaseData.__iter__(self)
-
-    def eq(self, rhs):
-        """ Assign member signals. """
-        return DivPipeBaseData.eq(self, rhs) + \
-               DivPipeCoreInterstageData.eq(self, rhs)
-
-
 class DivPipeCoreOutputData:
     """ output data type for ``DivPipeCore``.
 
@@ -261,11 +189,13 @@ class DivPipeCoreOutputData:
         fract-width of ``core_config.fract_width * 3`` bits.
     """
 
-    def __init__(self, core_config):
+    def __init__(self, core_config, reset_less=True):
         """ Create a ``DivPipeCoreOutputData`` instance. """
         self.core_config = core_config
-        self.quotient_root = Signal(core_config.bit_width, reset_less=True)
-        self.remainder = Signal(core_config.bit_width * 3, reset_less=True)
+        self.quotient_root = Signal(core_config.bit_width,
+                                    reset_less=reset_less)
+        self.remainder = Signal(core_config.bit_width * 3,
+                                reset_less=reset_less)
 
     def __iter__(self):
         """ Get member signals. """
@@ -279,30 +209,6 @@ class DivPipeCoreOutputData:
                 self.remainder.eq(rhs.remainder)]
 
 
-# TODO: move to suitable location
-class DivPipeOutputData(DivPipeCoreOutputData, DivPipeBaseData):
-    """ interstage data type for ``DivPipe``.
-
-    :attribute core_config: ``DivPipeCoreConfig`` instance describing the
-        configuration to be used.
-    """
-
-    def __init__(self, core_config):
-        """ Create a ``DivPipeCoreOutputData`` instance. """
-        DivPipeCoreOutputData.__init__(self, core_config)
-        DivPipeBaseData.__init__(self, width, pspec) # XXX TODO args
-
-    def __iter__(self):
-        """ Get member signals. """
-        yield from DivPipeOutputData.__iter__(self)
-        yield from DivPipeBaseData.__iter__(self)
-
-    def eq(self, rhs):
-        """ Assign member signals. """
-        return DivPipeBaseData.eq(self, rhs) + \
-               DivPipeCoreOutputData.eq(self, rhs)
-
-
 class DivPipeCoreSetupStage(Elaboratable):
     """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
 
@@ -337,10 +243,10 @@ class DivPipeCoreSetupStage(Elaboratable):
         m.d.comb += self.o.quotient_root.eq(0)
         m.d.comb += self.o.root_times_radicand.eq(0)
 
-        with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+        with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
             m.d.comb += self.o.compare_lhs.eq(self.i.dividend
                                               << self.core_config.fract_width)
-        with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+        with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
             m.d.comb += self.o.compare_lhs.eq(
                 self.i.divisor_radicand << (self.core_config.fract_width * 2))
         with m.Else():  # DivPipeCoreOperation.RSqrtRem
@@ -352,11 +258,6 @@ class DivPipeCoreSetupStage(Elaboratable):
 
         return m
 
-        # TODO: these as well
-        m.d.comb += self.o.oz.eq(self.i.oz)
-        m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
-        m.d.comb += self.o.ctx.eq(self.i.ctx)
-
 
 class DivPipeCoreCalculateStage(Elaboratable):
     """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
@@ -404,38 +305,45 @@ class DivPipeCoreCalculateStage(Elaboratable):
         trial_compare_rhs_values = []
         pass_flags = []
         for trial_bits in range(radix):
-            shifted_trial_bits = Const(trial_bits, log2_radix) << current_shift
-            shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
+            tb = trial_bits << current_shift
+            tb_width = log2_radix + current_shift
+            shifted_trial_bits = Const(tb, tb_width)
+            shifted_trial_bits2 = Const(tb*2, tb_width+1)
+            shifted_trial_bits_sqrd = Const(tb * tb, tb_width * 2)
 
             # UDivRem
             div_rhs = self.i.compare_rhs
-            div_factor1 = self.i.divisor_radicand * shifted_trial_bits
-            div_rhs += div_factor1 << self.core_config.fract_width
+            if tb != 0:  # no point adding stuff that's multiplied by zero
+                div_factor1 = self.i.divisor_radicand * shifted_trial_bits2
+                div_rhs += div_factor1 << self.core_config.fract_width
 
             # SqrtRem
             sqrt_rhs = self.i.compare_rhs
-            sqrt_factor1 = self.i.quotient_root * (shifted_trial_bits << 1)
-            sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
-            sqrt_factor2 = shifted_trial_bits_sqrd
-            sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
+            if tb != 0:  # no point adding stuff that's multiplied by zero
+                sqrt_factor1 = self.i.quotient_root * shifted_trial_bits2
+                sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
+                sqrt_factor2 = shifted_trial_bits_sqrd
+                sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
 
             # RSqrtRem
             rsqrt_rhs = self.i.compare_rhs
-            rsqrt_rhs += self.i.root_times_radicand * (shifted_trial_bits << 1)
-            rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
+            if tb != 0:  # no point adding stuff that's multiplied by zero
+                rsqrt_rhs += self.i.root_times_radicand * shifted_trial_bits2
+                rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
 
-            trial_compare_rhs = self.o.compare_rhs.like(
-                name=f"trial_compare_rhs_{trial_bits}")
+            trial_compare_rhs = Signal.like(
+                self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
+                reset_less=True)
 
-            with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+            with m.If(self.i.operation == int(DivPipeCoreOperation.UDivRem)):
                 m.d.comb += trial_compare_rhs.eq(div_rhs)
-            with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+            with m.Elif(self.i.operation == int(DivPipeCoreOperation.SqrtRem)):
                 m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
             with m.Else():  # DivPipeCoreOperation.RSqrtRem
                 m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
             trial_compare_rhs_values.append(trial_compare_rhs)
 
-            pass_flag = Signal(name=f"pass_flag_{trial_bits}")
+            pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
             m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
             pass_flags.append(pass_flag)
 
@@ -447,21 +355,26 @@ class DivPipeCoreCalculateStage(Elaboratable):
         # Assumes that pass_flag[0] is always set (since
         # compare_lhs >= compare_rhs is a pipeline invariant).
 
-        next_bits = Signal(log2_radix)
+        next_bits = Signal(log2_radix, reset_less=True)
         for i in range(log2_radix):
             bit_value = 1
             for j in range(0, radix, 1 << i):
                 bit_value ^= pass_flags[j]
             m.d.comb += next_bits.part(i, 1).eq(bit_value)
 
-        next_compare_rhs = 0
+        next_compare_rhs = Signal(radix, reset_less=True)
+        l = []
         for i in range(radix):
-            next_flag = pass_flags[i + 1] if i + 1 < radix else 0
-            next_compare_rhs |= Mux(pass_flags[i] & ~next_flag,
-                                    trial_compare_rhs_values[i],
-                                    0)
-
-        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
+            next_flag = pass_flags[i + 1] if (i + 1 < radix) else Const(0)
+            flag = Signal(reset_less=True, name=f"flag{i}")
+            test = Signal(reset_less=True, name=f"test{i}")
+            # XXX TODO: check the width on this
+            m.d.comb += test.eq((pass_flags[i] & ~next_flag))
+            m.d.comb += flag.eq(Mux(test, trial_compare_rhs_values[i], 0))
+            l.append(flag)
+
+        m.d.comb += next_compare_rhs.eq(Cat(*l))
+        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs.bool())
         m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
                                                   + ((self.i.divisor_radicand
                                                       * next_bits)
@@ -490,7 +403,7 @@ class DivPipeCoreFinalStage(Elaboratable):
 
     def setup(self, m, i):
         """ Pipeline stage setup. """
-        m.submodules.div_pipe_core_setup = self
+        m.submodules.div_pipe_core_final = self
         m.d.comb += self.i.eq(i)
 
     def process(self, i):