Shift-out skipped mask bits for both crpred and intpred
authorCesar Strauss <cestrauss@gmail.com>
Sun, 25 Apr 2021 19:11:19 +0000 (16:11 -0300)
committerCesar Strauss <cestrauss@gmail.com>
Sun, 25 Apr 2021 20:20:08 +0000 (17:20 -0300)
If src/dest step are not zero, we need to shift-out the skipped mask
bits. We already did this for intpred, and for crpred it's exactly the
same.
Move the shifting logic to a new last state, commonly used for both
intpred and crpred.

src/soc/simple/issuer.py

index 3969a4b42fcc2206b9f76cd97c8982498ec37463..9d50a075179e2ec77a2a77bd42862dee0b60f02b 100644 (file)
@@ -422,6 +422,12 @@ class TestIssuerInternal(Elaboratable):
         sidx, scrinvert = get_predcr(m, srcpred, 's')
         didx, dcrinvert = get_predcr(m, dstpred, 'd')
 
+        # store fetched masks, for either intpred or crpred
+        # when src/dst step is not zero, the skipped mask bits need to be
+        # shifted-out, before actually storing them in src/dest mask
+        new_srcmask = Signal(64, reset_less=True)
+        new_dstmask = Signal(64, reset_less=True)
+
         with m.FSM(name="fetch_predicate"):
 
             with m.State("FETCH_PRED_IDLE"):
@@ -430,7 +436,7 @@ class TestIssuerInternal(Elaboratable):
                     with m.If(predmode == SVP64PredMode.INT):
                         # skip fetching destination mask register, when zero
                         with m.If(dall1s):
-                            sync += self.dstmask.eq(-1)
+                            sync += new_dstmask.eq(-1)
                             # directly go to fetch source mask register
                             # guaranteed not to be zero (otherwise predmode
                             # would be SVP64PredMode.ALWAYS, not INT)
@@ -444,8 +450,8 @@ class TestIssuerInternal(Elaboratable):
                             m.next = "INT_DST_READ"
                     with m.Elif(predmode == SVP64PredMode.CR):
                         # go fetch masks from the CR register file
-                        sync += self.srcmask.eq(0)
-                        sync += self.dstmask.eq(0)
+                        sync += new_srcmask.eq(0)
+                        sync += new_dstmask.eq(0)
                         m.next = "CR_READ"
                     with m.Else():
                         sync += self.srcmask.eq(-1)
@@ -455,21 +461,18 @@ class TestIssuerInternal(Elaboratable):
             with m.State("INT_DST_READ"):
                 # store destination mask
                 inv = Repl(dinvert, 64)
-                new_dstmask = Signal(64)
                 with m.If(dunary):
                     # set selected mask bit for 1<<r3 mode
                     dst_shift = Signal(range(64))
                     comb += dst_shift.eq(self.int_pred.data_o & 0b111111)
-                    comb += new_dstmask.eq(1 << dst_shift)
+                    sync += new_dstmask.eq(1 << dst_shift)
                 with m.Else():
                     # invert mask if requested
-                    comb += new_dstmask.eq(self.int_pred.data_o ^ inv)
-                # shift-out already used mask bits
-                sync += self.dstmask.eq(new_dstmask >> dststep)
+                    sync += new_dstmask.eq(self.int_pred.data_o ^ inv)
                 # skip fetching source mask register, when zero
                 with m.If(sall1s):
-                    sync += self.srcmask.eq(-1)
-                    m.next = "FETCH_PRED_DONE"
+                    sync += new_srcmask.eq(-1)
+                    m.next = "FETCH_PRED_SHIFT_MASK"
                 # fetch source predicate register
                 with m.Else():
                     comb += int_pred.addr.eq(sregread)
@@ -479,18 +482,15 @@ class TestIssuerInternal(Elaboratable):
             with m.State("INT_SRC_READ"):
                 # store source mask
                 inv = Repl(sinvert, 64)
-                new_srcmask = Signal(64)
                 with m.If(sunary):
                     # set selected mask bit for 1<<r3 mode
                     src_shift = Signal(range(64))
                     comb += src_shift.eq(self.int_pred.data_o & 0b111111)
-                    comb += new_srcmask.eq(1 << src_shift)
+                    sync += new_srcmask.eq(1 << src_shift)
                 with m.Else():
                     # invert mask if requested
-                    comb += new_srcmask.eq(self.int_pred.data_o ^ inv)
-                # shift-out already used mask bits
-                sync += self.srcmask.eq(new_srcmask >> srcstep)
-                m.next = "FETCH_PRED_DONE"
+                    sync += new_srcmask.eq(self.int_pred.data_o ^ inv)
+                m.next = "FETCH_PRED_SHIFT_MASK"
 
             # fetch masks from the CR register file
             # implements the following loop:
@@ -523,7 +523,7 @@ class TestIssuerInternal(Elaboratable):
                     # exit on loop end
                     sync += cr_read.eq(0)
                     sync += cr_idx.eq(0)
-                    m.next = "FETCH_PRED_DONE"
+                    m.next = "FETCH_PRED_SHIFT_MASK"
                 with m.If(cr_read):
                     # compensate for the one cycle delay on the regfile
                     cur_cr_idx = Signal.like(cur_vl)
@@ -539,9 +539,15 @@ class TestIssuerInternal(Elaboratable):
                     bit_to_set = Signal.like(self.srcmask)
                     comb += bit_to_set.eq(1 << cur_cr_idx)
                     with m.If(scr_bit):
-                        sync += self.srcmask.eq(self.srcmask | bit_to_set)
+                        sync += new_srcmask.eq(new_srcmask | bit_to_set)
                     with m.If(dcr_bit):
-                        sync += self.dstmask.eq(self.dstmask | bit_to_set)
+                        sync += new_dstmask.eq(new_dstmask | bit_to_set)
+
+            with m.State("FETCH_PRED_SHIFT_MASK"):
+                # shift-out skipped mask bits
+                sync += self.srcmask.eq(new_srcmask >> srcstep)
+                sync += self.dstmask.eq(new_dstmask >> dststep)
+                m.next = "FETCH_PRED_DONE"
 
             with m.State("FETCH_PRED_DONE"):
                 comb += pred_mask_valid_o.eq(1)