split out div/sqrt/rsqrt trials to separate module
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 23 Jul 2019 11:45:38 +0000 (12:45 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 23 Jul 2019 11:45:38 +0000 (12:45 +0100)
src/ieee754/div_rem_sqrt_rsqrt/core.py
src/ieee754/div_rem_sqrt_rsqrt/test_core.py

index 8e49ea4da0b159cc3b28bc19d1dd9ec446788a97..e4602194fb84e5937f4888c1b8df75136e9acbf4 100644 (file)
@@ -256,6 +256,85 @@ class DivPipeCoreSetupStage(Elaboratable):
         return m
 
 
+class Trial(Elaboratable):
+    def __init__(self, core_config, trial_bits, current_shift, log2_radix):
+        self.core_config = core_config
+        self.trial_bits = trial_bits
+        self.current_shift = current_shift
+        self.log2_radix = log2_radix
+        bw = core_config.bit_width
+        self.divisor_radicand = Signal(bw, reset_less=True)
+        self.quotient_root = Signal(bw, reset_less=True)
+        self.root_times_radicand = Signal(bw * 2, reset_less=True)
+        self.compare_rhs = Signal(bw * 3, reset_less=True)
+        self.trial_compare_rhs = Signal(bw * 3, reset_less=True)
+        self.operation = DP.create_signal(reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        dr = self.divisor_radicand
+        qr = self.quotient_root
+        rr = self.root_times_radicand
+
+        trial_bits_sig = Const(self.trial_bits, self.log2_radix)
+        trial_bits_sqrd_sig = Const(self.trial_bits * self.trial_bits,
+                                    self.log2_radix * 2)
+
+        tblen = self.core_config.bit_width+self.log2_radix
+        tblen2 = self.core_config.bit_width+self.log2_radix*2
+        dr_times_trial_bits_sqrd = Signal(tblen2, reset_less=True)
+        m.d.comb += dr_times_trial_bits_sqrd.eq(dr * trial_bits_sqrd_sig)
+
+        # UDivRem
+        with m.If(self.operation == int(DP.UDivRem)):
+            dr_times_trial_bits = Signal(tblen, reset_less=True)
+            m.d.comb += dr_times_trial_bits.eq(dr * trial_bits_sig)
+            div_rhs = self.compare_rhs
+
+            div_term1 = dr_times_trial_bits
+            div_term1_shift = self.core_config.fract_width
+            div_term1_shift += self.current_shift
+            div_rhs += div_term1 << div_term1_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(div_rhs)
+
+        # SqrtRem
+        with m.Elif(self.operation == int(DP.SqrtRem)):
+            qr_times_trial_bits = Signal((tblen+1)*2, reset_less=True)
+            m.d.comb += qr_times_trial_bits.eq(qr * trial_bits_sig)
+            sqrt_rhs = self.compare_rhs
+
+            sqrt_term1 = qr_times_trial_bits
+            sqrt_term1_shift = self.core_config.fract_width
+            sqrt_term1_shift += self.current_shift + 1
+            sqrt_rhs += sqrt_term1 << sqrt_term1_shift
+            sqrt_term2 = trial_bits_sqrd_sig
+            sqrt_term2_shift = self.core_config.fract_width
+            sqrt_term2_shift += self.current_shift * 2
+            sqrt_rhs += sqrt_term2 << sqrt_term2_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(sqrt_rhs)
+
+        # RSqrtRem
+        with m.Else():
+            rr_times_trial_bits = Signal((tblen+1)*3, reset_less=True)
+            m.d.comb += rr_times_trial_bits.eq(rr * trial_bits_sig)
+            rsqrt_rhs = self.compare_rhs
+
+            rsqrt_term1 = rr_times_trial_bits
+            rsqrt_term1_shift = self.current_shift + 1
+            rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
+            rsqrt_term2 = dr_times_trial_bits_sqrd
+            rsqrt_term2_shift = self.current_shift * 2
+            rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
+
+            m.d.comb += self.trial_compare_rhs.eq(rsqrt_rhs)
+
+        return m
+
+
 class DivPipeCoreCalculateStage(Elaboratable):
     """ Calculate Stage of the core of the div/rem/sqrt/rsqrt pipeline. """
 
@@ -302,65 +381,19 @@ class DivPipeCoreCalculateStage(Elaboratable):
         trial_compare_rhs_values = []
         pass_flags = []
         for trial_bits in range(radix):
-            trial_bits_sig = Const(trial_bits, log2_radix)
-            trial_bits_sqrd_sig = Const(trial_bits * trial_bits,
-                                        log2_radix * 2)
-
-            dr_times_trial_bits = self.i.divisor_radicand * trial_bits_sig
-            dr_times_trial_bits_sqrd = self.i.divisor_radicand \
-                * trial_bits_sqrd_sig
-            qr_times_trial_bits = self.i.quotient_root * trial_bits_sig
-            rr_times_trial_bits = self.i.root_times_radicand * trial_bits_sig
-
-            trial_compare_rhs = Signal.like(
-                self.o.compare_rhs, name=f"trial_compare_rhs_{trial_bits}",
-                reset_less=True)
-            m.d.comb += trial_compare_rhs.eq(self.i.compare_rhs)
-
-            if trial_bits != 0:  # no point adding multiply by zero
-                # UDivRem
-                with m.If(self.i.operation == int(DP.UDivRem)):
-                    div_rhs = self.i.compare_rhs
-
-                    div_term1 = dr_times_trial_bits
-                    div_term1_shift = self.core_config.fract_width
-                    div_term1_shift += current_shift
-                    div_rhs += div_term1 << div_term1_shift
-
-                    m.d.comb += trial_compare_rhs.eq(div_rhs)
-
-                # SqrtRem
-                with m.Elif(self.i.operation == int(DP.SqrtRem)):
-                    sqrt_rhs = self.i.compare_rhs
-
-                    sqrt_term1 = qr_times_trial_bits
-                    sqrt_term1_shift = self.core_config.fract_width
-                    sqrt_term1_shift += current_shift + 1
-                    sqrt_rhs += sqrt_term1 << sqrt_term1_shift
-                    sqrt_term2 = trial_bits_sqrd_sig
-                    sqrt_term2_shift = self.core_config.fract_width
-                    sqrt_term2_shift += current_shift * 2
-                    sqrt_rhs += sqrt_term2 << sqrt_term2_shift
-
-                    m.d.comb += trial_compare_rhs.eq(sqrt_rhs)
-
-                # RSqrtRem
-                with m.Else():
-                    rsqrt_rhs = self.i.compare_rhs
-
-                    rsqrt_term1 = rr_times_trial_bits
-                    rsqrt_term1_shift = current_shift + 1
-                    rsqrt_rhs += rsqrt_term1 << rsqrt_term1_shift
-                    rsqrt_term2 = dr_times_trial_bits_sqrd
-                    rsqrt_term2_shift = current_shift * 2
-                    rsqrt_rhs += rsqrt_term2 << rsqrt_term2_shift
-
-                    m.d.comb += trial_compare_rhs.eq(rsqrt_rhs)
-
-            trial_compare_rhs_values.append(trial_compare_rhs)
+            t = Trial(self.core_config, trial_bits,
+                          current_shift, log2_radix)
+            setattr(m.submodules, "trial%d" % trial_bits, t)
+            m.d.comb += t.divisor_radicand.eq(self.i.divisor_radicand)
+            m.d.comb += t.quotient_root.eq(self.i.quotient_root)
+            m.d.comb += t.root_times_radicand.eq(self.i.root_times_radicand)
+            m.d.comb += t.compare_rhs.eq(self.i.compare_rhs)
+            m.d.comb += t.operation.eq(self.i.operation)
+
+            trial_compare_rhs_values.append(t.trial_compare_rhs)
 
             pass_flag = Signal(name=f"pass_flag_{trial_bits}", reset_less=True)
-            m.d.comb += pass_flag.eq(self.i.compare_lhs >= trial_compare_rhs)
+            m.d.comb += pass_flag.eq(self.i.compare_lhs >= t.trial_compare_rhs)
             pass_flags.append(pass_flag)
 
         # convert pass_flags to next_bits.
index fc42829ba24a3cbe51b0ce4b5e3a3a3a2fd03356..95292f9afbf95e6198dc05e7a1864bcd420f8d54 100755 (executable)
@@ -257,8 +257,10 @@ class TestDivPipeCore(unittest.TestCase):
                     remainder = (yield dut.o.remainder)
                     with self.subTest(test_case=str(test_case)):
                         self.assertEqual(quotient_root,
-                                         test_case.quotient_root)
-                        self.assertEqual(remainder, test_case.remainder)
+                                         test_case.quotient_root,
+                                         str(test_case))
+                        self.assertEqual(remainder, test_case.remainder,
+                                         str(test_case))
             sim.add_clock(2e-6)
             sim.add_sync_process(generate_process)
             sim.add_sync_process(check_process)