Fix order of results from PartitionedEqGtGe
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 5 Feb 2020 23:44:00 +0000 (18:44 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 5 Feb 2020 23:46:56 +0000 (18:46 -0500)
src/ieee754/part_cmp/eq_gt_ge.py
src/ieee754/part_cmp/formal/proof_eq_gt_ge.py
src/ieee754/part_cmp/reorder_results.py [new file with mode: 0644]

index cb6d117d6b2c1627bc39611107d5f8ea285bf7ef..e1f933fd185638b76e402cd41d1de8a25e1b7b6a 100644 (file)
@@ -15,10 +15,11 @@ See:
 
 from nmigen import Signal, Module, Elaboratable, Cat, C, Mux, Repl
 from nmigen.back.pysim import Simulator, Delay, Settle
-from nmigen.cli import main
+from nmigen.cli import main, rtlil
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
 from ieee754.part_cmp.gt_combiner import GTCombiner
+from ieee754.part_cmp.reorder_results import ReorderResults
 
 
 class PartitionedEqGtGe(Elaboratable):
@@ -50,6 +51,8 @@ class PartitionedEqGtGe(Elaboratable):
         comb = m.d.comb
         m.submodules.gtc = gtc = GTCombiner(self.mwidth)
 
+        m.submodules.reorder = reorder = ReorderResults(self.mwidth)
+
         # make a series of "eqs" and "gts", splitting a and b into
         # partition chunks
         eqs = Signal(self.mwidth, reset_less=True)
@@ -63,7 +66,7 @@ class PartitionedEqGtGe(Elaboratable):
             end = keys[i]
             eql.append(self.a[start:end] == self.b[start:end])
             gtl.append(self.a[start:end] > self.b[start:end])
-            start = end # for next time round loop
+            start = end  # for next time round loop
         comb += eqs.eq(Cat(*eql))
         comb += gts.eq(Cat(*gtl))
 
@@ -84,12 +87,18 @@ class PartitionedEqGtGe(Elaboratable):
                 comb += aux_input.eq(1)
                 comb += gt_en.eq(1)
 
+        results = Signal(self.mwidth, reset_less=True)
         comb += gtc.gates.eq(self.partition_points.as_sig())
         comb += gtc.eqs.eq(eqs)
         comb += gtc.gts.eq(gts)
         comb += gtc.aux_input.eq(aux_input)
         comb += gtc.gt_en.eq(gt_en)
-        comb += self.output.eq(gtc.outputs)
+        comb += results.eq(gtc.outputs)
+
+        comb += reorder.results_in.eq(results)
+        comb += reorder.gates.eq(self.partition_points.as_sig())
+
+        comb += self.output.eq(reorder.output)
 
         return m
 
@@ -110,9 +119,10 @@ if __name__ == "__main__":
         yield mask.eq(0b010)
         yield egg.a.eq(0xf000)
         yield egg.b.eq(0)
+        yield egg.opcode.eq(0b00)
         yield Delay(1e-6)
         out = yield egg.output
-        print ("out", bin(out))
+        print("out", bin(out))
         yield mask.eq(0b111)
         yield egg.a.eq(0x0000)
         yield egg.b.eq(0)
@@ -122,7 +132,7 @@ if __name__ == "__main__":
         yield egg.b.eq(0)
         yield Delay(1e-6)
         out = yield egg.output
-        print ("out", bin(out))
+        print("out", bin(out))
 
     sim.add_process(process)
     with sim.write_vcd("eq_gt_ge.vcd", "eq_gt_ge.gtkw", traces=egg.ports()):
index bfad73a80385bffeefc70217a54f790e981ccd95..1118bd444b59f643f182ed960d445a11b0c121ca 100644 (file)
@@ -57,15 +57,15 @@ class EqualsDriver(Elaboratable):
         with m.If(opcode == 0b00):
             with m.Switch(gates):
                 with m.Case(0b00):
-                    comb += Assert(out[-1] == (a == b))
+                    comb += Assert(out[0] == (a == b))
                 with m.Case(0b01):
-                    comb += Assert(out[2] == ((a_intervals[1] == \
+                    comb += Assert(out[1] == ((a_intervals[1] == \
                                     b_intervals[1]) &
                                               (a_intervals[2] == \
                                     b_intervals[2])))
                     comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
                 with m.Case(0b10):
-                    comb += Assert(out[1] == ((a_intervals[0] == \
+                    comb += Assert(out[0] == ((a_intervals[0] == \
                                     b_intervals[0]) &
                                               (a_intervals[1] == \
                                     b_intervals[1])))
@@ -77,41 +77,41 @@ class EqualsDriver(Elaboratable):
         with m.If(opcode == 0b01):
             with m.Switch(gates):
                 with m.Case(0b00):
-                    comb += Assert(out[-1] == (a > b))
+                    comb += Assert(out[0] == (a > b))
                 with m.Case(0b01):
                     comb += Assert(out[0] == (a_intervals[0] > b_intervals[0]))
 
-                    comb += Assert(out[1] == 0)
-                    comb += Assert(out[2] == (Cat(*a_intervals[1:3]) > \
-                                    Cat(*b_intervals[1:3])))
+                    comb += Assert(out[1] == (Cat(*a_intervals[1:3]) > \
+                                              Cat(*b_intervals[1:3])))
+                    comb += Assert(out[2] == 0)
                 with m.Case(0b10):
-                    comb += Assert(out[0] == 0)
-                    comb += Assert(out[1] == (Cat(*a_intervals[0:2]) > \
-                                    Cat(*b_intervals[0:2])))
+                    comb += Assert(out[0] == (Cat(*a_intervals[0:2]) > \
+                                              Cat(*b_intervals[0:2])))
+                    comb += Assert(out[1] == 0)
                     comb += Assert(out[2] == (a_intervals[2] > b_intervals[2]))
                 with m.Case(0b11):
                     for i in range(mwidth-1):
                         comb += Assert(out[i] == (a_intervals[i] > \
-                                    b_intervals[i]))
+                                                  b_intervals[i]))
         with m.If(opcode == 0b10):
             with m.Switch(gates):
                 with m.Case(0b00):
-                    comb += Assert(out[-1] == (a >= b))
+                    comb += Assert(out[0] == (a >= b))
                 with m.Case(0b01):
                     comb += Assert(out[0] == (a_intervals[0] >= b_intervals[0]))
 
-                    comb += Assert(out[1] == 0)
-                    comb += Assert(out[2] == (Cat(*a_intervals[1:3]) >= \
-                                    Cat(*b_intervals[1:3])))
+                    comb += Assert(out[1] == (Cat(*a_intervals[1:3]) >= \
+                                              Cat(*b_intervals[1:3])))
+                    comb += Assert(out[2] == 0)
                 with m.Case(0b10):
-                    comb += Assert(out[0] == 0)
-                    comb += Assert(out[1] == (Cat(*a_intervals[0:2]) >= \
-                                    Cat(*b_intervals[0:2])))
+                    comb += Assert(out[0] == (Cat(*a_intervals[0:2]) >= \
+                                              Cat(*b_intervals[0:2])))
+                    comb += Assert(out[1] == 0)
                     comb += Assert(out[2] == (a_intervals[2] >= b_intervals[2]))
                 with m.Case(0b11):
                     for i in range(mwidth-1):
                         comb += Assert(out[i] == \
-                                    (a_intervals[i] >= b_intervals[i]))
+                                       (a_intervals[i] >= b_intervals[i]))
 
 
 
diff --git a/src/ieee754/part_cmp/reorder_results.py b/src/ieee754/part_cmp/reorder_results.py
new file mode 100644 (file)
index 0000000..577f170
--- /dev/null
@@ -0,0 +1,30 @@
+# gt_combiner returns results that are in the wrong order from how
+# they need to be. Specifically, if the partition gates are open, the
+# bits need to be reversed through the width of the partition. This
+# module does that
+from nmigen import Signal, Module, Elaboratable, Mux
+from ieee754.part_mul_add.partpoints import PartitionPoints
+
+class ReorderResults(Elaboratable):
+    def __init__(self, width):
+        self.width = width
+        self.results_in = Signal(width, reset_less=True)
+        self.gates = Signal(width-1, reset_less=True)
+
+        self.output = Signal(width, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        width = self.width
+
+        current_result = self.results_in[-1]
+
+        for i in range(width-2, -1, -1):  # counts down from width-1 to 0
+            cur = Signal()
+            comb += cur.eq(current_result)
+            comb += self.output[i+1].eq(cur & self.gates[i])
+            current_result = Mux(self.gates[i], self.results_in[i], cur)
+
+            comb += self.output[0].eq(current_result)
+        return m