test shift against scalar b input
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 19 Feb 2020 14:33:23 +0000 (14:33 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 19 Feb 2020 14:33:23 +0000 (14:33 +0000)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py

index b08e9105b60ad4be95358cb0ac8e118f1d709fb4..bf0e85bc06bd2c4b2b1a49b7c748243c26d96d9f 100644 (file)
@@ -95,20 +95,22 @@ class PartitionedSignal:
     def ls_op(self, op1, op2, carry):
         op1 = getsig(op1)
         if isinstance(op2, Const) or isinstance(op2, Signal):
+            scalar = True
             shape = op1.shape()
             pa = PartitionedScalarShift(shape[0], self.partpoints)
         else:
+            scalar = False
             op2 = getsig(op2)
             shape = op1.shape()
             pa = PartitionedDynamicShift(shape[0], self.partpoints)
         setattr(self.m.submodules, self.get_modname('ls'), pa)
         comb = self.m.d.comb
-        if isinstance(op2, Const) or isinstance(op2, Signal):
-            comb += pa.a.eq(op1)
-            comb += pa.b.eq(op2)
-        else:
+        if scalar:
             comb += pa.data.eq(op1)
             comb += pa.shifter.eq(op2)
+        else:
+            comb += pa.a.eq(op1)
+            comb += pa.b.eq(op2)
         # XXX TODO: carry-in, carry-out
         #comb += pa.carry_in.eq(carry)
         return (pa.output, 0)
index 435292141b13af5da658a4c00dd335544b78e890..2fbabb13768a12e86fd84eb3de5a95846d3ea0aa 100644 (file)
@@ -52,8 +52,10 @@ class TestAddMod(Elaboratable):
         self.partpoints = partpoints
         self.a = PartitionedSignal(partpoints, width)
         self.b = PartitionedSignal(partpoints, width)
+        self.bsig = Signal(width)
         self.add_output = Signal(width)
         self.ls_output = Signal(width) # left shift
+        self.ls_scal_output = Signal(width) # left shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -96,6 +98,9 @@ class TestAddMod(Elaboratable):
         comb += self.ls_output.eq(self.a << self.b)
         ppts = self.partpoints
         comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+        # scalar left shift
+        comb += self.bsig.eq(self.b.sig)
+        comb += self.ls_scal_output.eq(self.a << self.bsig)
 
         return m
 
@@ -116,6 +121,23 @@ class TestPartitionPoints(unittest.TestCase):
 
         def async_process():
 
+            def test_ls_scal_fn(carry_in, a, b, mask):
+                # reduce range of b
+                bits = count_bits(mask)
+                newb = b & ((bits-1))
+                print ("%x %x %x bits %d trunc %x" % \
+                        (a, b, mask, bits, newb))
+                b = newb
+                # TODO: carry
+                carry_in = 0
+                lsb = mask & ~(mask-1) if carry_in else 0
+                sum = ((a & mask) << b)
+                result = mask & sum
+                carry = (sum & mask) != sum
+                carry = 0
+                print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+                return result, carry
+
             def test_ls_fn(carry_in, a, b, mask):
                 # reduce range of b
                 bits = count_bits(mask)
@@ -133,7 +155,7 @@ class TestPartitionPoints(unittest.TestCase):
                 result = mask & sum
                 carry = (sum & mask) != sum
                 carry = 0
-                print("result", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+                print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
                 return result, carry
 
             def test_add_fn(carry_in, a, b, mask):
@@ -194,6 +216,7 @@ class TestPartitionPoints(unittest.TestCase):
                         self.assertEqual(carry_result, c_outval, msg)
 
             for (test_fn, mod_attr) in (
+                                        (test_ls_scal_fn, "ls_scal"),
                                         (test_ls_fn, "ls"),
                                         (test_add_fn, "add"),
                                         (test_sub_fn, "sub"),