add rest of DivPipeCore
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Jul 2019 12:01:40 +0000 (05:01 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 5 Jul 2019 12:01:40 +0000 (05:01 -0700)
src/ieee754/div_rem_sqrt_rsqrt/core.py

index dd1fdf9143caa359c2e2895ee170fb21e3a8edcf..b52c89488311768838d6a2a0f944a83500cce35d 100644 (file)
@@ -18,7 +18,7 @@ Formulas solved are:
 The remainder is the left-hand-side of the comparison minus the
 right-hand-side of the comparison in the above formulas.
 """
-from nmigen import (Elaboratable, Module, Signal)
+from nmigen import (Elaboratable, Module, Signal, Const, Mux)
 import enum
 
 # TODO
@@ -47,6 +47,11 @@ class DivPipeCoreConfig:
         return f"DivPipeCoreConfig({self.bit_width}, " \
             + f"{self.fract_width}, {self.log2_radix})"
 
+    @property
+    def num_calculate_stages(self):
+        """ Get the number of ``DivPipeCoreCalculateStage`` needed. """
+        return (self.bit_width + self.log2_radix - 1) // self.log2_radix
+
 
 class DivPipeCoreOperation(enum.IntEnum):
     """ Operation for ``DivPipeCore``.
@@ -94,20 +99,19 @@ class DivPipeCoreInputData:
         # FIXME: this goes into (is replaced by) self.ctx.op
         self.operation = DivPipeCoreOperation.create_signal(reset_less=True)
 
-        return # TODO: needs a width argument and a pspec
+        return  # TODO: needs a width argument and a pspec
         self.z = FPNumBaseRecord(width, False)
         self.out_do_z = Signal(reset_less=True)
         self.oz = Signal(width, reset_less=True)
 
-        self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
+        self.ctx = FPPipeContext(width, pspec)  # context: muxid, operator etc.
         self.muxid = self.ctx.muxid             # annoying. complicated.
 
-
     def __iter__(self):
         """ Get member signals. """
         yield self.dividend
         yield self.divisor_radicand
-        yield self.operation # FIXME: delete.  already covered by self.ctx
+        yield self.operation  # FIXME: delete.  already covered by self.ctx
         return
         yield self.z
         yield self.out_do_z
@@ -118,13 +122,12 @@ class DivPipeCoreInputData:
         """ Assign member signals. """
         return [self.dividend.eq(rhs.dividend),
                 self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation)] # FIXME: delete.
+                self.operation.eq(rhs.operation)]  # FIXME: delete.
         # TODO: and these
         return [self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz),
                 self.ctx.eq(i.ctx)]
 
 
-
 class DivPipeCoreInterstageData:
     """ interstage data type for ``DivPipeCore``.
 
@@ -161,18 +164,18 @@ class DivPipeCoreInterstageData:
                                           reset_less=True)
         self.compare_lhs = Signal(core_config.bit_width * 3, reset_less=True)
         self.compare_rhs = Signal(core_config.bit_width * 3, reset_less=True)
-        return # TODO: needs a width argument and a pspec
+        return  # TODO: needs a width argument and a pspec
         self.z = FPNumBaseRecord(width, False)
         self.out_do_z = Signal(reset_less=True)
         self.oz = Signal(width, reset_less=True)
 
-        self.ctx = FPPipeContext(width, pspec) # context: muxid, operator etc.
+        self.ctx = FPPipeContext(width, pspec)  # context: muxid, operator etc.
         self.muxid = self.ctx.muxid             # annoying. complicated.
 
     def __iter__(self):
         """ Get member signals. """
         yield self.divisor_radicand
-        yield self.operation # XXX FIXME: delete.  already in self.ctx.op
+        yield self.operation  # XXX FIXME: delete.  already in self.ctx.op
         yield self.quotient_root
         yield self.root_times_radicand
         yield self.compare_lhs
@@ -186,7 +189,7 @@ class DivPipeCoreInterstageData:
     def eq(self, rhs):
         """ Assign member signals. """
         return [self.divisor_radicand.eq(rhs.divisor_radicand),
-                self.operation.eq(rhs.operation), # FIXME: delete.
+                self.operation.eq(rhs.operation),  # FIXME: delete.
                 self.quotient_root.eq(rhs.quotient_root),
                 self.root_times_radicand.eq(rhs.root_times_radicand),
                 self.compare_lhs.eq(rhs.compare_lhs),
@@ -196,10 +199,40 @@ class DivPipeCoreInterstageData:
                 self.ctx.eq(i.ctx)]
 
 
-class DivPipeCoreSetupStage(Elaboratable):
-    """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline.
+class DivPipeCoreOutputData:
+    """ output data type for ``DivPipeCore``.
+
+    :attribute core_config: ``DivPipeCoreConfig`` instance describing the
+        configuration to be used.
+    :attribute quotient_root: the quotient or root part of the result of the
+        operation. Signal with a bit-width of ``core_config.bit_width`` and a
+        fract-width of ``core_config.fract_width`` bits.
+    :attribute remainder: the remainder part of the result of the operation.
+        Signal with a bit-width of ``core_config.bit_width * 3`` and a
+        fract-width of ``core_config.fract_width * 3`` bits.
     """
 
