add PartitionedSignal.all() and unit test, currently failing
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 2 Oct 2021 17:26:41 +0000 (18:26 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 2 Oct 2021 17:26:41 +0000 (18:26 +0100)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py
src/ieee754/part_cmp/experiments/eq_combiner.py

index 79cc6c348d17b43ccd21893e92d359c21ce0c414..e6939422ecc57270f9a369d7c5d191edfce9baa8 100644 (file)
@@ -20,6 +20,7 @@ from ieee754.part_mul_add.adder import PartitionedAdder
 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
 from ieee754.part_bits.xor import PartitionedXOR
 from ieee754.part_bits.bool import PartitionedBool
+from ieee754.part_bits.all import PartitionedAll
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
 from ieee754.part_mul_add.partpoints import make_partition2, PartitionPoints
@@ -50,7 +51,7 @@ global modnames
 modnames = {}
 # for sub-modules to be created on-demand. Mux is done slightly
 # differently (has its own global)
-for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool']:
+for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool', 'all']:
     modnames[name] = 0
 
 
@@ -355,7 +356,11 @@ class PartitionedSignal(UserValue):
         Value, out
             ``1`` if all bits are set, ``0`` otherwise.
         """
-        return self == Const(-1) # leverage the __eq__ operator here
+        width = len(self.sig)
+        pa = PartitionedAll(width, self.partpoints)
+        setattr(self.m.submodules, self.get_modname("all"), pa)
+        self.m.d.comb += pa.a.eq(self.sig)
+        return pa.output
 
     def xor(self):
         """Compute pairwise exclusive-or of every bit.
index 6937b07ee6c818bdf874ee5523821fefd6557867..d7dd50c74f3e07d010114decc80a03e6a2aab7e3 100644 (file)
@@ -204,6 +204,7 @@ class TestAddMod(Elaboratable):
         self.neg_output = Signal(width)
         self.xor_output = Signal(len(partpoints)+1)
         self.bool_output = Signal(len(partpoints)+1)
+        self.all_output = Signal(len(partpoints)+1)
 
     def elaborate(self, platform):
         m = Module()
@@ -233,6 +234,7 @@ class TestAddMod(Elaboratable):
         # horizontal operators
         comb += self.xor_output.eq(self.a.xor())
         comb += self.bool_output.eq(self.a.bool())
+        comb += self.all_output.eq(self.a.all())
         # left shift
         comb += self.ls_output.eq(self.a << self.b)
         # right shift
@@ -581,48 +583,50 @@ class TestPartitionedSignal(unittest.TestCase):
 
             def test_bool_fn(a, mask):
                 test = (a & mask)
-                result = 0
-                while test != 0:
-                    bit = (test & 1)
-                    result |= bit
-                    test >>= 1
-                return result
+                return test != 0
+
+            def test_all_fn(a, mask):
+                # slightly different: all bits masked must be 1
+                test = (a & mask)
+                return test == mask
 
             def test_horizop(msg_prefix, test_fn, mod_attr, *maskbit_list):
-                randomvals = []
-                for i in range(10):
-                    randomvals.append(randint(0, 65535))
-                for a in [0x0000,
-                             0x1234,
-                             0xABCD,
-                             0xFFFF,
-                             0x8000,
-                             0xBEEF, 0xFEED,
-                                ]+randomvals:
-                    yield module.a.lower().eq(a)
-                    yield Delay(0.1e-6)
-                    # convert to mask_list
-                    mask_list = []
-                    for mb in maskbit_list:
-                        v = 0
-                        for i in range(4):
-                            if mb & (1 << i):
-                                v |= 0xf << (i*4)
-                        mask_list.append(v)
-                    y = 0
-                    # do the partitioned tests
-                    for i, mask in enumerate(mask_list):
-                        if test_fn(a, mask):
-                            # OR y with the lowest set bit in the mask
-                            y |= maskbit_list[i]
-                    # check the result
-                    outval = (yield getattr(module, "%s_output" % mod_attr))
-                    msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
-                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
-                    print((msg % str(maskbit_list)).format(locals()))
-                    self.assertEqual(y, outval, msg % str(maskbit_list))
+                with self.subTest(msg_prefix):
+                    randomvals = []
+                    for i in range(10):
+                        randomvals.append(randint(0, 65535))
+                    for a in [0x0000,
+                                 0x1234,
+                                 0xABCD,
+                                 0xFFFF,
+                                 0x8000,
+                                 0xBEEF, 0xFEED,
+                                    ]+randomvals:
+                        yield module.a.lower().eq(a)
+                        yield Delay(0.1e-6)
+                        # convert to mask_list
+                        mask_list = []
+                        for mb in maskbit_list:
+                            v = 0
+                            for i in range(4):
+                                if mb & (1 << i):
+                                    v |= 0xf << (i*4)
+                            mask_list.append(v)
+                        y = 0
+                        # do the partitioned tests
+                        for i, mask in enumerate(mask_list):
+                            if test_fn(a, mask):
+                                # OR y with the lowest set bit in the mask
+                                y |= maskbit_list[i]
+                        # check the result
+                        outval = (yield getattr(module, "%s_output" % mod_attr))
+                        msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
+                            f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                        print((msg % str(maskbit_list)).format(locals()))
+                        self.assertEqual(y, outval, msg % str(maskbit_list))
 
             for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
+                                        (test_all_fn, "all"),
                                         (test_bool_fn, "bool"),
                                         #(test_ne_fn, "ne"),
                                         ):
index 7ee6cf8b5cd87c7fb2fd9a65106b7708b03ae055..e10948d34175a89eb0bb4918bb70dfbc446ec698 100644 (file)
@@ -68,6 +68,7 @@ class AllCombiner(Combiner):
     def __init__(self, width):
         Combiner.__init__(self, operator.and_, width)
 
+
 class XORCombiner(Combiner):
     def __init__(self, width):
         Combiner.__init__(self, operator.xor, width)