convert shift_rot pipeline to XLEN=32/64
[soc.git] / src / soc / fu / shift_rot / rotator.py
index 7c3d811c8fa0402a70d5a3e1551e8ecb83873280..eac042fedcece092fec572ed75dd9759f852728e 100644 (file)
@@ -11,18 +11,18 @@ from nmutil.mask import Mask
 
 
 # note BE bit numbering
-def right_mask(m, mask_begin):
-    ret = Signal(64, name="right_mask", reset_less=True)
-    with m.If(mask_begin <= 64):
-        m.d.comb += ret.eq((1 << (64-mask_begin)) - 1)
+def right_mask(m, mask_begin, width):
+    ret = Signal(width, name="right_mask", reset_less=True)
+    with m.If(mask_begin <= width):
+        m.d.comb += ret.eq((1 << (width-mask_begin)) - 1)
     with m.Else():
         m.d.comb += ret.eq(0)
     return ret
 
 
-def left_mask(m, mask_end):
-    ret = Signal(64, name="left_mask", reset_less=True)
-    m.d.comb += ret.eq(~((1 << (63-mask_end)) - 1))
+def left_mask(m, mask_end, width):
+    ret = Signal(width, name="left_mask", reset_less=True)
+    m.d.comb += ret.eq(~((1 << (width-1-mask_end)) - 1))
     return ret
 
 
@@ -45,14 +45,15 @@ class Rotator(Elaboratable):
         * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
     """
 
-    def __init__(self):
+    def __init__(self, width):
+        self.width = width
         # input
         self.me = Signal(5, reset_less=True)        # ME field
         self.mb = Signal(5, reset_less=True)        # MB field
         # extra bit of mb in MD-form
         self.mb_extra = Signal(1, reset_less=True)
-        self.ra = Signal(64, reset_less=True)       # RA
-        self.rs = Signal(64, reset_less=True)       # RS
+        self.ra = Signal(width, reset_less=True)       # RA
+        self.rs = Signal(width, reset_less=True)       # RS
         self.shift = Signal(7, reset_less=True)     # RB[0:7]
         self.is_32bit = Signal(reset_less=True)
         self.right_shift = Signal(reset_less=True)
@@ -61,10 +62,11 @@ class Rotator(Elaboratable):
         self.clear_right = Signal(reset_less=True)
         self.sign_ext_rs = Signal(reset_less=True)
         # output
-        self.result_o = Signal(64, reset_less=True)
+        self.result_o = Signal(width, reset_less=True)
         self.carry_out_o = Signal(reset_less=True)
 
     def elaborate(self, platform):
+        width = self.width
         m = Module()
         comb = m.d.comb
         ra, rs = self.ra, self.rs
@@ -75,11 +77,11 @@ class Rotator(Elaboratable):
         sh = Signal(7, reset_less=True)
         mb = Signal(7, reset_less=True)
         me = Signal(7, reset_less=True)
-        mr = Signal(64, reset_less=True)
-        ml = Signal(64, reset_less=True)
+        mr = Signal(width, reset_less=True)
+        ml = Signal(width, reset_less=True)
         output_mode = Signal(2, reset_less=True)
         hi32 = Signal(32, reset_less=True)
-        repl32 = Signal(64, reset_less=True)
+        repl32 = Signal(width, reset_less=True)
 
         # First replicate bottom 32 bits to both halves if 32-bit
         with m.If(self.is_32bit):
@@ -88,7 +90,8 @@ class Rotator(Elaboratable):
             # sign-extend bottom 32 bits
             comb += hi32.eq(Repl(rs[31], 32))
         with m.Else():
-            comb += hi32.eq(rs[32:64])
+            if width == 64:
+                comb += hi32.eq(rs[32:64])
         comb += repl32.eq(Cat(rs[0:32], hi32))
 
         shift_signed = Signal(signed(6))
@@ -101,7 +104,7 @@ class Rotator(Elaboratable):
             comb += rot_count.eq(self.shift[0:6])
 
         # ROTL submodule
-        m.submodules.rotl = rotl = ROTL(64)
+        m.submodules.rotl = rotl = ROTL(width)
         comb += rotl.a.eq(repl32)
         comb += rotl.b.eq(rot_count)
         comb += rot.eq(rotl.o)
@@ -139,16 +142,16 @@ class Rotator(Elaboratable):
             comb += me.eq(Cat(~sh[0:6], sh[6]))
 
         # Calculate left and right masks
-        m.submodules.right_mask = right_mask = Mask(64)
-        with m.If(mb <= 64):
-            comb += right_mask.shift.eq(64-mb)
+        m.submodules.right_mask = right_mask = Mask(width)
+        with m.If(mb <= width):
+            comb += right_mask.shift.eq(width-mb)
             comb += mr.eq(right_mask.mask)
         with m.Else():
             comb += mr.eq(0)
         #comb += mr.eq(right_mask(m, mb))
 
-        m.submodules.left_mask = left_mask = Mask(64)
-        comb += left_mask.shift.eq(63-me)
+        m.submodules.left_mask = left_mask = Mask(width)
+        comb += left_mask.shift.eq(width-1-me)
         comb += ml.eq(~left_mask.mask)
         #comb += ml.eq(left_mask(m, me))
 
@@ -159,7 +162,8 @@ class Rotator(Elaboratable):
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
         with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
-            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1)))
+            comb += output_mode.eq(Cat(self.arith &
+                                       repl32[width-1], Const(1, 1)))
         with m.Else():
             mbgt = self.clear_right & (mb[0:6] > me[0:6])
             comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
@@ -186,7 +190,7 @@ if __name__ == '__main__':
     comb = m.d.comb
     mr = Signal(64)
     mb = Signal(6)
-    comb += mr.eq(left_mask(m, mb))
+    comb += mr.eq(left_mask(m, mb, 64))
 
     def loop():
         for i in range(64):