move local variables
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 14:30:38 +0000 (15:30 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 14:30:38 +0000 (15:30 +0100)
src/ieee754/part_mul_add/multiply.py
src/ieee754/part_mul_add/test/test_multiply.py

index facbce300d031b6823f790d961f6a090ab22c52d..5176d469b25b7fbf0daeb80e9bcead943c3b7053 100644 (file)
@@ -624,21 +624,23 @@ class Mul8_16_32_64(Elaboratable):
     """
 
     def __init__(self, register_levels= ()):
+
+        # parameter(s)
+        self.register_levels = list(register_levels)
+
+        # inputs
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.a = Signal(64)
         self.b = Signal(64)
-        self.output = Signal(64)
-        self.register_levels = list(register_levels)
+
+        # intermediates (needed for unit tests)
         self._intermediate_output = Signal(128)
-        self._output_64 = Signal(64)
-        self._output_32 = Signal(64)
-        self._output_16 = Signal(64)
-        self._output_8 = Signal(64)
-        self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
-        self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+
+        # output
+        self.output = Signal(64)
 
     def _part_byte(self, index):
         if index == -1 or index == 7:
@@ -658,6 +660,14 @@ class Mul8_16_32_64(Elaboratable):
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
+        # local variables
+        output_64 = Signal(64)
+        output_32 = Signal(64)
+        output_16 = Signal(64)
+        output_8 = Signal(64)
+        a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
+        b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+
         delayed_part_ops = [
             [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
              for i in range(8)]
@@ -676,10 +686,10 @@ class Mul8_16_32_64(Elaboratable):
         for mod in [part_8, part_16, part_32, part_64]:
             m.d.comb += mod.a.eq(self.a)
             m.d.comb += mod.b.eq(self.b)
-            for i in range(len(self._a_signed)):
-                m.d.comb += mod._a_signed[i].eq(self._a_signed[i])
-            for i in range(len(self._b_signed)):
-                m.d.comb += mod._b_signed[i].eq(self._b_signed[i])
+            for i in range(len(a_signed)):
+                m.d.comb += mod._a_signed[i].eq(a_signed[i])
+            for i in range(len(b_signed)):
+                m.d.comb += mod._b_signed[i].eq(b_signed[i])
             m.d.comb += mod.pbs.eq(pbs)
             nat_l.append(mod.not_a_term)
             nbt_l.append(mod.not_b_term)
@@ -700,11 +710,11 @@ class Mul8_16_32_64(Elaboratable):
                 terms.append(term)
 
         for i in range(8):
-            a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
-            b_signed = (self.part_ops[i] == OP_MUL_LOW) \
+            asig = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
+            bsig = (self.part_ops[i] == OP_MUL_LOW) \
                         | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
-            m.d.comb += self._a_signed[i].eq(a_signed)
-            m.d.comb += self._b_signed[i].eq(b_signed)
+            m.d.comb += a_signed[i].eq(asig)
+            m.d.comb += b_signed[i].eq(bsig)
 
         # it's fine to bitwise-or data together since they are never enabled
         # at the same time
@@ -737,28 +747,28 @@ class Mul8_16_32_64(Elaboratable):
         m.d.comb += io64.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += self._output_64.eq(io64.output)
+        m.d.comb += output_64.eq(io64.output)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
         m.d.comb += io32.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += self._output_32.eq(io32.output)
+        m.d.comb += output_32.eq(io32.output)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
         m.d.comb += io16.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += self._output_16.eq(io16.output)
+        m.d.comb += output_16.eq(io16.output)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
         m.d.comb += io8.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += self._output_8.eq(io8.output)
+        m.d.comb += output_8.eq(io8.output)
 
         # final output
         ol = []
@@ -768,11 +778,11 @@ class Mul8_16_32_64(Elaboratable):
                 Mux(part_8.delayed_parts[-1][i]
                     | part_16.delayed_parts[-1][i // 2],
                     Mux(part_8.delayed_parts[-1][i],
-                        self._output_8.bit_select(i * 8, 8),
-                        self._output_16.bit_select(i * 8, 8)),
+                        output_8.bit_select(i * 8, 8),
+                        output_16.bit_select(i * 8, 8)),
                     Mux(part_32.delayed_parts[-1][i // 4],
-                        self._output_32.bit_select(i * 8, 8),
-                        self._output_64.bit_select(i * 8, 8))))
+                        output_32.bit_select(i * 8, 8),
+                        output_64.bit_select(i * 8, 8))))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
         return m
index 0c0b420b7f53cbeb9623207fba82eabaf8221bc1..f15e948e1b88ab0a50115edc425162932e9cdb8a 100644 (file)
@@ -525,12 +525,6 @@ class TestMul8_16_32_64(unittest.TestCase):
                  module.output]
         ports.extend(module.part_ops)
         ports.extend(module.part_pts.values())
-        ports += [module._output_64,
-                  module._output_32,
-                  module._output_16,
-                  module._output_8]
-        ports.extend(module._a_signed)
-        ports.extend(module._b_signed)
         with create_simulator(module, ports, file_name) as sim:
             def process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
                 for a_signed in False, True: