fix syntax errors for test_partsig
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 7 Feb 2020 14:47:07 +0000 (14:47 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 7 Feb 2020 14:47:09 +0000 (14:47 +0000)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py
src/ieee754/part_mux/part_mux.py

index b7e6eabbb9249241327241bcde386e3b142cac0f..9e1c3f72d0f6662c255949dd563fbcf95aed9df3 100644 (file)
@@ -49,6 +49,8 @@ class PartitionedSignal:
         return "%s%d" % (category, self.modnames[category])
 
     def eq(self, val):
+        if isinstance(val, PartitionedSignal):
+            return self.sig.eq(val.sig)
         return self.sig.eq(val)
 
     # unary  ops that require partitioning
index eb12b8a601cad98a12badbf151a0b45d68247a95..c9dfcb5216e608194ab4a835fc6ae407d2c32b1c 100644 (file)
@@ -26,6 +26,7 @@ def create_simulator(module, traces, test_name):
 
 class TestAddMod(Elaboratable):
     def __init__(self, width, partpoints):
+        self.partpoints = partpoints
         self.a = PartitionedSignal(partpoints, width)
         self.b = PartitionedSignal(partpoints, width)
         self.add_output = Signal(width)
@@ -49,7 +50,8 @@ class TestAddMod(Elaboratable):
         m.d.comb += self.eq_output.eq(self.a == self.b)
         m.d.comb += self.ge_output.eq(self.a >= self.b)
         m.d.comb += self.add_output.eq(self.a + self.b)
-        m.d.comb += self.mux_out.eq(PMux(m, self.a, self.b, self.mux_sel))
+        ppts = self.partpoints
+        m.d.comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
 
         return m
 
index 54f34812d4cd06bc3195ea1aca0b198d2707c43d..b0cd90b4744d6bb80175a2cad0b63a797699bd72 100644 (file)
@@ -15,13 +15,17 @@ See:
 
 from nmigen import Signal, Module, Elaboratable, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.partpoints import make_partition
 
 modcount = 0 # global for now
-def PMux(m, sel, a, b):
+def PMux(m, mask, sel, a, b):
+    global modcount
     modcount += 1
-    pm = PartitionedMux(a.shape()[0])
-    m.d.comb += pm.a.eq(a)
-    m.d.comb += pm.b.eq(b)
+    width = a.sig.shape()[0] # get width
+    part_pts = make_partition(mask, width) # create partition points
+    pm = PartitionedMux(width, part_pts)
+    m.d.comb += pm.a.eq(a.sig)
+    m.d.comb += pm.b.eq(b.sig)
     m.d.comb += pm.sel.eq(sel)
     setattr(m.submodules, "pmux%d" % modcount, pm)
     return pm.output
@@ -35,7 +39,7 @@ class PartitionedMux(Elaboratable):
     consequently the incoming selector (sel) can completely
     ignore what the *actual* partition bits are.
     """
-    def __init__(self, width):
+    def __init__(self, width, partition_points):
         self.width = width
         self.partition_points = PartitionPoints(partition_points)
         self.mwidth = len(self.partition_points)+1
@@ -43,8 +47,8 @@ class PartitionedMux(Elaboratable):
         self.b = Signal(width, reset_less=True)
         self.sel = Signal(self.mwidth, reset_less=True)
         self.output = Signal(width, reset_less=True)
-        assert (self.partition_points.fits_in_width(width),
-                    "partition_points doesn't fit in width")
+        assert self.partition_points.fits_in_width(width), \
+                    "partition_points doesn't fit in width"
 
     def elaborate(self, platform):
         m = Module()
@@ -56,12 +60,12 @@ class PartitionedMux(Elaboratable):
         start = 0
         for i in range(len(keys)):
             end = keys[i]
-            mux = output[start:end]
-            mux.append(self.a[start:end] == self.b[start:end])
+            mux = self.output[start:end]
+            comb += mux.eq(self.a[start:end] == self.b[start:end])
             start = end  # for next time round loop
 
         return m
 
     def ports(self):
-        return [self.a, self.b, self.sel, self.output]
+        return [self.a.sig, self.b.sig, self.sel, self.output]