big reorganisation to support twin-predication
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 4 Oct 2018 09:16:10 +0000 (10:16 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 4 Oct 2018 09:16:10 +0000 (10:16 +0100)
twin predication is on certain operations like LOAD, STORE, C.MV, FCVT
where the effect is similar to VSPLAT, VINSERT, and also bitmanip
scatter/gather.

id_regs.py
riscv/insn_template_sv.cc
riscv/sv.cc
riscv/sv_decode.h

index 734f02930f7f2d420de0e35c0fa827710510e340..49f68b0ac94ed91c2e47c0eecc53cf8aad83625f 100644 (file)
@@ -54,7 +54,15 @@ allints = intpatterns + cintpatterns[2:]
 
 skip = '#define USING_NOREGS\n' \
        '#define REGS_PATTERN 0x0\n'
-def find_registers(fname):
+
+# this matches the order of the 4 predication arguments to
+drlookup = { 'rd': 0, 'frd': 0, 'rs1': 1, 'rs2': 2, 'rs3': 3,
+             'rvc_rs1': 1, 'rvc_rs1s': 1,
+             'rvc_rs2': 2, 'rvc_rs2s': 2,
+             'rvc_frs2': 2, 'rvc_frs2s': 2,
+             }
+
+def find_registers(fname, twin_predication):
     # HACK! macro-skipping of instructions too painful
     for notparallel in ['csr', 'lui', 'c_j', 'wfi', 'auipc',
                         'dret', 'uret', 'mret', 'sret',
@@ -62,9 +70,11 @@ def find_registers(fname):
         if notparallel in fname:
             return skip
     res = []
+    regs = []
     isintfloat = 0x0 + floatmask << len(allints)
     with open(fname) as f:
         f = f.read()
+        dest_reg = None
         for pattern in patterns:
             x = f.find(pattern)
             if x == -1:
@@ -88,16 +98,50 @@ def find_registers(fname):
             p = pattern
             if p.startswith('WRITE_'):
                 p = p[6:]
+                dest_reg = p
             if pattern in allints:
                 idx = allints.index(pattern)
                 isintfloat += 1 << idx
             if pattern in allfloats:
                 idx = allfloats.index(pattern)
                 isintfloat &= ~(1 << (idx+len(allints)))
+            regs.append(p)
             res.append('#define USING_REG_%s' % p)
+        if dest_reg:
+            dr = dest_reg
+            fdest = False
+            if dest_reg.startswith('RVC_F'):
+                fdest = True
+                dr = 'RVC_' + dest_reg[5:]
+            if dest_reg == 'FRD':
+                fdest = True
+                dr = 'RD'
+            dridx = drlookup[dest_reg.lower()]
+            res.append('#define DEST_REG %s' % dr.lower())
+            res.append('#define _DEST_REG _%s' % dr.lower())
+            res.append('#define DEST_PREDINT %d' % (0 if fdest else 1))
     if not res:
         return skip
     res.append('#define REGS_PATTERN 0x%x' % isintfloat)
+
+    predargs = ['dest_pred'] * 4
+    if twin_predication:
+        found = None
+        for search in ['rs1', 'rs2', 'rs3', 'rvc_rs1', 'rvc_rs1s',
+                       'rvc_rs2', 'rvc_rs2s',
+                       'frs1', 'frs2', 'frs3',
+                       'rvc_frs2', 'rvc_frs2s']:
+            if search.upper() in regs:
+                found = search
+        if found:
+            predargs[drlookup[found]] = 'src_pred'
+            fsrc = 'f' in found
+            found = found.replace('f', '')
+            res.append('#define SRC_PREDINT %d' % (0 if fsrc else 1))
+            res.append('#define SRC_REG %s' % found)
+
+    res.append('#define PRED_ARGS %s' % ','.join(predargs))
+
     return '\n'.join(res)
 
 if __name__ == '__main__':
@@ -107,8 +151,7 @@ if __name__ == '__main__':
         regsname = os.path.join(insns_dir, regsname)
         twin_predication = False
         with open(regsname, "w") as f:
-            txt = find_registers(fname)
-            txt += "\n#define INSN_%s\n" % insn.upper()
+            txt = "\n#define INSN_%s\n" % insn.upper()
             # help identify type of register
             if insn in ['beq', 'bne', 'blt', 'bltu', 'bge', 'bgeu']:
                 txt += "#define INSN_TYPE_BRANCH\n"
@@ -141,4 +184,5 @@ if __name__ == '__main__':
                 txt += "#define INSN_TYPE_FP_BRANCH\n"
             if twin_predication:
                 txt += "\n#define INSN_CATEGORY_TWINPREDICATION\n"
+            txt += find_registers(fname, twin_predication)
             f.write(txt)
index 7804625808b48e8620c1b76e4267c1b966f27a75..40d2a9b398a992e3c61c0a4c139ecec557c60a82 100644 (file)
@@ -19,44 +19,25 @@ reg_t FN(processor_t* p, insn_t s_insn, reg_t pc)
   // REGS_PATTERN is generated by id_regs.py (per opcode)
   unsigned int floatintmap = REGS_PATTERN;
   reg_t dest_pred = ~0x0;
+  bool zeroing = false;
 #ifdef INSN_CATEGORY_TWINPREDICATION
   reg_t src_pred = ~0x0;
+  bool zeroingsrc = false;
 #endif
-  sv_insn_t insn(p, bits, floatintmap,
-                 dest_pred,
-#ifdef INSN_CATEGORY_TWINPREDICATION
-// twin-predication ONLY applies to dual-op operands: MV, FCVT, LD/ST.
-// however we don't know which register any of those will use, so
-// pass src_pred to each of rs1-3 and let the instruction sort it out.
-src_pred, src_pred, src_pred
-#else
-dest_pred, dest_pred, dest_pred
-#endif
-                );
-  bool zeroing;
+  sv_insn_t insn(p, bits, floatintmap, PRED_ARGS);
+  if (vlen > 0)
+  {
+    fprintf(stderr, "pre-ex reg %s %x rd %ld rs1 %ld rs2 %ld vlen %d\n",
+            xstr(INSN), INSNCODE, s_insn.rd(), s_insn.rs1(), s_insn.rs2(),
+            vlen);
 #ifdef INSN_CATEGORY_TWINPREDICATION
-#ifdef USING_REG_RS1
-  #define SRCREG s_insn.rs1()
-#endif
-#ifdef USING_REG_RS2
-  #define SRCREG s_insn.rs2()
-#endif
-#ifdef USING_REG_RS3
-  #define SRCREG s_insn.rs3()
-#endif
-#if (defined(USING_REG_RVC_RS1) || defined(USING_REG_RVC_RS1S))
-  #define SRCREG s_insn.rvc_rs1()
+    src_pred = insn.predicate(s_insn.SRC_REG(), SRC_PREDINT, zeroingsrc);
 #endif
-#if (defined(USING_REG_RVC_RS2) || defined(USING_REG_RVC_RS2S))
-  #define SRCREG s_insn.rvc_rs2()
-#endif
-  src_pred = insn.predicate(SRCREG, floatintmap & (REG_RS1|REG_RS2|REG_RS3),
-                            zeroing);
-#endif
-#if defined(USING_REG_RD) || defined(USING_REG_FRD)
+#ifdef DEST_PREDINT
   // use the ORIGINAL, i.e. NON-REDIRECTED, register here
-  dest_pred = insn.predicate(s_insn.rd(), floatintmap & REG_RD, zeroing);
+    dest_pred = insn.predicate(s_insn.DEST_REG(), DEST_PREDINT, zeroing);
 #endif
+  }
   // identify which regs have had their CSR entries set as vectorised.
   // really could do with a macro for-loop here... oh well...
   // integer ops, RD, RS1, RS2, RS3 (use sv_int_tb)
@@ -74,18 +55,55 @@ dest_pred, dest_pred, dest_pred
   }
   for (int voffs=0; voffs < vlen; voffs++)
   {
-      insn.reset_vloop_check();
+    insn.reset_vloop_check();
+#ifdef INSN_CATEGORY_TWINPREDICATION
+    int srcoffs = insn.rs_offs();
+    if (!zeroingsrc)
+    {
+      while ((src_pred & (1<<srcoffs)) == 0) {
+          srcoffs = insn.rs_offs_inc();
+          if (srcoffs == vlen) {
+              break;
+          }
+      }
+    }
+    int destoffs = insn.rd_offs();
+    if (!zeroing)
+    {
+      while ((dest_pred & (1<<destoffs)) == 0) {
+          destoffs = insn.rd_offs_inc();
+          if (destoffs == vlen) {
+              break;
+          }
+      }
+    }
+    if (srcoffs == vlen || destoffs == vlen) {
+        break; // end vector loop if either src or dest pred reaches end
+    }
+    if (vlen > 1)
+    {
+        fprintf(stderr, "twin reg %s src %d dest %d pred %lx %lx\n",
+            xstr(INSN), srcoffs, destoffs, src_pred, dest_pred);
+    }
+#endif
+#ifdef INSN_C_MV
+    fprintf(stderr, "pre loop reg %s %x vloop %d " \
+                      "vlen %d stop %d pred %lx rd%lx rvc2%d\n",
+                xstr(INSN), INSNCODE, voffs, vlen, insn.stop_vloop(),
+                dest_pred & (1<<voffs), READ_REG(insn._rd()), insn.rvc_rs2());
+#endif
       #include INCLUDEFILE
-#if defined(USING_REG_RD) || defined(USING_REG_FRD)
+#ifdef DEST_PREDINT
       // don't check inversion here as dest_pred has already been inverted
-      if (zeroing && ((dest_pred & (1<<voffs)) == 0))
+      if (zeroing && ((dest_pred & (1<<insn.rd_offs())) == 0))
       {
           // insn._rd() would be predicated: have to use insn._rd() here
-          WRITE_REG(insn._rd(), 0);
+          WRITE_REG(insn._DEST_REG(), 0);
       }
 #endif
       if (vlen > 1)
       {
+        insn.reset_caches(); // ready to increment offsets
 #if defined(USING_REG_RD)
         fprintf(stderr, "reg %s %x vloop %d vlen %d stop %d pred %lx rd%lx\n",
                 xstr(INSN), INSNCODE, voffs, vlen, insn.stop_vloop(),
@@ -98,7 +116,6 @@ dest_pred, dest_pred, dest_pred
                 (READ_FREG(insn._rd())));
 #endif
       }
-      insn.reset_caches(); // ready to increment offsets in next iteration
       if (insn.stop_vloop())
       {
         break;
index 90ef35426581b2cafdfce6c78591b98811c61bcc..7f7846ae8d248f4165a93705ce83d8aaa135992b 100644 (file)
@@ -165,3 +165,9 @@ uint64_t sv_insn_t::predicated(uint64_t reg, int offs, uint64_t pred)
     return 0;
 }
 
+bool sv_insn_t::stop_vloop(void)
+{
+    return (p->get_state()->vl == 0) || !vloop_continue;
+}
+
+
index 5d2b2fc7e171c1888d3c3ac6f4ed1cd0fbb9d8c2..aabbe79faaf070d5277fa04179615cb2bb1ebba1 100644 (file)
@@ -25,36 +25,34 @@ public:
             insn_t(bits), p(pr), vloop_continue(false), fimap(f),
             cached_rd(0xff), cached_rs1(0xff),
             cached_rs2(0xff), cached_rs3(0xff),
-            offs_rd(0), offs_rs1(0),
-            offs_rs2(0), offs_rs3(0),
-            new_offs_rd(0), new_offs_rs1(0),
-            new_offs_rs2(0), new_offs_rs3(0),
+            offs_rd(0), offs_rs(0),
+            new_offs_rd(0), new_offs_rs(0),
             prd(p_rd), prs1(p_rs1), prs2(p_rs2), prs3(p_rs3) {}
-  uint64_t rd () { return predicated(_rd (), offs_rd , prd); }
-  uint64_t rs1() { return predicated(_rs1(), offs_rs1, prs1); }
-  uint64_t rs2() { return predicated(_rs2(), offs_rs2, prs2); }
-  uint64_t rs3() { return predicated(_rs3(), offs_rs3, prs3); }
-  uint64_t rvc_rs1 () { return predicated(_rvc_rs1 (), offs_rs1, prs1); }
-  uint64_t rvc_rs1s() { return predicated(_rvc_rs1s(), offs_rs1, prs1); }
-  uint64_t rvc_rs2 () { return predicated(_rvc_rs2 (), offs_rs2, prs2); }
-  uint64_t rvc_rs2s() { return predicated(_rvc_rs2s(), offs_rs2, prs2); }
+  uint64_t rd () { return predicated(_rd (), offs_rd, prd); }
+  uint64_t rs1() { return predicated(_rs1(), offs_rs, prs1); }
+  uint64_t rs2() { return predicated(_rs2(), offs_rs, prs2); }
+  uint64_t rs3() { return predicated(_rs3(), offs_rs, prs3); }
+  uint64_t rvc_rs1 () { return predicated(_rvc_rs1 (), offs_rs, prs1); }
+  uint64_t rvc_rs1s() { return predicated(_rvc_rs1s(), offs_rs, prs1); }
+  uint64_t rvc_rs2 () { return predicated(_rvc_rs2 (), offs_rs, prs2); }
+  uint64_t rvc_rs2s() { return predicated(_rvc_rs2s(), offs_rs, prs2); }
 
   uint64_t _rd () { return _remap(insn_t::rd (), fimap & REG_RD ,
                                   offs_rd , cached_rd, new_offs_rd); }
   uint64_t _rs1() { return _remap(insn_t::rs1(), fimap & REG_RS1,
-                                  offs_rs1, cached_rs1, new_offs_rs1); }
+                                  offs_rs, cached_rs1, new_offs_rs); }
   uint64_t _rs2() { return _remap(insn_t::rs2(), fimap & REG_RS2,
-                                  offs_rs2, cached_rs2, new_offs_rs2); }
+                                  offs_rs, cached_rs2, new_offs_rs); }
   uint64_t _rs3() { return _remap(insn_t::rs3(), fimap & REG_RS3,
-                                  offs_rs3, cached_rs3, new_offs_rs3); }
+                                  offs_rs, cached_rs3, new_offs_rs); }
   uint64_t _rvc_rs1 () { return _remap(insn_t::rvc_rs1(), fimap & REG_RVC_RS1,
-                                       offs_rs1, cached_rs1, new_offs_rs1); }
+                                       offs_rs, cached_rs1, new_offs_rs); }
   uint64_t _rvc_rs1s() { return _remap(insn_t::rvc_rs1s(), fimap & REG_RVC_RS1S,
-                                       offs_rs1, cached_rs1, new_offs_rs1); }
+                                       offs_rs, cached_rs1, new_offs_rs); }
   uint64_t _rvc_rs2 () { return _remap(insn_t::rvc_rs2(), fimap & REG_RVC_RS2,
-                                       offs_rs2, cached_rs2, new_offs_rs2); }
+                                       offs_rs, cached_rs2, new_offs_rs); }
   uint64_t _rvc_rs2s() { return _remap(insn_t::rvc_rs2s(), fimap & REG_RVC_RS2S,
-                                       offs_rs2, cached_rs2, new_offs_rs2); }
+                                       offs_rs, cached_rs2, new_offs_rs); }
 
   void reset_caches(void)
   {
@@ -63,9 +61,7 @@ public:
     cached_rs2 = 0xff;
     cached_rs3 = 0xff;
     offs_rd = new_offs_rd;
-    offs_rs1 = new_offs_rs1;
-    offs_rs2 = new_offs_rs2;
-    offs_rs3 = new_offs_rs3;
+    offs_rs = new_offs_rs;
   }
 
   bool sv_check_reg(bool intreg, uint64_t reg);
@@ -74,10 +70,15 @@ public:
   reg_t predicate(uint64_t reg, bool isint, bool &zeroing);
 
   void reset_vloop_check(void) { vloop_continue = false; }
-  bool stop_vloop(void) { return !vloop_continue; }
+  bool stop_vloop(void);
+
+  int rd_offs(void) { return offs_rd; }
+  int rs_offs(void) { return offs_rs; }
+  int rd_offs_inc(void) { offs_rd += 1; return offs_rd; }
+  int rs_offs_inc(void) { offs_rs += 1; return offs_rs; }
 
-private:
   processor_t *p;
+private:
   bool vloop_continue;
   unsigned int fimap;
   uint64_t cached_rd;
@@ -85,13 +86,9 @@ private:
   uint64_t cached_rs2;
   uint64_t cached_rs3;
   int offs_rd;
-  int offs_rs1;
-  int offs_rs2;
-  int offs_rs3;
+  int offs_rs;
   int new_offs_rd;
-  int new_offs_rs1;
-  int new_offs_rs2;
-  int new_offs_rs3;
+  int new_offs_rs;
   uint64_t &prd;
   uint64_t &prs1;
   uint64_t &prs2;