[committed] openmp: Compute triangular loop number of iterations at compile time

Message ID 20200622091000.GA23754@tucnak
State New
Headers show
Series
  • [committed] openmp: Compute triangular loop number of iterations at compile time
Related show

Commit Message

Christophe Lyon via Gcc-patches June 22, 2020, 9:10 a.m.
Hi!

This patch handles the compile time computation of number of loop
iterations.

Bootstrapped/regtested on x86_64-linux and i686-linux, committed to trunk.

2020-06-22  Jakub Jelinek  <jakub@redhat.com>

	* omp-general.c (omp_extract_for_data): For triangular loops with
	all loop invariant expressions constant where the innermost loop is
	executed at least once compute number of iterations at compile time.


	Jakub

Patch

--- gcc/omp-general.c.jj	2020-06-16 16:30:43.748400142 +0200
+++ gcc/omp-general.c	2020-06-19 16:01:46.379724361 +0200
@@ -313,6 +313,44 @@  omp_extract_for_data (gomp_for *for_stmt
     }
 
   int cnt = fd->ordered ? fd->ordered : fd->collapse;
+  int single_nonrect = -1;
+  tree single_nonrect_count = NULL_TREE;
+  enum tree_code single_nonrect_cond_code = ERROR_MARK;
+  for (i = 1; i < cnt; i++)
+    {
+      tree n1 = gimple_omp_for_initial (for_stmt, i);
+      tree n2 = gimple_omp_for_final (for_stmt, i);
+      if (TREE_CODE (n1) == TREE_VEC)
+	{
+	  if (fd->non_rect)
+	    {
+	      single_nonrect = -1;
+	      break;
+	    }
+	  for (int j = i - 1; j >= 0; j--)
+	    if (TREE_VEC_ELT (n1, 0) == gimple_omp_for_index (for_stmt, j))
+	      {
+		single_nonrect = j;
+		break;
+	      }
+	  fd->non_rect = true;
+	}
+      else if (TREE_CODE (n2) == TREE_VEC)
+	{
+	  if (fd->non_rect)
+	    {
+	      single_nonrect = -1;
+	      break;
+	    }
+	  for (int j = i - 1; j >= 0; j--)
+	    if (TREE_VEC_ELT (n2, 0) == gimple_omp_for_index (for_stmt, j))
+	      {
+		single_nonrect = j;
+		break;
+	      }
+	  fd->non_rect = true;
+	}
+    }
   for (i = 0; i < cnt; i++)
     {
       if (i == 0
@@ -444,8 +482,90 @@  omp_extract_for_data (gomp_for *for_stmt
 
       if (collapse_count && *collapse_count == NULL)
 	{
+	  if (count && integer_zerop (count))
+	    continue;
+	  tree n1first = NULL_TREE, n2first = NULL_TREE;
+	  tree n1last = NULL_TREE, n2last = NULL_TREE;
+	  tree ostep = NULL_TREE;
 	  if (loop->m1 || loop->m2)
-	    t = NULL_TREE;
+	    {
+	      if (count == NULL_TREE)
+		continue;
+	      if (single_nonrect == -1
+		  || (loop->m1 && TREE_CODE (loop->m1) != INTEGER_CST)
+		  || (loop->m2 && TREE_CODE (loop->m2) != INTEGER_CST))
+		{
+		  count = NULL_TREE;
+		  continue;
+		}
+	      tree var = gimple_omp_for_initial (for_stmt, single_nonrect);
+	      tree itype = TREE_TYPE (var);
+	      tree first = gimple_omp_for_initial (for_stmt, single_nonrect);
+	      t = gimple_omp_for_incr (for_stmt, single_nonrect);
+	      ostep = omp_get_for_step_from_incr (loc, t);
+	      t = fold_binary (MINUS_EXPR, long_long_unsigned_type_node,
+			       single_nonrect_count,
+			       build_one_cst (long_long_unsigned_type_node));
+	      t = fold_convert (itype, t);
+	      first = fold_convert (itype, first);
+	      ostep = fold_convert (itype, ostep);
+	      tree last = fold_binary (PLUS_EXPR, itype, first,
+				       fold_binary (MULT_EXPR, itype, t,
+						    ostep));
+	      if (TREE_CODE (first) != INTEGER_CST
+		  || TREE_CODE (last) != INTEGER_CST)
+		{
+		  count = NULL_TREE;
+		  continue;
+		}
+	      if (loop->m1)
+		{
+		  tree m1 = fold_convert (itype, loop->m1);
+		  tree n1 = fold_convert (itype, loop->n1);
+		  n1first = fold_binary (PLUS_EXPR, itype,
+					 fold_binary (MULT_EXPR, itype,
+						      first, m1), n1);
+		  n1last = fold_binary (PLUS_EXPR, itype,
+					fold_binary (MULT_EXPR, itype,
+						     last, m1), n1);
+		}
+	      else
+		n1first = n1last = loop->n1;
+	      if (loop->m2)
+		{
+		  tree n2 = fold_convert (itype, loop->n2);
+		  tree m2 = fold_convert (itype, loop->m2);
+		  n2first = fold_binary (PLUS_EXPR, itype,
+					 fold_binary (MULT_EXPR, itype,
+						      first, m2), n2);
+		  n2last = fold_binary (PLUS_EXPR, itype,
+					fold_binary (MULT_EXPR, itype,
+						     last, m2), n2);
+		}
+	      else
+		n2first = n2last = loop->n2;
+	      n1first = fold_convert (TREE_TYPE (loop->v), n1first);
+	      n2first = fold_convert (TREE_TYPE (loop->v), n2first);
+	      n1last = fold_convert (TREE_TYPE (loop->v), n1last);
+	      n2last = fold_convert (TREE_TYPE (loop->v), n2last);
+	      t = fold_binary (loop->cond_code, boolean_type_node,
+			       n1first, n2first);
+	      tree t2 = fold_binary (loop->cond_code, boolean_type_node,
+				     n1last, n2last);
+	      if (t && t2 && integer_nonzerop (t) && integer_nonzerop (t2))
+		/* All outer loop iterators have at least one inner loop
+		   iteration.  Try to compute the count at compile time.  */
+		t = NULL_TREE;
+	      else if (t && t2 && integer_zerop (t) && integer_zerop (t2))
+		/* No iterations of the inner loop.  count will be set to
+		   zero cst below.  */;
+	      else
+		{
+		  /* Punt (for now).  */
+		  count = NULL_TREE;
+		  continue;
+		}
+	    }
 	  else
 	    t = fold_binary (loop->cond_code, boolean_type_node,
 			     fold_convert (TREE_TYPE (loop->v), loop->n1),
@@ -454,8 +574,6 @@  omp_extract_for_data (gomp_for *for_stmt
 	    count = build_zero_cst (long_long_unsigned_type_node);
 	  else if ((i == 0 || count != NULL_TREE)
 		   && TREE_CODE (TREE_TYPE (loop->v)) == INTEGER_TYPE
-		   && loop->m1 == NULL_TREE
-		   && loop->m2 == NULL_TREE
 		   && TREE_CONSTANT (loop->n1)
 		   && TREE_CONSTANT (loop->n2)
 		   && TREE_CODE (loop->step) == INTEGER_CST)
@@ -465,31 +583,89 @@  omp_extract_for_data (gomp_for *for_stmt
 	      if (POINTER_TYPE_P (itype))
 		itype = signed_type_for (itype);
 	      t = build_int_cst (itype, (loop->cond_code == LT_EXPR ? -1 : 1));
-	      t = fold_build2_loc (loc, PLUS_EXPR, itype,
-				   fold_convert_loc (loc, itype, loop->step),
-				   t);
-	      t = fold_build2_loc (loc, PLUS_EXPR, itype, t,
-				   fold_convert_loc (loc, itype, loop->n2));
-	      t = fold_build2_loc (loc, MINUS_EXPR, itype, t,
-				   fold_convert_loc (loc, itype, loop->n1));
+	      t = fold_build2 (PLUS_EXPR, itype,
+			       fold_convert (itype, loop->step), t);
+	      tree n1 = loop->n1;
+	      tree n2 = loop->n2;
+	      if (loop->m1 || loop->m2)
+		{
+		  gcc_assert (single_nonrect != -1);
+		  if (single_nonrect_cond_code == LT_EXPR)
+		    {
+		      n1 = n1first;
+		      n2 = n2first;
+		    }
+		  else
+		    {
+		      n1 = n1last;
+		      n2 = n2last;
+		    }
+		}
+	      t = fold_build2 (PLUS_EXPR, itype, t, fold_convert (itype, n2));
+	      t = fold_build2 (MINUS_EXPR, itype, t, fold_convert (itype, n1));
+	      tree step = fold_convert_loc (loc, itype, loop->step);
 	      if (TYPE_UNSIGNED (itype) && loop->cond_code == GT_EXPR)
+		t = fold_build2 (TRUNC_DIV_EXPR, itype,
+				 fold_build1 (NEGATE_EXPR, itype, t),
+				 fold_build1 (NEGATE_EXPR, itype, step));
+	      else
+		t = fold_build2 (TRUNC_DIV_EXPR, itype, t, step);
+	      tree llutype = long_long_unsigned_type_node;
+	      t = fold_convert (llutype, t);
+	      if (loop->m1 || loop->m2)
 		{
-		  tree step = fold_convert_loc (loc, itype, loop->step);
-		  t = fold_build2_loc (loc, TRUNC_DIV_EXPR, itype,
-				       fold_build1_loc (loc, NEGATE_EXPR,
-							itype, t),
-				       fold_build1_loc (loc, NEGATE_EXPR,
-							itype, step));
-		}
-	      else
-		t = fold_build2_loc (loc, TRUNC_DIV_EXPR, itype, t,
-				     fold_convert_loc (loc, itype,
-						       loop->step));
-	      t = fold_convert_loc (loc, long_long_unsigned_type_node, t);
-	      if (count != NULL_TREE)
-		count = fold_build2_loc (loc, MULT_EXPR,
-					 long_long_unsigned_type_node,
-					 count, t);
+		  /* t is number of iterations of inner loop at either first
+		     or last value of the outer iterator (the one with fewer
+		     iterations).
+		     Compute t2 = ((m2 - m1) * ostep) / step
+		     (for single_nonrect_cond_code GT_EXPR
+		      t2 = ((m1 - m2) * ostep) / step instead)
+		     and niters = outer_count * t
+				  + t2 * ((outer_count - 1) * outer_count / 2)
+		   */
+		  tree m1 = loop->m1 ? loop->m1 : integer_zero_node;
+		  tree m2 = loop->m2 ? loop->m2 : integer_zero_node;
+		  m1 = fold_convert (itype, m1);
+		  m2 = fold_convert (itype, m2);
+		  tree t2;
+		  if (single_nonrect_cond_code == LT_EXPR)
+		    t2 = fold_build2 (MINUS_EXPR, itype, m2, m1);
+		  else
+		    t2 = fold_build2 (MINUS_EXPR, itype, m1, m2);
+		  t2 = fold_build2 (MULT_EXPR, itype, t2, ostep);
+		  if (TYPE_UNSIGNED (itype) && loop->cond_code == GT_EXPR)
+		    t2 = fold_build2 (TRUNC_DIV_EXPR, itype,
+				      fold_build1 (NEGATE_EXPR, itype, t2),
+				      fold_build1 (NEGATE_EXPR, itype, step));
+		  else
+		    t2 = fold_build2 (TRUNC_DIV_EXPR, itype, t2, step);
+		  t2 = fold_convert (llutype, t2);
+		  t = fold_build2 (MULT_EXPR, llutype, t,
+				   single_nonrect_count);
+		  tree t3 = fold_build2 (MINUS_EXPR, llutype,
+					 single_nonrect_count,
+					 build_one_cst (llutype));
+		  t3 = fold_build2 (MULT_EXPR, llutype, t3,
+				    single_nonrect_count);
+		  t3 = fold_build2 (TRUNC_DIV_EXPR, llutype, t3,
+				    build_int_cst (llutype, 2));
+		  t2 = fold_build2 (MULT_EXPR, llutype, t2, t3);
+		  t = fold_build2 (PLUS_EXPR, llutype, t, t2);
+		}
+	      if (i == single_nonrect)
+		{
+		  if (integer_zerop (t) || TREE_CODE (t) != INTEGER_CST)
+		    count = t;
+		  else
+		    {
+		      single_nonrect_count = t;
+		      single_nonrect_cond_code = loop->cond_code;
+		      if (count == NULL_TREE)
+			count = build_one_cst (llutype);
+		    }
+		}
+	      else if (count != NULL_TREE)
+		count = fold_build2 (MULT_EXPR, llutype, count, t);
 	      else
 		count = t;
 	      if (TREE_CODE (count) != INTEGER_CST)