diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 102721616500..1e2371d2664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.io.Writer import java.util.UUID +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.Map import scala.reflect.ClassTag @@ -29,6 +30,8 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.TableIdentifier @@ -75,7 +78,7 @@ object CurrentOrigin { } // scalastyle:off -abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { +abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Logging { // scalastyle:on self: BaseType => @@ -484,7 +487,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { writer: Writer, verbose: Boolean, addSuffix: Boolean): Unit = { - generateTreeString(0, Nil, writer, verbose, "", addSuffix) + treeString(writer, verbose, addSuffix, TreeNode.maxTreeToStringDepth) + } + + def treeString( + writer: Writer, + verbose: Boolean, + addSuffix: Boolean, + maxDepth: Int): Unit = { + generateTreeString(0, Nil, writer, verbose, "", addSuffix, maxDepth) } /** @@ -550,7 +561,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { + addSuffix: Boolean = false, + maxDepth: Int = TreeNode.maxTreeToStringDepth): Unit = { if (depth > 0) { lastChildren.init.foreach { isLast => @@ -559,30 +571,42 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { writer.write(if (lastChildren.last) "+- " else ":- ") } - val str = if (verbose) { - if (addSuffix) verboseStringWithSuffix else verboseString - } else { - simpleString - } - writer.write(prefix) - writer.write(str) - writer.write("\n") - - if (innerChildren.nonEmpty) { - innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, - addSuffix = addSuffix)) - innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, - addSuffix = addSuffix) - } - - if (children.nonEmpty) { - children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix)) - children.last.generateTreeString( - depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix) - } + if (depth < maxDepth) { + val str = if (verbose) { + if (addSuffix) verboseStringWithSuffix else verboseString + } else { + simpleString + } + writer.write(prefix) + writer.write(str) + writer.write("\n") + + if (innerChildren.nonEmpty) { + innerChildren.init.foreach(_.generateTreeString( + depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, + addSuffix = addSuffix, maxDepth = maxDepth)) + innerChildren.last.generateTreeString( + depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, + addSuffix = addSuffix, maxDepth = maxDepth) + } + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString( + depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix, maxDepth)) + children.last.generateTreeString( + depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix, maxDepth) + } + } + else { + if (TreeNode.treeDepthWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was nested too deeply. " + + "This behavior can be adjusted by setting 'spark.debug.maxToStringTreeDepth' in " + + "SparkEnv.conf.") + } + writer.write(prefix) + writer.write("...\n") + } } /** @@ -701,3 +725,23 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case _ => false } } + +object TreeNode { + /** + * Query plans for large, deeply nested plans can get extremely large. To limit the impact, + * we add a parameter that limits the logging to the top layers if the tree gets too deep. + * This can be overridden by setting the 'spark.debug.maxToStringTreeDepth' conf in SparkEnv. + */ + val DEFAULT_MAX_TO_STRING_TREE_DEPTH = 15 + + def maxTreeToStringDepth: Int = { + if (SparkEnv.get != null) { + SparkEnv.get.conf.getInt("spark.debug.maxToStringTreeDepth", DEFAULT_MAX_TO_STRING_TREE_DEPTH) + } else { + DEFAULT_MAX_TO_STRING_TREE_DEPTH + } + } + + /** Whether we have warned about plan string truncation yet. */ + private val treeDepthWarningPrinted = new AtomicBoolean(false) +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 29bcbcae366c..5cc1a86a7bb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -454,8 +455,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, prefix = "", addSuffix = false) + addSuffix: Boolean = false, + maxDepth: Int = TreeNode.maxTreeToStringDepth): Unit = { + child.generateTreeString(depth, lastChildren, writer, verbose, prefix, addSuffix, maxDepth) } override def needCopyResult: Boolean = false @@ -730,8 +732,16 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, s"*($codegenStageId) ", false) + addSuffix: Boolean = false, + maxDepth: Int = TreeNode.maxTreeToStringDepth): Unit = { + child.generateTreeString( + depth, + lastChildren, + writer, + verbose, + s"*($codegenStageId) ", + false, + maxDepth) } override def needStopCheck: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index a5922d7c825d..b35e809c881d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext +case class Simple(a: String, b: Int) + class QueryExecutionSuite extends SharedSQLContext { def checkDumpedPlans(path: String, expected: Int): Unit = { assert(Source.fromFile(path).getLines.toList @@ -108,4 +110,16 @@ class QueryExecutionSuite extends SharedSQLContext { val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) } + + test("toString() tree depth") { + import testImplicits._ + + val s = Seq(Simple("a", 1), Simple("b", 3), Simple("c", 4)) + val ds = (1 until 30).foldLeft(s.toDF()) { case (newDs, _) => + newDs.join(s.toDF(), "a") + } + + val nLines = ds.queryExecution.optimizedPlan.toString.split("\n").length + assert(nLines <= 31) + } }