update comments
[ieee754fpu.git] / src / add / test_multishift.py
index 5fa649ef83f79dcbce7d3338da418d02dd9b9b2d..651e5018a35d332a5b2f1809493f9b1e747db139 100644 (file)
@@ -2,7 +2,7 @@ from random import randint
 from nmigen import Module, Signal
 from nmigen.compat.sim import run_simulation
 
-from fpbase import MultiShift, MultiShiftR
+from fpbase import MultiShift, MultiShiftR, MultiShiftRMerge
 
 class MultiShiftModL:
     def __init__(self, width):
@@ -11,7 +11,7 @@ class MultiShiftModL:
         self.b = Signal(self.ms.smax)
         self.x = Signal(width)
 
-    def get_fragment(self, platform=None):
+    def elaborate(self, platform=None):
 
         m = Module()
         m.d.comb += self.x.eq(self.ms.lshift(self.a, self.b))
@@ -25,7 +25,7 @@ class MultiShiftModR:
         self.b = Signal(self.ms.smax)
         self.x = Signal(width)
 
-    def get_fragment(self, platform=None):
+    def elaborate(self, platform=None):
 
         m = Module()
         m.d.comb += self.x.eq(self.ms.rshift(self.a, self.b))
@@ -39,7 +39,7 @@ class MultiShiftModRMod:
         self.b = Signal(self.ms.smax)
         self.x = Signal(width)
 
-    def get_fragment(self, platform=None):
+    def elaborate(self, platform=None):
 
         m = Module()
         m.submodules += self.ms
@@ -49,6 +49,24 @@ class MultiShiftModRMod:
 
         return m
 
+class MultiShiftRMergeMod:
+    def __init__(self, width):
+        self.ms = MultiShiftRMerge(width)
+        self.a = Signal(width)
+        self.b = Signal(self.ms.smax)
+        self.x = Signal(width)
+
+    def elaborate(self, platform=None):
+
+        m = Module()
+        m.submodules += self.ms
+        m.d.comb += self.ms.inp.eq(self.a)
+        m.d.comb += self.ms.diff.eq(self.b)
+        m.d.comb += self.x.eq(self.ms.m)
+
+        return m
+
+
 def check_case(dut, width, a, b):
     yield dut.a.eq(a)
     yield dut.b.eq(b)
@@ -69,6 +87,27 @@ def check_caser(dut, width, a, b):
     out_x = yield dut.x
     assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
 
+
+def check_case_merge(dut, width, a, b):
+    yield dut.a.eq(a)
+    yield dut.b.eq(b)
+    yield
+
+    x = (a >> b) & ((1<<width)-1) # actual shift
+    if (a & ((2<<b)-1)) != 0: # mask for sticky bit
+        x |= 1 # set LSB
+
+    out_x = yield dut.x
+    assert out_x == x, \
+                "\nshift %d\nInput\n%+32s\nOutput x\n%+32s != \n%+32s" % \
+                        (b, bin(a), bin(out_x), bin(x))
+
+def testmerge(dut):
+    for i in range(32):
+        for j in range(1000):
+            a = randint(0, (1<<32)-1)
+            yield from check_case_merge(dut, 32, a, i)
+
 def testbench(dut):
     for i in range(32):
         for j in range(1000):
@@ -82,6 +121,8 @@ def testbenchr(dut):
             yield from check_caser(dut, 32, a, i)
 
 if __name__ == '__main__':
+    dut = MultiShiftRMergeMod(width=32)
+    run_simulation(dut, testmerge(dut), vcd_name="test_multishiftmerge.vcd")
     dut = MultiShiftModRMod(width=32)
     run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")