slp: Support optimizing load distribution
authorTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:50:57 +0000 (20:50 +0000)
committerTamar Christina <tamar.christina@arm.com>
Thu, 14 Jan 2021 20:54:31 +0000 (20:54 +0000)
This introduces a post processing step for the pattern matcher to flatten
permutes introduced by the complex multiplications patterns.

This performs a blend early such that SLP is not cancelled by the LOAD_LANES
permute.  This is a temporary workaround to the fact that loads are not CSEd
during building and is required to produce efficient code.

gcc/ChangeLog:

* tree-vect-slp.c (optimize_load_redistribution_1): New.
(optimize_load_redistribution, vect_is_slp_load_node): New.
(vect_match_slp_patterns): Use it.

gcc/tree-vect-slp.c

index f7f656a481024310cf2ac51121887779d1d779ff..6b6c9ccc0a021b118052a3a8d562ea0cfdc68335 100644 (file)
@@ -2258,6 +2258,123 @@ calculate_unrolling_factor (poly_uint64 nunits, unsigned int group_size)
   return exact_div (common_multiple (nunits, group_size), group_size);
 }
 
+/* Helper that checks to see if a node is a load node.  */
+
+static inline bool
+vect_is_slp_load_node  (slp_tree root)
+{
+  return SLP_TREE_DEF_TYPE (root) == vect_internal_def
+        && STMT_VINFO_GROUPED_ACCESS (SLP_TREE_REPRESENTATIVE (root))
+        && DR_IS_READ (STMT_VINFO_DATA_REF (SLP_TREE_REPRESENTATIVE (root)));
+}
+
+
+/* Helper function of optimize_load_redistribution that performs the operation
+   recursively.  */
+
+static slp_tree
+optimize_load_redistribution_1 (scalar_stmts_to_slp_tree_map_t *bst_map,
+                               vec_info *vinfo, unsigned int group_size,
+                               hash_map<slp_tree, slp_tree> *load_map,
+                               slp_tree root)
+{
+  if (slp_tree *leader = load_map->get (root))
+    return *leader;
+
+  load_map->put (root, NULL);
+
+  slp_tree node;
+  unsigned i;
+
+  /* For now, we don't know anything about externals so do not do anything.  */
+  if (SLP_TREE_DEF_TYPE (root) != vect_internal_def)
+    return NULL;
+  else if (SLP_TREE_CODE (root) == VEC_PERM_EXPR)
+    {
+      /* First convert this node into a load node and add it to the leaves
+        list and flatten the permute from a lane to a load one.  If it's
+        unneeded it will be elided later.  */
+      vec<stmt_vec_info> stmts;
+      stmts.create (SLP_TREE_LANES (root));
+      lane_permutation_t lane_perm = SLP_TREE_LANE_PERMUTATION (root);
+      for (unsigned j = 0; j < lane_perm.length (); j++)
+       {
+         std::pair<unsigned, unsigned> perm = lane_perm[j];
+         node = SLP_TREE_CHILDREN (root)[perm.first];
+
+         if (!vect_is_slp_load_node (node)
+             || SLP_TREE_CHILDREN (node).exists ())
+           {
+             stmts.release ();
+             goto next;
+           }
+
+         stmts.quick_push (SLP_TREE_SCALAR_STMTS (node)[perm.second]);
+       }
+
+      if (dump_enabled_p ())
+       dump_printf_loc (MSG_NOTE, vect_location,
+                        "converting stmts on permute node %p\n", root);
+
+      bool *matches = XALLOCAVEC (bool, group_size);
+      poly_uint64 max_nunits = 1;
+      unsigned tree_size = 0, limit = 1;
+      node = vect_build_slp_tree (vinfo, stmts, group_size, &max_nunits,
+                                 matches, &limit, &tree_size, bst_map);
+      if (!node)
+       stmts.release ();
+
+      load_map->put (root, node);
+      return node;
+    }
+
+next:
+  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (root), i , node)
+    {
+      slp_tree value
+       = optimize_load_redistribution_1 (bst_map, vinfo, group_size, load_map,
+                                         node);
+      if (value)
+       {
+         SLP_TREE_REF_COUNT (value)++;
+         SLP_TREE_CHILDREN (root)[i] = value;
+         vect_free_slp_tree (node);
+       }
+    }
+
+  return NULL;
+}
+
+/* Temporary workaround for loads not being CSEd during SLP build.  This
+   function will traverse the SLP tree rooted in ROOT for INSTANCE and find
+   VEC_PERM nodes that blend vectors from multiple nodes that all read from the
+   same DR such that the final operation is equal to a permuted load.  Such
+   NODES are then directly converted into LOADS themselves.  The nodes are
+   CSEd using BST_MAP.  */
+
+static void
+optimize_load_redistribution (scalar_stmts_to_slp_tree_map_t *bst_map,
+                             vec_info *vinfo, unsigned int group_size,
+                             hash_map<slp_tree, slp_tree> *load_map,
+                             slp_tree root)
+{
+  slp_tree node;
+  unsigned i;
+
+  FOR_EACH_VEC_ELT (SLP_TREE_CHILDREN (root), i , node)
+    {
+      slp_tree value
+       = optimize_load_redistribution_1 (bst_map, vinfo, group_size, load_map,
+                                         node);
+      if (value)
+       {
+         SLP_TREE_REF_COUNT (value)++;
+         SLP_TREE_CHILDREN (root)[i] = value;
+         vect_free_slp_tree (node);
+       }
+    }
+}
+
 /* Helper function of vect_match_slp_patterns.
 
    Attempts to match patterns against the slp tree rooted in REF_NODE using
@@ -2305,8 +2422,7 @@ vect_match_slp_patterns_2 (slp_tree *ref_node, vec_info *vinfo,
 static bool
 vect_match_slp_patterns (slp_instance instance, vec_info *vinfo,
                         hash_set<slp_tree> *visited,
-                        slp_tree_to_load_perm_map_t *perm_cache,
-                        scalar_stmts_to_slp_tree_map_t * /* bst_map */)
+                        slp_tree_to_load_perm_map_t *perm_cache)
 {
   DUMP_VECT_SCOPE ("vect_match_slp_patterns");
   slp_tree *ref_node = &SLP_INSTANCE_TREE (instance);
@@ -2316,20 +2432,7 @@ vect_match_slp_patterns (slp_instance instance, vec_info *vinfo,
                     "Analyzing SLP tree %p for patterns\n",
                     SLP_INSTANCE_TREE (instance));
 
-  bool found_p
-    = vect_match_slp_patterns_2 (ref_node, vinfo, perm_cache, visited);
-
-  if (found_p)
-    {
-      if (dump_enabled_p ())
-       {
-         dump_printf_loc (MSG_NOTE, vect_location,
-                          "Pattern matched SLP tree\n");
-         vect_print_slp_graph (MSG_NOTE, vect_location, *ref_node);
-       }
-    }
-
-  return found_p;
+  return vect_match_slp_patterns_2 (ref_node, vinfo, perm_cache, visited);
 }
 
 /* Analyze an SLP instance starting from a group of grouped stores.  Call
@@ -2768,10 +2871,25 @@ vect_analyze_slp (vec_info *vinfo, unsigned max_tree_size)
 
   hash_set<slp_tree> visited_patterns;
   slp_tree_to_load_perm_map_t perm_cache;
+  hash_map<slp_tree, slp_tree> load_map;
+
   /* See if any patterns can be found in the SLP tree.  */
   FOR_EACH_VEC_ELT (LOOP_VINFO_SLP_INSTANCES (vinfo), i, instance)
-    vect_match_slp_patterns (instance, vinfo, &visited_patterns, &perm_cache,
-                            bst_map);
+    if (vect_match_slp_patterns (instance, vinfo, &visited_patterns,
+                                &perm_cache))
+      {
+       slp_tree root = SLP_INSTANCE_TREE (instance);
+       optimize_load_redistribution (bst_map, vinfo, SLP_TREE_LANES (root),
+                                     &load_map, root);
+       if (dump_enabled_p ())
+         {
+           dump_printf_loc (MSG_NOTE, vect_location,
+                            "Pattern matched SLP tree\n");
+           vect_print_slp_graph (MSG_NOTE, vect_location, root);
+         }
+      }
+
+
 
   /* The map keeps a reference on SLP nodes built, release that.  */
   for (scalar_stmts_to_slp_tree_map_t::iterator it = bst_map->begin ();