+    def __init__(self, core_config):
+        """ Create a ``DivPipeCoreOutputData`` instance. """
+        self.core_config = core_config
+        self.quotient_root = Signal(core_config.bit_width, reset_less=True)
+        self.remainder = Signal(core_config.bit_width * 3, reset_less=True)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.quotient_root
+        yield self.remainder
+        return
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [self.quotient_root.eq(rhs.quotient_root),
+                self.remainder.eq(rhs.remainder)]
+
+
+class DivPipeCoreSetupStage(Elaboratable):
+    """ Setup Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
     def __init__(self, core_config):
         """ Create a ``DivPipeCoreSetupStage`` instance."""
         self.core_config = core_config
@@ -251,3 +284,152 @@ class DivPipeCoreSetupStage(Elaboratable):
         m.d.comb += self.o.out_do_z.eq(self.i.out_do_z)
         m.d.comb += self.o.ctx.eq(self.i.ctx)
 
+
+class DivPipeCoreCalculateStage(Elaboratable):
+    """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
+    def __init__(self, core_config, stage_index):
+        """ Create a ``DivPipeCoreSetupStage`` instance. """
+        self.core_config = core_config
+        assert stage_index in range(core_config.num_calculate_stages)
+        self.stage_index = stage_index
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        """ Get the input spec for this pipeline stage. """
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def ospec(self):
+        """ Get the output spec for this pipeline stage. """
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def setup(self, m, i):
+        """ Pipeline stage setup. """
+        setattr(m.submodules,
+                f"div_pipe_core_calculate_{self.stage_index}",
+                self)
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        """ Pipeline stage process. """
+        return self.o
+
+    def elaborate(self, platform):
+        """ Elaborate into ``Module``. """
+        m = Module()
+        m.d.comb += self.o.divisor_radicand.eq(self.i.divisor_radicand)
+        m.d.comb += self.o.operation.eq(self.i.operation)
+        m.d.comb += self.o.compare_lhs.eq(self.i.compare_lhs)
+        log2_radix = self.core_config.log2_radix
+        current_shift = self.core_config.bit_width
+        current_shift -= self.stage_index * log2_radix
+        log2_radix = min(log2_radix, current_shift)
+        assert log2_radix > 0
+        current_shift -= log2_radix
+        radix = 1 << log2_radix
+        trial_compare_rhs_values = []
+        pass_flags = []
+        for trial_bits in range(radix):
+            shifted_trial_bits = Const(trial_bits, log2_radix) << current_shift
+            shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
+
+            # UDivRem
+            div_rhs = self.i.compare_rhs
+            div_factor1 = self.i.divisor_radicand * shifted_trial_bits
+            div_rhs += div_factor1 << self.core_config.fract_width
+
+            # SqrtRem
+            sqrt_rhs = self.i.compare_rhs
+            sqrt_factor1 = self.i.quotient_root * (shifted_trial_bits << 1)
+            sqrt_rhs += sqrt_factor1 << self.core_config.fract_width
+            sqrt_factor2 = shifted_trial_bits_sqrd
+            sqrt_rhs += sqrt_factor2 << self.core_config.fract_width
+
+            # RSqrtRem
+            rsqrt_rhs = self.i.compare_rhs
+            rsqrt_rhs += self.i.root_times_radicand * (shifted_trial_bits << 1)
+            rsqrt_rhs += self.i.divisor_radicand * shifted_trial_bits_sqrd
+
+            trial_compare_rhs = self.o.compare_rhs.like(
+                name=f"trial_compare_rhs_{trial_bits}")
+
+            with m.If(self.i.operation == DivPipeCoreOperation.UDivRem):
+                m.d.comb += trial_compare_rhs.eq(div_rhs)
+            with m.Elif(self.i.operation == DivPipeCoreOperation.SqrtRem):
+                m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
+            with m.Else():  # DivPipeCoreOperation.RSqrtRem
+                m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
+            trial_compare_rhs_values.append(trial_compare_rhs)
+
+            pass_flag = Signal(name=f"pass_flag_{trial_bits}")
+            m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
+            pass_flags.append(pass_flag)
+
+        # convert pass_flags to next_bits.
+        #
+        # Assumes that for each set bit in pass_flag, all previous bits are
+        # also set.
+        #
+        # Assumes that pass_flag[0] is always set (since
+        # compare_lhs >= compare_rhs is a pipeline invariant).
+
+        next_bits = Signal(log2_radix)
+        for i in range(log2_radix):
+            bit_value = 1
+            for j in range(0, radix, 1 << i):
+                bit_value ^= pass_flags[j]
+            m.d.comb += next_bits.part(i, 1).eq(bit_value)
+
+        next_compare_rhs = 0
+        for i in range(radix):
+            next_flag = pass_flags[i + 1] if i + 1 < radix else 0
+            next_compare_rhs |= Mux(pass_flags[i] & ~next_flag,
+                                    trial_compare_rhs_values[i],
+                                    0)
+
+        m.d.comb += self.o.compare_rhs.eq(next_compare_rhs)
+        m.d.comb += self.o.root_times_radicand.eq(self.i.root_times_radicand
+                                                  + ((self.i.divisor_radicand
+                                                      * next_bits)
+                                                     << current_shift))
+        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root
+                                            | (next_bits << current_shift))
+        return m
+
+
+class DivPipeCoreFinalStage(Elaboratable):
+    """ Final Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
+
+    def __init__(self, core_config):
+        """ Create a ``DivPipeCoreFinalStage`` instance."""
+        self.core_config = core_config
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        """ Get the input spec for this pipeline stage."""
+        return DivPipeCoreInterstageData(self.core_config)
+
+    def ospec(self):
+        """ Get the output spec for this pipeline stage."""
+        return DivPipeCoreOutputData(self.core_config)
+
+    def setup(self, m, i):
+        """ Pipeline stage setup. """
+        m.submodules.div_pipe_core_setup = self
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        """ Pipeline stage process. """
+        return self.o  # return processed data (ignore i)
+
+    def elaborate(self, platform):
+        """ Elaborate into ``Module``. """
+        m = Module()
+
+        m.d.comb += self.o.quotient_root.eq(self.i.quotient_root)
+        m.d.comb += self.o.remainder.eq(self.i.compare_lhs
+                                        - self.i.compare_rhs)
+
+        return m