From 4847dbf30ea0e89675fdb3b681b699d4f41a1289 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 20 Dec 2021 15:20:50 +0800 Subject: [PATCH 1/5] watchdog support for spark-3.2 --- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 6 +- .../watchdog/ForcedMaxOutputRowsRule.scala | 126 +----- .../org/apache/spark/sql/WatchDogSuite.scala | 385 +--------------- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 8 + .../watchdog/ForcedMaxOutputRowsRule.scala | 44 ++ .../org/apache/spark/sql/WatchDogSuite.scala | 3 + .../watchdog/ForcedMaxOutputRowsBase.scala | 141 ++++++ .../watchdog/KyuubiWatchDogException.scala | 0 .../sql/watchdog/MaxPartitionStrategy.scala | 0 .../sql/PruneFileSourcePartitionHelper.scala | 19 +- .../apache/spark/sql/WatchDogSuiteBase.scala | 418 ++++++++++++++++++ 11 files changed, 621 insertions(+), 529 deletions(-) create mode 100644 dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala create mode 100644 dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala create mode 100644 dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala rename dev/{kyuubi-extension-spark-3-1 => kyuubi-extension-spark-common}/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala (100%) rename dev/{kyuubi-extension-spark-3-1 => kyuubi-extension-spark-common}/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala (100%) rename dev/{kyuubi-extension-spark-3-1 => kyuubi-extension-spark-common}/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala (59%) create mode 100644 dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 3f9193731a2..f2ebeb7aec6 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -32,15 +32,15 @@ import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrd class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { override def apply(extensions: SparkSessionExtensions): Unit = { KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions) - // a help rule for ForcedMaxOutputRowsRule - extensions.injectResolutionRule(MarkAggregateOrderRule) extensions.injectPostHocResolutionRule(KyuubiSqlClassification) extensions.injectPostHocResolutionRule(RepartitionBeforeWritingDatasource) extensions.injectPostHocResolutionRule(RepartitionBeforeWritingHive) - extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + // a help rule for ForcedMaxOutputRowsRule + extensions.injectResolutionRule(MarkAggregateOrderRule) + extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index 07ddb2cd48b..b8eb153d7ca 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -18,129 +18,7 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.AnalysisContext -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Alias -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, RepartitionByExpression, Sort, Union} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.execution.command.DataWritingCommand -import org.apache.kyuubi.sql.KyuubiSQLConf +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {} -object ForcedMaxOutputRowsConstraint { - val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") - val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" -} - -/* - * Add ForcedMaxOutputRows rule for output rows limitation - * to avoid huge output rows of non_limit query unexpectedly - * mainly applied to cases as below: - * - * case 1: - * {{{ - * SELECT [c1, c2, ...] - * }}} - * - * case 2: - * {{{ - * WITH CTE AS ( - * ...) - * SELECT [c1, c2, ...] FROM CTE ... - * }}} - * - * The Logical Rule add a GlobalLimit node before root project - * */ -case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] { - - private def isChildAggregate(a: Aggregate): Boolean = a - .aggregateExpressions.exists(p => - p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) - .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) - - private def isView: Boolean = { - val nestedViewDepth = AnalysisContext.get.nestedViewDepth - nestedViewDepth > 0 - } - - private def canInsertLimitInner(p: LogicalPlan): Boolean = p match { - - case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false - case agg: Aggregate => !isChildAggregate(agg) - case _: RepartitionByExpression => true - case _: Distinct => true - case _: Filter => true - case _: Project => true - case Limit(_, _) => true - case _: Sort => true - case Union(children, _, _) => - if (children.exists(_.isInstanceOf[DataWritingCommand])) { - false - } else { - true - } - case _ => false - - } - - private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { - - maxOutputRowsOpt match { - case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && - !p.maxRows.exists(_ <= forcedMaxOutputRows) && - !isView - case None => false - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) - plan match { - case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) => - Limit( - maxOutputRowsOpt.get, - plan) - case _ => plan - } - } - -} - -case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] { - - private def markChildAggregate(a: Aggregate): Unit = { - // mark child aggregate - a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) - } - - private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { - /* - * The case mainly process order not aggregate column but grouping column as below - * SELECT c1, COUNT(*) as cnt - * FROM t1 - * GROUP BY c1 - * ORDER BY c1 - * */ - case a: Aggregate - if a.aggregateExpressions - .exists(x => x.resolved && x.name.equals("aggOrder")) => - markChildAggregate(a) - plan - - case _ => - plan.children.foreach(_.foreach { - case agg: Aggregate => markChildAggregate(agg) - case _ => Unit - }) - plan - } - - override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( - KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match { - case Some(_) => findAndMarkChildAggregate(plan) - case _ => plan - } -} +case class MarkAggregateOrderRule(sparkSession: SparkSession) extends MarkAggregateOrderBase {} diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala index 642bd7fd773..957089340ca 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -17,387 +17,4 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit - -import org.apache.kyuubi.sql.KyuubiSQLConf -import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException - -class WatchDogSuite extends KyuubiSparkSQLExtensionTest { - override protected def beforeAll(): Unit = { - super.beforeAll() - setupData() - } - - case class LimitAndExpected(limit: Int, expected: Int) - val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) - - private def checkMaxPartition: Unit = { - withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") { - checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) - } - withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") { - sql("SELECT * FROM test where p=1").queryExecution.sparkPlan - - sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})") - .queryExecution.sparkPlan - - intercept[MaxPartitionExceedException]( - sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) - - intercept[MaxPartitionExceedException]( - sql("SELECT * FROM test").queryExecution.sparkPlan) - - intercept[MaxPartitionExceedException](sql( - s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") - .queryExecution.sparkPlan) - } - } - - test("watchdog with scan maxPartitions -- hive") { - Seq("textfile", "parquet").foreach { format => - withTable("test", "temp") { - sql( - s""" - |CREATE TABLE test(i int) - |PARTITIONED BY (p int) - |STORED AS $format""".stripMargin) - spark.range(0, 10, 1).selectExpr("id as col") - .createOrReplaceTempView("temp") - - for (part <- Range(0, 10)) { - sql( - s""" - |INSERT OVERWRITE TABLE test PARTITION (p='$part') - |select col from temp""".stripMargin) - } - checkMaxPartition - } - } - } - - test("watchdog with scan maxPartitions -- data source") { - withTempDir { dir => - withTempView("test") { - spark.range(10).selectExpr("id", "id as p") - .write - .partitionBy("p") - .mode("overwrite") - .save(dir.getCanonicalPath) - spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test") - checkMaxPartition - } - } - } - - test("test watchdog: simple SELECT STATEMENT") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => - List("", " DISTINCT").foreach { distinct => - assert(sql( - s""" - |SELECT $distinct * - |FROM t1 - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => - List("", "DISTINCT").foreach { distinct => - assert(sql( - s""" - |SELECT $distinct * - |FROM t1 - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - assert(!sql("SELECT count(*) FROM t1") - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, COUNT(*) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, COUNT(*) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: SELECT with CTE forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - val sorts = List("", "ORDER BY c1", "ORDER BY c2") - - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - |SELECT * - |FROM custom_cte - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - |SELECT * - |FROM custom_cte - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - - test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - assert(!sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT COUNT(*) - |FROM custom_cte - |""".stripMargin).queryExecution - .analyzed.isInstanceOf[GlobalLimit]) - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) as cnt - |FROM custom_cte - |GROUP BY c1 - |$having - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) as cnt - |FROM custom_cte - |GROUP BY c1 - |$having - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: UNION Statement for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - List("", "ALL").foreach { x => - assert(sql( - s""" - |SELECT c1, c2 FROM t1 - |UNION $x - |SELECT c1, c2 FROM t2 - |UNION $x - |SELECT c1, c2 FROM t3 - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - List("", "ALL").foreach { x => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, count(c2) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |UNION $x - |SELECT c1, COUNT(c2) as cnt - |FROM t2 - |GROUP BY c1 - |$having - |UNION $x - |SELECT c1, COUNT(c2) as cnt - |FROM t3 - |GROUP BY c1 - |$having - |$sort - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - assert(sql( - s""" - |SELECT c1, c2 FROM t1 - |UNION - |SELECT c1, c2 FROM t2 - |UNION - |SELECT c1, c2 FROM t3 - |LIMIT $limit - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - - test("test watchdog: Select View Statement for forceMaxOutputRows") { - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") { - withTable("tmp_table", "tmp_union") { - withView("tmp_view", "tmp_view2") { - sql(s"create table tmp_table (a int, b int)") - sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - sql(s"create table tmp_union (a int, b int)") - sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)") - sql(s"create view tmp_view2 as select * from tmp_union") - assert(!sql( - s""" - |CREATE VIEW tmp_view - |as - |SELECT * FROM - |tmp_table - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - - assert(sql( - s""" - |SELECT * FROM - |tmp_view - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) - - assert(sql( - s""" - |SELECT * FROM - |tmp_view - |limit 11 - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) - - assert(sql( - s""" - |SELECT * FROM - |(select * from tmp_view - |UNION - |select * from tmp_view2) - |ORDER BY a - |DESC - |""".stripMargin) - .collect().head.get(0).equals(10)) - } - } - } - } - - test("test watchdog: Insert Statement for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - withTable("tmp_table", "tmp_insert") { - spark.sql(s"create table tmp_table (a int, b int)") - spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - val multiInsertTableName1: String = "tmp_tbl1" - val multiInsertTableName2: String = "tmp_tbl2" - sql(s"drop table if exists $multiInsertTableName1") - sql(s"drop table if exists $multiInsertTableName2") - sql(s"create table $multiInsertTableName1 like tmp_table") - sql(s"create table $multiInsertTableName2 like tmp_table") - assert(!sql( - s""" - |FROM tmp_table - |insert into $multiInsertTableName1 select * limit 2 - |insert into $multiInsertTableName2 select * - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } - - test("test watchdog: Distribute by for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - withTable("tmp_table") { - spark.sql(s"create table tmp_table (a int, b int)") - spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - assert(sql( - s""" - |SELECT * - |FROM tmp_table - |DISTRIBUTE BY a - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } -} +class WatchDogSuite extends WatchDogSuiteBase {} diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 2b8920a254f..996a7f2df81 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,6 +19,8 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.SparkSessionExtensions +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxPartitionStrategy} + // scalastyle:off line.size.limit /** * Depend on Spark SQL Extension framework, we can use this extension follow steps @@ -33,5 +35,11 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource) extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + + // watchdog extension + // a help rule for ForcedMaxOutputRowsRule + extensions.injectResolutionRule(MarkAggregateOrderRule) + extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) + extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala new file mode 100644 index 00000000000..79e1acc2550 --- /dev/null +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, WithCTE} + +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case WithCTE(plan, _) => this.canInsertLimitInner(plan) + case plan: LogicalPlan => super.canInsertLimitInner(plan) + } + + override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + p match { + case WithCTE(plan, _) => this.canInsertLimit(plan, maxOutputRowsOpt) + case _ => super.canInsertLimit(p, maxOutputRowsOpt) + } + } +} + +case class MarkAggregateOrderRule(sparkSession: SparkSession) extends MarkAggregateOrderBase { + override protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { + case withCTE @ WithCTE(plan, _) => + withCTE.copy(plan = this.findAndMarkChildAggregate(plan)) + case _ => + super.findAndMarkChildAggregate(plan) + } +} diff --git a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala new file mode 100644 index 00000000000..df35922d162 --- /dev/null +++ b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -0,0 +1,3 @@ +package org.apache.spark.sql + +class WatchDogSuite extends WatchDogSuiteBase {} diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala new file mode 100644 index 00000000000..08bf6e57cfa --- /dev/null +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.catalyst.analysis.AnalysisContext +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.command.DataWritingCommand + +import org.apache.kyuubi.sql.KyuubiSQLConf + +object ForcedMaxOutputRowsConstraint { + val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") + val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" +} + +/* + * Add ForcedMaxOutputRows rule for output rows limitation + * to avoid huge output rows of non_limit query unexpectedly + * mainly applied to cases as below: + * + * case 1: + * {{{ + * SELECT [c1, c2, ...] + * }}} + * + * case 2: + * {{{ + * WITH CTE AS ( + * ...) + * SELECT [c1, c2, ...] FROM CTE ... + * }}} + * + * The Logical Rule add a GlobalLimit node before root project + * */ +trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { + + protected def isChildAggregate(a: Aggregate): Boolean = a + .aggregateExpressions.exists(p => + p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) + .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) + + protected def isView: Boolean = { + val nestedViewDepth = AnalysisContext.get.nestedViewDepth + nestedViewDepth > 0 + } + + protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false + case agg: Aggregate => !isChildAggregate(agg) + case _: RepartitionByExpression => true + case _: Distinct => true + case _: Filter => true + case _: Project => true + case Limit(_, _) => true + case _: Sort => true + case Union(children, _, _) => + if (children.exists(_.isInstanceOf[DataWritingCommand])) { + false + } else { + true + } + case _ => false + } + + protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + maxOutputRowsOpt match { + case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && + !p.maxRows.exists(_ <= forcedMaxOutputRows) && + !isView + case None => false + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) + plan match { + case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) => + Limit( + maxOutputRowsOpt.get, + plan) + case _ => plan + } + } + +} + +trait MarkAggregateOrderBase extends Rule[LogicalPlan] { + + private def markChildAggregate(a: Aggregate): Unit = { + // mark child aggregate + a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) + } + + protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { + /* + * The case mainly process order not aggregate column but grouping column as below + * SELECT c1, COUNT(*) as cnt + * FROM t1 + * GROUP BY c1 + * ORDER BY c1 + * */ + case a: Aggregate + if a.aggregateExpressions + .exists(x => x.resolved && x.name.equals("aggOrder")) => + markChildAggregate(a) + plan + case _ => + plan.children.foreach(_.foreach { + case agg: Aggregate => markChildAggregate(agg) + case _ => Unit + }) + plan + } + + override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( + KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match { + case Some(_) => findAndMarkChildAggregate(plan) + case _ => plan + } +} diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala similarity index 100% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala similarity index 100% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala similarity index 59% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala index 65dd016e3d1..ad82aa42f2f 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala @@ -1,25 +1,8 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.types.StructType trait PruneFileSourcePartitionHelper extends PredicateHelper { diff --git a/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala new file mode 100644 index 00000000000..a604308a537 --- /dev/null +++ b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException + +trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + case class LimitAndExpected(limit: Int, expected: Int) + val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) + + private def checkMaxPartition: Unit = { + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") { + checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) + } + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") { + sql("SELECT * FROM test where p=1").queryExecution.sparkPlan + + sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})") + .queryExecution.sparkPlan + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException](sql( + s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") + .queryExecution.sparkPlan) + } + } + + test("watchdog with scan maxPartitions -- hive") { + Seq("textfile", "parquet").foreach { format => + withTable("test", "temp") { + sql( + s""" + |CREATE TABLE test(i int) + |PARTITIONED BY (p int) + |STORED AS $format""".stripMargin) + spark.range(0, 10, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Range(0, 10)) { + sql( + s""" + |INSERT OVERWRITE TABLE test PARTITION (p='$part') + |select col from temp""".stripMargin) + } + checkMaxPartition + } + } + } + + test("watchdog with scan maxPartitions -- data source") { + withTempDir { dir => + withTempView("test") { + spark.range(10).selectExpr("id", "id as p") + .write + .partitionBy("p") + .mode("overwrite") + .save(dir.getCanonicalPath) + spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test") + checkMaxPartition + } + } + } + + test("test watchdog: simple SELECT STATEMENT") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", " DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", "DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql("SELECT count(*) FROM t1") + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT with CTE forceMaxOutputRows") { + // simple CTE + val q1 = + """ + |WITH t2 AS ( + | SELECT * FROM t1 + |) + |""".stripMargin + + // nested CTE + val q2 = + """ + |WITH + | t AS (SELECT * FROM t1), + | t2 AS ( + | WITH t3 AS (SELECT * FROM t1) + | SELECT * FROM t3 + | ) + |""".stripMargin + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + val sorts = List("", "ORDER BY c1", "ORDER BY c2") + + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql( + """ + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT COUNT(*) + |FROM custom_cte + |""".stripMargin).queryExecution + .analyzed.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: UNION Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ALL").foreach { x => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION $x + |SELECT c1, c2 FROM t2 + |UNION $x + |SELECT c1, c2 FROM t3 + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + List("", "ALL").foreach { x => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, count(c2) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t2 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t3 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION + |SELECT c1, c2 FROM t2 + |UNION + |SELECT c1, c2 FROM t3 + |LIMIT $limit + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + + test("test watchdog: Select View Statement for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") { + withTable("tmp_table", "tmp_union") { + withView("tmp_view", "tmp_view2") { + sql(s"create table tmp_table (a int, b int)") + sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + sql(s"create table tmp_union (a int, b int)") + sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)") + sql(s"create view tmp_view2 as select * from tmp_union") + assert(!sql( + s""" + |CREATE VIEW tmp_view + |as + |SELECT * FROM + |tmp_table + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |limit 11 + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |(select * from tmp_view + |UNION + |select * from tmp_view2) + |ORDER BY a + |DESC + |""".stripMargin) + .collect().head.get(0).equals(10)) + } + } + } + } + + test("test watchdog: Insert Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table", "tmp_insert") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + val multiInsertTableName1: String = "tmp_tbl1" + val multiInsertTableName2: String = "tmp_tbl2" + sql(s"drop table if exists $multiInsertTableName1") + sql(s"drop table if exists $multiInsertTableName2") + sql(s"create table $multiInsertTableName1 like tmp_table") + sql(s"create table $multiInsertTableName2 like tmp_table") + assert(!sql( + s""" + |FROM tmp_table + |insert into $multiInsertTableName1 select * limit 2 + |insert into $multiInsertTableName2 select * + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } + + test("test watchdog: Distribute by for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + assert(sql( + s""" + |SELECT * + |FROM tmp_table + |DISTRIBUTE BY a + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } +} From 44726deef35d0ffefa446fa3d8dfaca9c87cb909 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 20 Dec 2021 15:37:40 +0800 Subject: [PATCH 2/5] fix style --- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 1 + .../sql/PruneFileSourcePartitionHelper.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index f2ebeb7aec6..f6b1ef0f754 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -38,6 +38,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(RepartitionBeforeWritingHive) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + // watchdog extension // a help rule for ForcedMaxOutputRowsRule extensions.injectResolutionRule(MarkAggregateOrderRule) extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala index ad82aa42f2f..ce496eb474c 100644 --- a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} From 364fc26e6a28f3586764de388bf59744e1f93fa7 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 20 Dec 2021 15:51:01 +0800 Subject: [PATCH 3/5] add license header --- .../org/apache/spark/sql/WatchDogSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala index df35922d162..957089340ca 100644 --- a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala +++ b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql class WatchDogSuite extends WatchDogSuiteBase {} From 0ce83ba5a3c57ba28df81a8cb800db9807a7b66e Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 21 Dec 2021 10:10:09 +0800 Subject: [PATCH 4/5] remove MarkAggregateOrderRule --- .../watchdog/ForcedMaxOutputRowsRule.scala | 58 ++++++++++++++++++- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 4 +- .../watchdog/ForcedMaxOutputRowsRule.scala | 14 ++--- .../watchdog/ForcedMaxOutputRowsBase.scala | 49 +--------------- 4 files changed, 62 insertions(+), 63 deletions(-) diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index b8eb153d7ca..f5042146e0f 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -18,7 +18,61 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag -case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {} +import org.apache.kyuubi.sql.KyuubiSQLConf -case class MarkAggregateOrderRule(sparkSession: SparkSession) extends MarkAggregateOrderBase {} +object ForcedMaxOutputRowsConstraint { + val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") + val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" +} + +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + override protected def isChildAggregate(a: Aggregate): Boolean = + a.aggregateExpressions.exists(p => + p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) + .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) +} + +/** + * After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version, + * for more detail, please see https://github.com/apache/spark/pull/32470 + */ +case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { + + private def markChildAggregate(a: Aggregate): Unit = { + // mark child aggregate + a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) + } + + protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { + /* + * The case mainly process order not aggregate column but grouping column as below + * SELECT c1, COUNT(*) as cnt + * FROM t1 + * GROUP BY c1 + * ORDER BY c1 + * */ + case a: Aggregate + if a.aggregateExpressions + .exists(x => x.resolved && x.name.equals("aggOrder")) => + markChildAggregate(a) + plan + case _ => + plan.children.foreach(_.foreach { + case agg: Aggregate => markChildAggregate(agg) + case _ => Unit + }) + plan + } + + override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( + KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match { + case Some(_) => findAndMarkChildAggregate(plan) + case _ => plan + } +} diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 996a7f2df81..38426b8e672 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.SparkSessionExtensions -import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxPartitionStrategy} +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy} // scalastyle:off line.size.limit /** @@ -37,8 +37,6 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) // watchdog extension - // a help rule for ForcedMaxOutputRowsRule - extensions.injectResolutionRule(MarkAggregateOrderRule) extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) extensions.injectPlannerStrategy(MaxPartitionStrategy) } diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index 79e1acc2550..03d243f85e0 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -18,9 +18,12 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE} case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + + override protected def isChildAggregate(a: Aggregate): Boolean = false + override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { case WithCTE(plan, _) => this.canInsertLimitInner(plan) case plan: LogicalPlan => super.canInsertLimitInner(plan) @@ -33,12 +36,3 @@ case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMax } } } - -case class MarkAggregateOrderRule(sparkSession: SparkSession) extends MarkAggregateOrderBase { - override protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { - case withCTE @ WithCTE(plan, _) => - withCTE.copy(plan = this.findAndMarkChildAggregate(plan)) - case _ => - super.findAndMarkChildAggregate(plan) - } -} diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala index 08bf6e57cfa..7f846d9616e 100644 --- a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala @@ -22,16 +22,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.kyuubi.sql.KyuubiSQLConf -object ForcedMaxOutputRowsConstraint { - val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") - val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" -} - /* * Add ForcedMaxOutputRows rule for output rows limitation * to avoid huge output rows of non_limit query unexpectedly @@ -53,10 +47,7 @@ object ForcedMaxOutputRowsConstraint { * */ trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { - protected def isChildAggregate(a: Aggregate): Boolean = a - .aggregateExpressions.exists(p => - p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) - .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) + protected def isChildAggregate(a: Aggregate): Boolean protected def isView: Boolean = { val nestedViewDepth = AnalysisContext.get.nestedViewDepth @@ -100,42 +91,4 @@ trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { case _ => plan } } - -} - -trait MarkAggregateOrderBase extends Rule[LogicalPlan] { - - private def markChildAggregate(a: Aggregate): Unit = { - // mark child aggregate - a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) - } - - protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { - /* - * The case mainly process order not aggregate column but grouping column as below - * SELECT c1, COUNT(*) as cnt - * FROM t1 - * GROUP BY c1 - * ORDER BY c1 - * */ - case a: Aggregate - if a.aggregateExpressions - .exists(x => x.resolved && x.name.equals("aggOrder")) => - markChildAggregate(a) - plan - case _ => - plan.children.foreach(_.foreach { - case agg: Aggregate => markChildAggregate(agg) - case _ => Unit - }) - plan - } - - override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( - KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match { - case Some(_) => findAndMarkChildAggregate(plan) - case _ => plan - } } From 5399a3f09f21a652ef42b282809168fd08af3bf9 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 21 Dec 2021 10:25:58 +0800 Subject: [PATCH 5/5] fix style --- .../apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index f5042146e0f..e92a69f7196 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -58,8 +58,8 @@ case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[Logic * ORDER BY c1 * */ case a: Aggregate - if a.aggregateExpressions - .exists(x => x.resolved && x.name.equals("aggOrder")) => + if a.aggregateExpressions + .exists(x => x.resolved && x.name.equals("aggOrder")) => markChildAggregate(a) plan case _ =>