slp: support complex FMA and complex FMA conjugate
authorTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:58:12 +0000 (20:58 +0000)
committerTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:58:12 +0000 (20:58 +0000)
This adds support for FMA and FMA conjugated to the slp pattern matcher.

Example of instructions matched:

#include <stdio.h>
#include <complex.h>

#define N 200
#define ROT
#define TYPE float
#define TYPE2 float

void g (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] +=  a[i] * (b[i] ROT);
    }
}

void g_f1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] +=  conjf (a[i]) * (b[i] ROT);
    }
}

void g_s1 (TYPE2 complex a[restrict N], TYPE complex b[restrict N], TYPE complex c[restrict N])
{
  for (int i=0; i < N; i++)
    {
      c[i] +=  a[i] * conjf (b[i] ROT);
    }
}

void caxpy_add(double complex * restrict y, double complex * restrict x, size_t N, double complex f) {
  for (size_t i = 0; i < N; ++i)
    y[i] += x[i]* f;
}

gcc/ChangeLog:

* internal-fn.def (COMPLEX_FMA, COMPLEX_FMA_CONJ): New.
* optabs.def (cmla_optab, cmla_conj_optab): New.
* doc/md.texi: Document them.
* tree-vect-slp-patterns.c (vect_match_call_p,
class complex_fma_pattern, vect_slp_reset_pattern,
complex_fma_pattern::matches, complex_fma_pattern::recognize,
complex_fma_pattern::build): New.

gcc/doc/md.texi
gcc/internal-fn.def
gcc/optabs.def
gcc/tree-vect-slp-patterns.c

index 60e8c94810a3f27a4fa8d59367b0710323504e9c..49a1ce045b19a2a1d49a1fedbacd81876e317c04 100644 (file)
@@ -6202,6 +6202,51 @@ The operation is only supported for vector modes @var{m}.
 
 This pattern is not allowed to @code{FAIL}.
 
+@cindex @code{cmla@var{m}4} instruction pattern
+@item @samp{cmla@var{m}4}
+Perform a vector multiply and accumulate that is semantically the same as
+a multiply and accumulate of complex numbers.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] += a[i] * b[i];
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmla_conj@var{m}4} instruction pattern
+@item @samp{cmla_conj@var{m}4}
+Perform a vector multiply by conjugate and accumulate that is semantically
+the same as a multiply and accumulate of complex numbers where the second
+multiply arguments is conjugated.
+
+@smallexample
+  complex TYPE c[N];
+  complex TYPE a[N];
+  complex TYPE b[N];
+  for (int i = 0; i < N; i += 1)
+    @{
+      c[i] += a[i] * conj (b[i]);
+    @}
+@end smallexample
+
+In GCC lane ordering the real part of the number must be in the even lanes with
+the imaginary part in the odd lanes.
+
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
 @cindex @code{cmul@var{m}4} instruction pattern
 @item @samp{cmul@var{m}4}
 Perform a vector multiply that is semantically the same as multiply of
index e3e4fe5ebadb446408b68b7408f4b86f42695b8d..020b586bc656b2691c6151d727f58601d4eec80f 100644 (file)
@@ -288,6 +288,8 @@ DEF_INTERNAL_FLT_FN (LDEXP, ECF_CONST, ldexp, binary)
 
 /* Ternary math functions.  */
 DEF_INTERNAL_FLT_FLOATN_FN (FMA, ECF_CONST, fma, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA, ECF_CONST, cmla, ternary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_FMA_CONJ, ECF_CONST, cmla_conj, ternary)
 
 /* Unary integer ops.  */
 DEF_INTERNAL_INT_FN (CLRSB, ECF_CONST | ECF_NOTHROW, clrsb, unary)
