split out logical ops into PartitionedBase
[ieee754fpu.git] / src / ieee754 / part_bits / xor.py
index 1c0110df525908f26587f73ac1210e170422824e..85129df0896e83f539f346c1781269ca24543dd5 100644 (file)
@@ -20,50 +20,16 @@ from nmutil.ripple import RippleLSB
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
 from ieee754.part_cmp.experiments.eq_combiner import XORCombiner
+from ieee754.part_bits.base import PartitionedBase
 
 
-class PartitionedXOR(Elaboratable):
+
+class PartitionedXOR(PartitionedBase):
 
     def __init__(self, width, partition_points):
         """Create a ``PartitionedXOR`` operator
         """
-        self.width = width
-        self.a = Signal(width, reset_less=True)
-        self.partition_points = PartitionPoints(partition_points)
-        self.mwidth = len(self.partition_points)+1
-        self.output = Signal(self.mwidth, reset_less=True)
-        if not self.partition_points.fits_in_width(width):
-            raise ValueError("partition_points doesn't fit in width")
-
-    def elaborate(self, platform):
-        m = Module()
-        comb = m.d.comb
-        m.submodules.xorc = xorc = XORCombiner(self.mwidth)
-
-        # make a series of "xor", splitting a and b into partition chunks
-        xors = Signal(self.mwidth, reset_less=True)
-        xorl = []
-        keys = list(self.partition_points.keys()) + [self.width]
-        start = 0
-        for i in range(len(keys)):
-            end = keys[i]
-            xorl.append(self.a[start:end].xor())
-            start = end # for next time round loop
-        comb += xors.eq(Cat(*xorl))
-
-        # put the partial results through the combiner
-        comb += xorc.gates.eq(self.partition_points.as_sig())
-        comb += xorc.neqs.eq(xors)
-
-        m.submodules.ripple = ripple = RippleLSB(self.mwidth)
-        comb += ripple.results_in.eq(xorc.outputs)
-        comb += ripple.gates.eq(self.partition_points.as_sig())
-        comb += self.output.eq(~ripple.output)
-
-        return m
-
-    def ports(self):
-        return [self.a, self.output]
+        super().__init__(width, partition_points, XORCombiner, "xor")
 
 
 if __name__ == "__main__":