add a MultiShift class for generating single-cycle bit-shifters
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 17 Feb 2019 10:03:51 +0000 (10:03 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 17 Feb 2019 10:03:51 +0000 (10:03 +0000)
src/add/fpbase.py
src/add/test_multishift.py [new file with mode: 0644]

index 6e1644f5fd3c86fddfd2a0778ff495265461c483..8c0b54b0800ac9fcbddc67d95361955fbe297c36 100644 (file)
@@ -2,7 +2,37 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const
+from nmigen import Signal, Cat, Const, Mux
+from math import log
+
+class MultiShift:
+    """ Generates variable-length single-cycle shifter from a series
+        of conditional tests on each bit of the left/right shift operand.
+        Each bit tested produces output shifted by that number of bits,
+        in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
+        shifts by 2 bits, each partial result cascading to the next Mux.
+
+        Could be adapted to do arithmetic shift by taking copies of the
+        MSB instead of zeros.
+    """
+
+    def __init__(self, width):
+        self.width = width
+        self.smax = int(log(width) / log(2))
+
+    def lshift(self, op, s):
+        res = op
+        for i in range(self.smax):
+            zeros = [0] * (1<<i)
+            res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
+        return res
+
+    def rshift(self, op, s):
+        res = op
+        for i in range(self.smax):
+            zeros = [0] * (1<<i)
+            res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
+        return res
 
 
 class FPNum:
diff --git a/src/add/test_multishift.py b/src/add/test_multishift.py
new file mode 100644 (file)
index 0000000..0486d33
--- /dev/null
@@ -0,0 +1,73 @@
+from random import randint
+from nmigen import Module, Signal
+from nmigen.compat.sim import run_simulation
+
+from fpbase import MultiShift
+
+class MultiShiftModL:
+    def __init__(self, width):
+        self.ms = MultiShift(width)
+        self.a = Signal(width)
+        self.b = Signal(self.ms.smax)
+        self.x = Signal(width)
+
+    def get_fragment(self, platform=None):
+
+        m = Module()
+        m.d.comb += self.x.eq(self.ms.lshift(self.a, self.b))
+
+        return m
+
+class MultiShiftModR:
+    def __init__(self, width):
+        self.ms = MultiShift(width)
+        self.a = Signal(width)
+        self.b = Signal(self.ms.smax)
+        self.x = Signal(width)
+
+    def get_fragment(self, platform=None):
+
+        m = Module()
+        m.d.comb += self.x.eq(self.ms.rshift(self.a, self.b))
+
+        return m
+
+def check_case(dut, width, a, b):
+    yield dut.a.eq(a)
+    yield dut.b.eq(b)
+    yield
+
+    x = (a << b) & ((1<<width)-1)
+
+    out_x = yield dut.x
+    assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
+
+def check_caser(dut, width, a, b):
+    yield dut.a.eq(a)
+    yield dut.b.eq(b)
+    yield
+
+    x = (a >> b) & ((1<<width)-1)
+
+    out_x = yield dut.x
+    assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
+
+def testbench(dut):
+    for i in range(32):
+        for j in range(1000):
+            a = randint(0, (1<<32)-1)
+            yield from check_case(dut, 32, a, i)
+
+def testbenchr(dut):
+    for i in range(32):
+        for j in range(1000):
+            a = randint(0, (1<<32)-1)
+            yield from check_caser(dut, 32, a, i)
+
+if __name__ == '__main__':
+    dut = MultiShiftModR(width=32)
+    run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")
+
+    dut = MultiShiftModL(width=32)
+    run_simulation(dut, testbench(dut), vcd_name="test_multishift.vcd")
+