index fcc27d00dbadb8dd0f6793c12d45e8c5a5ab509e..cecd1b61a1f1f7ae58e182eb789e1a3551a429ba 100644 (file)
@@ -294,6 +294,8 @@ OPTAB_D (cadd90_optab, "cadd90$a3")
 OPTAB_D (cadd270_optab, "cadd270$a3")
 OPTAB_D (cmul_optab, "cmul$a3")
 OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
+OPTAB_D (cmla_optab, "cmla$a4")
+OPTAB_D (cmla_conj_optab, "cmla_conj$a4")
 OPTAB_D (cos_optab, "cos$a2")
 OPTAB_D (cosh_optab, "cosh$a2")
 OPTAB_D (exp10_optab, "exp10$a2")
index dc96be51dfe2f8176621183f7e8f61da0252c770..bd632e01fb851761ad9e43ffaa1aeaf152176c47 100644 (file)
@@ -325,6 +325,24 @@ vect_match_expression_p (slp_tree node, tree_code code)
   return true;
 }
 
+/* Checks to see if the expression represented by NODE is a call to the internal
+   function FN.  */
+
+static inline bool
+vect_match_call_p (slp_tree node, internal_fn fn)
+{
+  if (!node
+      || !SLP_TREE_REPRESENTATIVE (node))
+    return false;
+
+  gimple* expr = STMT_VINFO_STMT (SLP_TREE_REPRESENTATIVE (node));
+  if (!expr
+      || !gimple_call_internal_p (expr, fn))
+    return false;
+
+   return true;
+}
+
 /* Check if the given lane permute in PERMUTES matches an alternating sequence
    of {even odd even odd ...}.  This to account for unrolled loops.  Further
    mode there resulting permute must be linear.   */
@@ -1085,6 +1103,168 @@ complex_mul_pattern::build (vec_info *vinfo)
   complex_pattern::build (vinfo);
 }
 
