Extend CLZ to work over even, non powers of 2
authorMichael Nolan <mtnolan2640@gmail.com>
Mon, 4 May 2020 19:56:24 +0000 (15:56 -0400)
committerMichael Nolan <mtnolan2640@gmail.com>
Mon, 4 May 2020 19:56:24 +0000 (15:56 -0400)
src/ieee754/cordic/clz.py
src/ieee754/cordic/formal/proof_clz.py
src/ieee754/cordic/test/test_clz.py

index a6dec14cbada3aeeddab7bdda9ab11a189b681f8..51239dc2f43e830bab8a251eac3cbb86d6b3a07e 100644 (file)
@@ -24,31 +24,45 @@ class CLZ(Elaboratable):
                     comb += pair_cnt.eq(1)
                 with m.Default():
                     comb += pair_cnt.eq(0)
-            pairs.append(pair_cnt)
+            pairs.append((pair_cnt, 2))  # append pair, max_value
         return pairs
 
     def combine_pairs(self, m, iteration, pairs):
         comb = m.d.comb
         length = len(pairs)
-        assert length % 2 == 0  # TODO handle non powers of 2
         ret = []
         for i in range(0, length, 2):
-            left = pairs[i+1]
-            right = pairs[i]
-            width = left.width + 1
-            print(left)
-            print(f"pair({i}, {i+1}) - cnt_{iteration}_{i}")
-            new_pair = Signal(left.width + 1, name="cnt_%d_%d" %
-                              (iteration, i))
-            with m.If(left[-1] == 1):
-                with m.If(right[-1] == 1):
-                    comb += new_pair.eq(Cat(Repl(0, width-1), 1))
-                with m.Else():
-                    comb += new_pair.eq(Cat(right[0:-1], 0b01))
-            with m.Else():
-                comb += new_pair.eq(Cat(left, 0))
+            if i+1 >= length:
+                right, mv = pairs[i]
+                width = right.width
+                print(f"single({i}) - cnt_{iteration}_{i}")
+                new_pair = Signal(width, name="cnt_%d_%d" % (iteration, i))
+                comb += new_pair.eq(Cat(right, 0))
+                ret.append((new_pair, mv))
+            else:
+                left, lv = pairs[i+1]
+                right, rv = pairs[i]
+                width = right.width + 1
+                print(left)
+                print(f"pair({left}, {right}) - cnt_{iteration}_{i}")
+                new_pair = Signal(width, name="cnt_%d_%d" %
+                                  (iteration, i))
+                if rv == lv:
+                    with m.If(left[-1] == 1):
+                        with m.If(right[-1] == 1):
+                            comb += new_pair.eq(Cat(Repl(0, width-1), 1))
+                        with m.Else():
+                            comb += new_pair.eq(Cat(right[0:-1], 0b01))
+                    with m.Else():
+                        comb += new_pair.eq(Cat(left, 0))
+                else:
+                    with m.If(left == lv):
+                        comb += new_pair.eq(right + left)
+                    with m.Else():
+                        comb += new_pair.eq(left)
+                        
 
-            ret.append(new_pair)
+                ret.append((new_pair, lv+rv))
         return ret
 
     def elaborate(self, platform):
@@ -59,9 +73,10 @@ class CLZ(Elaboratable):
         i = 2
         while len(pairs) > 1:
             pairs = self.combine_pairs(m, i, pairs)
+            print(pairs)
             i += 1
 
-        comb += self.lz.eq(pairs[0])
+        comb += self.lz.eq(pairs[0][0])
 
         return m
 
index 3eb4f5e42200b4d5b95de91f964d271cffb886c4..273fc61f05b8869be49a988639156bfe667cf472 100644 (file)
@@ -17,7 +17,7 @@ class Driver(Elaboratable):
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
-        width = 32
+        width = 10
 
         m.submodules.dut = dut = CLZ(width)
         sig_in = Signal.like(dut.sig_in)
index 78d43e311f110b92296ca31ecbcb4b694000e69f..c3051007f6f17cbb4204c7f2dd5df678d78ceab1 100644 (file)
@@ -37,6 +37,10 @@ class CLZTestCase(FHDLTestCase):
         inputs = [0, 15, 10, 127]
         self.run_test(iter(inputs), width=8)
 
+    def test_non_power_2(self):
+        inputs = [0, 128, 512]
+        self.run_test(iter(inputs), width=10)
+
 
 if __name__ == "__main__":
     unittest.main()