From a36539e789f37f6f385ef719684fbf04c0036d91 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sun, 27 Feb 2022 19:25:54 +0000 Subject: [PATCH] convert shift_rot pipeline to XLEN=32/64 --- src/soc/fu/shift_rot/main_stage.py | 7 +++-- src/soc/fu/shift_rot/pipe_data.py | 37 +++++++++++++--------- src/soc/fu/shift_rot/rotator.py | 50 ++++++++++++++++-------------- 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/src/soc/fu/shift_rot/main_stage.py b/src/soc/fu/shift_rot/main_stage.py index e249d54d..df8b17c2 100644 --- a/src/soc/fu/shift_rot/main_stage.py +++ b/src/soc/fu/shift_rot/main_stage.py @@ -34,6 +34,7 @@ class ShiftRotMainStage(PipeModBase): return ShiftRotOutputData(self.pspec) def elaborate(self, platform): + XLEN = self.pspec.XLEN m = Module() comb = m.d.comb op = self.i.ctx.op @@ -42,13 +43,13 @@ class ShiftRotMainStage(PipeModBase): bitwise_lut = None grev = None if self.draft_bitmanip: - bitwise_lut = BitwiseLut(input_count=3, width=64) + bitwise_lut = BitwiseLut(input_count=3, width=XLEN) m.submodules.bitwise_lut = bitwise_lut comb += bitwise_lut.inputs[0].eq(self.i.rb) comb += bitwise_lut.inputs[1].eq(self.i.ra) comb += bitwise_lut.inputs[2].eq(self.i.rc) # 6 == log2(64) because we have 64-bit values - grev = GRev(log2_width=6) + grev = GRev(log2_width=XLEN.bit_length()) m.submodules.grev = grev with m.If(op.is_32bit): # 32-bit, so input is lower 32-bits zero-extended @@ -74,7 +75,7 @@ class ShiftRotMainStage(PipeModBase): comb += mb_extra.eq(md_fields['mb'][0:-1][0]) # set up microwatt rotator module - m.submodules.rotator = rotator = Rotator() + m.submodules.rotator = rotator = Rotator(XLEN) comb += [ rotator.me.eq(me), rotator.mb.eq(mb), diff --git a/src/soc/fu/shift_rot/pipe_data.py b/src/soc/fu/shift_rot/pipe_data.py index 276320db..d783d017 100644 --- a/src/soc/fu/shift_rot/pipe_data.py +++ b/src/soc/fu/shift_rot/pipe_data.py @@ -4,42 +4,51 @@ from soc.fu.alu.pipe_data import ALUOutputData class ShiftRotInputData(FUBaseData): - regspec = [('INT', 'ra', '0:63'), # RA - ('INT', 'rb', '0:63'), # RB - ('INT', 'rc', '0:63'), # RS - ('XER', 'xer_so', '32'), # XER bit 32: SO - ('XER', 'xer_ca', '34,45')] # XER bit 34/45: CA/CA32 def __init__(self, pspec): super().__init__(pspec, False) # convenience self.a, self.b, self.rs = self.ra, self.rb, self.rc + @property + def regspec(self): + return [('INT', 'ra', self.intrange), # RA + ('INT', 'rb', self.intrange), # RB/immediate + ('INT', 'rc', self.intrange), # RB/immediate + ('XER', 'xer_so', '32'), # XER bit 32: SO + ('XER', 'xer_ca', '34,45')] # XER bit 34/45: CA/CA32 + # input to shiftrot final stage (common output) class ShiftRotOutputData(FUBaseData): - regspec = [('INT', 'o', '0:63'), # RT - ('CR', 'cr_a', '0:3'), - ('XER', 'xer_so', '32'), # bit0: so - ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32 - ] def __init__(self, pspec): super().__init__(pspec, True) # convenience self.cr0 = self.cr_a + @property + def regspec(self): + return [('INT', 'o', self.intrange), + ('CR', 'cr_a', '0:3'), + ('XER', 'xer_so', '32'), # bit0: so + ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32 + ] + # output from shiftrot final stage (common output) - note that XER.so # is *not* included (the only reason it's in the input is because of CR0) class ShiftRotOutputDataFinal(FUBaseData): - regspec = [('INT', 'o', '0:63'), # RT - ('CR', 'cr_a', '0:3'), - ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32 - ] def __init__(self, pspec): super().__init__(pspec, True) # convenience self.cr0 = self.cr_a + @property + def regspec(self): + return [('INT', 'o', self.intrange), + ('CR', 'cr_a', '0:3'), + ('XER', 'xer_ca', '34,45'), # XER bit 34/45: CA/CA32 + ] + class ShiftRotPipeSpec(CommonPipeSpec): regspecklses = (ShiftRotInputData, ShiftRotOutputDataFinal) diff --git a/src/soc/fu/shift_rot/rotator.py b/src/soc/fu/shift_rot/rotator.py index 7c3d811c..eac042fe 100644 --- a/src/soc/fu/shift_rot/rotator.py +++ b/src/soc/fu/shift_rot/rotator.py @@ -11,18 +11,18 @@ from nmutil.mask import Mask # note BE bit numbering -def right_mask(m, mask_begin): - ret = Signal(64, name="right_mask", reset_less=True) - with m.If(mask_begin <= 64): - m.d.comb += ret.eq((1 << (64-mask_begin)) - 1) +def right_mask(m, mask_begin, width): + ret = Signal(width, name="right_mask", reset_less=True) + with m.If(mask_begin <= width): + m.d.comb += ret.eq((1 << (width-mask_begin)) - 1) with m.Else(): m.d.comb += ret.eq(0) return ret -def left_mask(m, mask_end): - ret = Signal(64, name="left_mask", reset_less=True) - m.d.comb += ret.eq(~((1 << (63-mask_end)) - 1)) +def left_mask(m, mask_end, width): + ret = Signal(width, name="left_mask", reset_less=True) + m.d.comb += ret.eq(~((1 << (width-1-mask_end)) - 1)) return ret @@ -45,14 +45,15 @@ class Rotator(Elaboratable): * clear_right = 1 when insn_type is OP_RLC or OP_RLCR """ - def __init__(self): + def __init__(self, width): + self.width = width # input self.me = Signal(5, reset_less=True) # ME field self.mb = Signal(5, reset_less=True) # MB field # extra bit of mb in MD-form self.mb_extra = Signal(1, reset_less=True) - self.ra = Signal(64, reset_less=True) # RA - self.rs = Signal(64, reset_less=True) # RS + self.ra = Signal(width, reset_less=True) # RA + self.rs = Signal(width, reset_less=True) # RS self.shift = Signal(7, reset_less=True) # RB[0:7] self.is_32bit = Signal(reset_less=True) self.right_shift = Signal(reset_less=True) @@ -61,10 +62,11 @@ class Rotator(Elaboratable): self.clear_right = Signal(reset_less=True) self.sign_ext_rs = Signal(reset_less=True) # output - self.result_o = Signal(64, reset_less=True) + self.result_o = Signal(width, reset_less=True) self.carry_out_o = Signal(reset_less=True) def elaborate(self, platform): + width = self.width m = Module() comb = m.d.comb ra, rs = self.ra, self.rs @@ -75,11 +77,11 @@ class Rotator(Elaboratable): sh = Signal(7, reset_less=True) mb = Signal(7, reset_less=True) me = Signal(7, reset_less=True) - mr = Signal(64, reset_less=True) - ml = Signal(64, reset_less=True) + mr = Signal(width, reset_less=True) + ml = Signal(width, reset_less=True) output_mode = Signal(2, reset_less=True) hi32 = Signal(32, reset_less=True) - repl32 = Signal(64, reset_less=True) + repl32 = Signal(width, reset_less=True) # First replicate bottom 32 bits to both halves if 32-bit with m.If(self.is_32bit): @@ -88,7 +90,8 @@ class Rotator(Elaboratable): # sign-extend bottom 32 bits comb += hi32.eq(Repl(rs[31], 32)) with m.Else(): - comb += hi32.eq(rs[32:64]) + if width == 64: + comb += hi32.eq(rs[32:64]) comb += repl32.eq(Cat(rs[0:32], hi32)) shift_signed = Signal(signed(6)) @@ -101,7 +104,7 @@ class Rotator(Elaboratable): comb += rot_count.eq(self.shift[0:6]) # ROTL submodule - m.submodules.rotl = rotl = ROTL(64) + m.submodules.rotl = rotl = ROTL(width) comb += rotl.a.eq(repl32) comb += rotl.b.eq(rot_count) comb += rot.eq(rotl.o) @@ -139,16 +142,16 @@ class Rotator(Elaboratable): comb += me.eq(Cat(~sh[0:6], sh[6])) # Calculate left and right masks - m.submodules.right_mask = right_mask = Mask(64) - with m.If(mb <= 64): - comb += right_mask.shift.eq(64-mb) + m.submodules.right_mask = right_mask = Mask(width) + with m.If(mb <= width): + comb += right_mask.shift.eq(width-mb) comb += mr.eq(right_mask.mask) with m.Else(): comb += mr.eq(0) #comb += mr.eq(right_mask(m, mb)) - m.submodules.left_mask = left_mask = Mask(64) - comb += left_mask.shift.eq(63-me) + m.submodules.left_mask = left_mask = Mask(width) + comb += left_mask.shift.eq(width-1-me) comb += ml.eq(~left_mask.mask) #comb += ml.eq(left_mask(m, me)) @@ -159,7 +162,8 @@ class Rotator(Elaboratable): # 10 for rldicl, sr[wd] # 1z for sra[wd][i], z = 1 if rs is negative with m.If((self.clear_left & ~self.clear_right) | self.right_shift): - comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1))) + comb += output_mode.eq(Cat(self.arith & + repl32[width-1], Const(1, 1))) with m.Else(): mbgt = self.clear_right & (mb[0:6] > me[0:6]) comb += output_mode.eq(Cat(mbgt, Const(0, 1))) @@ -186,7 +190,7 @@ if __name__ == '__main__': comb = m.d.comb mr = Signal(64) mb = Signal(6) - comb += mr.eq(left_mask(m, mb)) + comb += mr.eq(left_mask(m, mb, 64)) def loop(): for i in range(64): -- 2.30.2