+/*******************************************************************************
+ * complex_fma_pattern class
+ ******************************************************************************/
+
+class complex_fma_pattern : public complex_pattern
+{
+  protected:
+    complex_fma_pattern (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+      : complex_pattern (node, m_ops, ifn)
+    {
+      this->m_num_args = 3;
+    }
+
+  public:
+    void build (vec_info *);
+    static internal_fn
+    matches (complex_operation_t op, slp_tree_to_load_perm_map_t *, slp_tree *,
+            vec<slp_tree> *);
+
+    static vect_pattern*
+    recognize (slp_tree_to_load_perm_map_t *, slp_tree *);
+
+    static vect_pattern*
+    mkInstance (slp_tree *node, vec<slp_tree> *m_ops, internal_fn ifn)
+    {
+      return new complex_fma_pattern (node, m_ops, ifn);
+    }
+};
+
+/* Helper function to "reset" a previously matched node and undo the changes
+   made enough so that the node is treated as an irrelevant node.  */
+
+static inline void
+vect_slp_reset_pattern (slp_tree node)
+{
+  stmt_vec_info stmt_info = vect_orig_stmt (SLP_TREE_REPRESENTATIVE (node));
+  STMT_VINFO_IN_PATTERN_P (stmt_info) = false;
+  STMT_SLP_TYPE (stmt_info) = pure_slp;
+  SLP_TREE_REPRESENTATIVE (node) = stmt_info;
+}
+
+/* Pattern matcher for trying to match complex multiply and accumulate
+   and multiply and subtract patterns in SLP tree.
+   If the operation matches then IFN is set to the operation it matched and
+   the arguments to the two replacement statements are put in m_ops.
+
+   If no match is found then IFN is set to IFN_LAST and m_ops is unchanged.
+
+   This function matches the patterns shaped as:
+
+   double ax = (b[i+1] * a[i]) + (b[i] * a[i]);
+   double bx = (a[i+1] * b[i]) - (a[i+1] * b[i+1]);
+
+   c[i] = c[i] - ax;
+   c[i+1] = c[i+1] + bx;
+
+   If a match occurred then TRUE is returned, else FALSE.  The match is
+   performed after COMPLEX_MUL which would have done the majority of the work.
+   This function merely matches an ADD with a COMPLEX_MUL IFN.  The initial
+   match is expected to be in OP1 and the initial match operands in args0.  */
+
+internal_fn
+complex_fma_pattern::matches (complex_operation_t op,
+                             slp_tree_to_load_perm_map_t * /* perm_cache */,
+                             slp_tree *ref_node, vec<slp_tree> *ops)
+{
+  internal_fn ifn = IFN_LAST;
+
+  /* Find the two components.  We match Complex MUL first which reduces the
+     amount of work this pattern has to do.  After that we just match the
+     head node and we're done.:
+
+     * FMA: + +.
+
+     We need to ignore the two_operands nodes that may also match.
+     For that we can check if they have any scalar statements and also
+     check that it's not a permute node as we're looking for a normal
+     PLUS_EXPR operation.  */
+  if (op != CMPLX_NONE)
+    return IFN_LAST;
+
+  /* Find the two components.  We match Complex MUL first which reduces the
+     amount of work this pattern has to do.  After that we just match the
+     head node and we're done.:
+
+   * FMA: + + on a non-two_operands node.  */
+  slp_tree vnode = *ref_node;
+  if (SLP_TREE_LANE_PERMUTATION (vnode).exists ()
+      || !SLP_TREE_CHILDREN (vnode).exists ()
+      || !vect_match_expression_p (vnode, PLUS_EXPR))
+    return IFN_LAST;
+
+  slp_tree node = SLP_TREE_CHILDREN (vnode)[1];
+
+  if (vect_match_call_p (node, IFN_COMPLEX_MUL))
+    ifn = IFN_COMPLEX_FMA;
+  else if (vect_match_call_p (node, IFN_COMPLEX_MUL_CONJ))
+    ifn = IFN_COMPLEX_FMA_CONJ;
+  else
+    return IFN_LAST;
+
+  if (!vect_pattern_validate_optab (ifn, vnode))
+    return IFN_LAST;
+
+  /* FMA matched ADD + CMUL.  During the matching of CMUL the
+     stmt that starts the pattern is marked as being in a pattern,
+     namely the CMUL.  When replacing this with a CFMA we have to
+     unmark this statement as being in a pattern.  This is because
+     vect_mark_pattern_stmts will only mark the current stmt as being
+     in a pattern.  Later on when the scalar stmts are examined the
+     old statement which is supposed to be irrelevant will point to
+     CMUL unless we undo the pattern relationship here.  */
+  vect_slp_reset_pattern (node);
+  ops->truncate (0);
+  ops->create (3);
+
+  if (ifn == IFN_COMPLEX_FMA)
+    {
+      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
+      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
+      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
+    }
+  else
+    {
+      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
+      ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
+      ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
+    }
+
+  return ifn;
+}
+
+/* Attempt to recognize a complex mul pattern.  */
+
+vect_pattern*
+complex_fma_pattern::recognize (slp_tree_to_load_perm_map_t *perm_cache,
+                               slp_tree *node)
+{
+  auto_vec<slp_tree> ops;
+  complex_operation_t op
+    = vect_detect_pair_op (*node, true, &ops);
+  internal_fn ifn
+    = complex_fma_pattern::matches (op, perm_cache, node, &ops);
+  if (ifn == IFN_LAST)
+    return NULL;
+
+  return new complex_fma_pattern (node, &ops, ifn);
+}
+
+/* Perform a replacement of the detected complex mul pattern with the new
+   instruction sequences.  */
+
+void
+complex_fma_pattern::build (vec_info *vinfo)
+{
+  SLP_TREE_CHILDREN (*this->m_node).release ();
+  SLP_TREE_CHILDREN (*this->m_node).create (3);
+  SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
+
+  complex_pattern::build (vinfo);
+}
+
 /*******************************************************************************
  * Pattern matching definitions
  ******************************************************************************/