convert shift_rot pipeline to XLEN=32/64
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 27 Feb 2022 19:25:54 +0000 (19:25 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 27 Feb 2022 19:25:54 +0000 (19:25 +0000)
src/soc/fu/shift_rot/main_stage.py
src/soc/fu/shift_rot/pipe_data.py
src/soc/fu/shift_rot/rotator.py

index e249d54d3f126806707becd964c63c565ba6a862..df8b17c23eacd2f62d9696ac37b1c83ca9104edf 100644 (file)
@@ -34,6 +34,7 @@ class ShiftRotMainStage(PipeModBase):
         return ShiftRotOutputData(self.pspec)
 
     def elaborate(self, platform):
+        XLEN = self.pspec.XLEN
         m = Module()
         comb = m.d.comb
         op = self.i.ctx.op
@@ -42,13 +43,13 @@ class ShiftRotMainStage(PipeModBase):
         bitwise_lut = None
         grev = None
         if self.draft_bitmanip:
-            bitwise_lut = BitwiseLut(input_count=3, width=64)
+            bitwise_lut = BitwiseLut(input_count=3, width=XLEN)
             m.submodules.bitwise_lut = bitwise_lut
             comb += bitwise_lut.inputs[0].eq(self.i.rb)
             comb += bitwise_lut.inputs[1].eq(self.i.ra)
             comb += bitwise_lut.inputs[2].eq(self.i.rc)
             # 6 == log2(64) because we have 64-bit values
-            grev = GRev(log2_width=6)
+            grev = GRev(log2_width=XLEN.bit_length())
             m.submodules.grev = grev
             with m.If(op.is_32bit):
                 # 32-bit, so input is lower 32-bits zero-extended
@@ -74,7 +75,7 @@ class ShiftRotMainStage(PipeModBase):
         comb += mb_extra.eq(md_fields['mb'][0:-1][0])
 
         # set up microwatt rotator module
-        m.submodules.rotator = rotator = Rotator()
+        m.submodules.rotator = rotator = Rotator(XLEN)
         comb += [
             rotator.me.eq(me),
             rotator.mb.eq(mb),
index 276320dbcaead42bdf79a08098a9207202139a79..d783d017ed3851ea2dbb55b2f56b2e20d2ef66eb 100644 (file)
@@ -4,42 +4,51 @@ from soc.fu.alu.pipe_data import ALUOutputData
 
 
 class ShiftRotInputData(FUBaseData):
-    regspec = [('INT', 'ra', '0:63'),      # RA
-               ('INT', 'rb', '0:63'),      # RB
-               ('INT', 'rc', '0:63'),      # RS
-               ('XER', 'xer_so', '32'), # XER bit 32: SO
-               ('XER', 'xer_ca', '34,45')] # XER bit 34/45: CA/CA32
     def __init__(self, pspec):
         super().__init__(pspec, False)
         # convenience
         self.a, self.b, self.rs = self.ra, self.rb, self.rc
 
+    @property
+    def regspec(self):
+        return [('INT', 'ra', self.intrange),  # RA
+               ('INT', 'rb', self.intrange),  # RB/immediate
+               ('INT', 'rc', self.intrange),  # RB/immediate
+               ('XER', 'xer_so', '32'), # XER bit 32: SO
+               ('XER', 'xer_ca', '34,45')] # XER bit 34/45: CA/CA32
+
 
 # input to shiftrot final stage (common output)
 class ShiftRotOutputData(FUBaseData):
-    regspec = [('INT', 'o', '0:63'),        # RT
-               ('CR', 'cr_a', '0:3'),
-               ('XER', 'xer_so', '32'),    # bit0: so
-               ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32
-               ]
     def __init__(self, pspec):
         super().__init__(pspec, True)
         # convenience
         self.cr0 = self.cr_a
 
+    @property
+    def regspec(self):
+        return [('INT', 'o', self.intrange),
+               ('CR', 'cr_a', '0:3'),
+               ('XER', 'xer_so', '32'),    # bit0: so
+               ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32
+               ]
+
 
 # output from shiftrot final stage (common output) - note that XER.so
 # is *not* included (the only reason it's in the input is because of CR0)
 class ShiftRotOutputDataFinal(FUBaseData):
-    regspec = [('INT', 'o', '0:63'),        # RT
-               ('CR', 'cr_a', '0:3'),
-               ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32
-               ]
     def __init__(self, pspec):
         super().__init__(pspec, True)
         # convenience
         self.cr0 = self.cr_a
 
+    @property
+    def regspec(self):
+        return [('INT', 'o', self.intrange),
+               ('CR', 'cr_a', '0:3'),
+               ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32
+               ]
+
 
 class ShiftRotPipeSpec(CommonPipeSpec):
     regspecklses = (ShiftRotInputData, ShiftRotOutputDataFinal)
index 7c3d811c8fa0402a70d5a3e1551e8ecb83873280..eac042fedcece092fec572ed75dd9759f852728e 100644 (file)
@@ -11,18 +11,18 @@ from nmutil.mask import Mask
 
 
 # note BE bit numbering
-def right_mask(m, mask_begin):
-    ret = Signal(64, name="right_mask", reset_less=True)
-    with m.If(mask_begin <= 64):
-        m.d.comb += ret.eq((1 << (64-mask_begin)) - 1)
+def right_mask(m, mask_begin, width):
+    ret = Signal(width, name="right_mask", reset_less=True)
+    with m.If(mask_begin <= width):
+        m.d.comb += ret.eq((1 << (width-mask_begin)) - 1)
     with m.Else():
         m.d.comb += ret.eq(0)
     return ret
 
 
-def left_mask(m, mask_end):
-    ret = Signal(64, name="left_mask", reset_less=True)
-    m.d.comb += ret.eq(~((1 << (63-mask_end)) - 1))
+def left_mask(m, mask_end, width):
+    ret = Signal(width, name="left_mask", reset_less=True)
+    m.d.comb += ret.eq(~((1 << (width-1-mask_end)) - 1))
     return ret
 
 
@@ -45,14 +45,15 @@ class Rotator(Elaboratable):
         * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
     """
 
-    def __init__(self):
+    def __init__(self, width):
+        self.width = width
         # input
         self.me = Signal(5, reset_less=True)        # ME field
         self.mb = Signal(5, reset_less=True)        # MB field
         # extra bit of mb in MD-form
         self.mb_extra = Signal(1, reset_less=True)
-        self.ra = Signal(64, reset_less=True)       # RA
-        self.rs = Signal(64, reset_less=True)       # RS
+        self.ra = Signal(width, reset_less=True)       # RA
+        self.rs = Signal(width, reset_less=True)       # RS
         self.shift = Signal(7, reset_less=True)     # RB[0:7]
         self.is_32bit = Signal(reset_less=True)
         self.right_shift = Signal(reset_less=True)
@@ -61,10 +62,11 @@ class Rotator(Elaboratable):
         self.clear_right = Signal(reset_less=True)
         self.sign_ext_rs = Signal(reset_less=True)
         # output
-        self.result_o = Signal(64, reset_less=True)
+        self.result_o = Signal(width, reset_less=True)
         self.carry_out_o = Signal(reset_less=True)
 
     def elaborate(self, platform):
+        width = self.width
         m = Module()
         comb = m.d.comb
         ra, rs = self.ra, self.rs
@@ -75,11 +77,11 @@ class Rotator(Elaboratable):
         sh = Signal(7, reset_less=True)
         mb = Signal(7, reset_less=True)
         me = Signal(7, reset_less=True)
-        mr = Signal(64, reset_less=True)
-        ml = Signal(64, reset_less=True)
+        mr = Signal(width, reset_less=True)
+        ml = Signal(width, reset_less=True)
         output_mode = Signal(2, reset_less=True)
         hi32 = Signal(32, reset_less=True)
-        repl32 = Signal(64, reset_less=True)
+        repl32 = Signal(width, reset_less=True)
 
         # First replicate bottom 32 bits to both halves if 32-bit
         with m.If(self.is_32bit):
@@ -88,7 +90,8 @@ class Rotator(Elaboratable):
             # sign-extend bottom 32 bits
             comb += hi32.eq(Repl(rs[31], 32))
         with m.Else():
-            comb += hi32.eq(rs[32:64])
+            if width == 64:
+                comb += hi32.eq(rs[32:64])
         comb += repl32.eq(Cat(rs[0:32], hi32))
 
         shift_signed = Signal(signed(6))
@@ -101,7 +104,7 @@ class Rotator(Elaboratable):
             comb += rot_count.eq(self.shift[0:6])
 
         # ROTL submodule
-        m.submodules.rotl = rotl = ROTL(64)
+        m.submodules.rotl = rotl = ROTL(width)
         comb += rotl.a.eq(repl32)
         comb += rotl.b.eq(rot_count)
         comb += rot.eq(rotl.o)
@@ -139,16 +142,16 @@ class Rotator(Elaboratable):
             comb += me.eq(Cat(~sh[0:6], sh[6]))
 
         # Calculate left and right masks
-        m.submodules.right_mask = right_mask = Mask(64)
-        with m.If(mb <= 64):
-            comb += right_mask.shift.eq(64-mb)
+        m.submodules.right_mask = right_mask = Mask(width)
+        with m.If(mb <= width):
+            comb += right_mask.shift.eq(width-mb)
             comb += mr.eq(right_mask.mask)
         with m.Else():
             comb += mr.eq(0)
         #comb += mr.eq(right_mask(m, mb))
 
-        m.submodules.left_mask = left_mask = Mask(64)
-        comb += left_mask.shift.eq(63-me)
+        m.submodules.left_mask = left_mask = Mask(width)
+        comb += left_mask.shift.eq(width-1-me)
         comb += ml.eq(~left_mask.mask)
         #comb += ml.eq(left_mask(m, me))
 
@@ -159,7 +162,8 @@ class Rotator(Elaboratable):
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
         with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
-            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1)))
+            comb += output_mode.eq(Cat(self.arith &
+                                       repl32[width-1], Const(1, 1)))
         with m.Else():
             mbgt = self.clear_right & (mb[0:6] > me[0:6])
             comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
@@ -186,7 +190,7 @@ if __name__ == '__main__':
     comb = m.d.comb
     mr = Signal(64)
     mb = Signal(6)
-    comb += mr.eq(left_mask(m, mb))
+    comb += mr.eq(left_mask(m, mb, 64))
 
     def loop():
         for i in range(64):