diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 3f9193731a2..f6b1ef0f754 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -32,15 +32,16 @@ import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrd class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { override def apply(extensions: SparkSessionExtensions): Unit = { KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions) - // a help rule for ForcedMaxOutputRowsRule - extensions.injectResolutionRule(MarkAggregateOrderRule) extensions.injectPostHocResolutionRule(KyuubiSqlClassification) extensions.injectPostHocResolutionRule(RepartitionBeforeWritingDatasource) extensions.injectPostHocResolutionRule(RepartitionBeforeWritingHive) - extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + // watchdog extension + // a help rule for ForcedMaxOutputRowsRule + extensions.injectResolutionRule(MarkAggregateOrderRule) + extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index 07ddb2cd48b..e92a69f7196 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -18,13 +18,9 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.AnalysisContext -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Alias -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, RepartitionByExpression, Sort, Union} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.kyuubi.sql.KyuubiSQLConf @@ -33,81 +29,18 @@ object ForcedMaxOutputRowsConstraint { val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" } -/* - * Add ForcedMaxOutputRows rule for output rows limitation - * to avoid huge output rows of non_limit query unexpectedly - * mainly applied to cases as below: - * - * case 1: - * {{{ - * SELECT [c1, c2, ...] - * }}} - * - * case 2: - * {{{ - * WITH CTE AS ( - * ...) - * SELECT [c1, c2, ...] FROM CTE ... - * }}} - * - * The Logical Rule add a GlobalLimit node before root project - * */ -case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] { - - private def isChildAggregate(a: Aggregate): Boolean = a - .aggregateExpressions.exists(p => +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + override protected def isChildAggregate(a: Aggregate): Boolean = + a.aggregateExpressions.exists(p => p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) - - private def isView: Boolean = { - val nestedViewDepth = AnalysisContext.get.nestedViewDepth - nestedViewDepth > 0 - } - - private def canInsertLimitInner(p: LogicalPlan): Boolean = p match { - - case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false - case agg: Aggregate => !isChildAggregate(agg) - case _: RepartitionByExpression => true - case _: Distinct => true - case _: Filter => true - case _: Project => true - case Limit(_, _) => true - case _: Sort => true - case Union(children, _, _) => - if (children.exists(_.isInstanceOf[DataWritingCommand])) { - false - } else { - true - } - case _ => false - - } - - private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { - - maxOutputRowsOpt match { - case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && - !p.maxRows.exists(_ <= forcedMaxOutputRows) && - !isView - case None => false - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) - plan match { - case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) => - Limit( - maxOutputRowsOpt.get, - plan) - case _ => plan - } - } - } -case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] { +/** + * After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version, + * for more detail, please see https://github.com/apache/spark/pull/32470 + */ +case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def markChildAggregate(a: Aggregate): Unit = { // mark child aggregate @@ -116,7 +49,7 @@ case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPla ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) } - private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { + protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { /* * The case mainly process order not aggregate column but grouping column as below * SELECT c1, COUNT(*) as cnt @@ -129,7 +62,6 @@ case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPla .exists(x => x.resolved && x.name.equals("aggOrder")) => markChildAggregate(a) plan - case _ => plan.children.foreach(_.foreach { case agg: Aggregate => markChildAggregate(agg) diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala index 642bd7fd773..957089340ca 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -17,387 +17,4 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit - -import org.apache.kyuubi.sql.KyuubiSQLConf -import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException - -class WatchDogSuite extends KyuubiSparkSQLExtensionTest { - override protected def beforeAll(): Unit = { - super.beforeAll() - setupData() - } - - case class LimitAndExpected(limit: Int, expected: Int) - val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) - - private def checkMaxPartition: Unit = { - withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") { - checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) - } - withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") { - sql("SELECT * FROM test where p=1").queryExecution.sparkPlan - - sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})") - .queryExecution.sparkPlan - - intercept[MaxPartitionExceedException]( - sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) - - intercept[MaxPartitionExceedException]( - sql("SELECT * FROM test").queryExecution.sparkPlan) - - intercept[MaxPartitionExceedException](sql( - s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") - .queryExecution.sparkPlan) - } - } - - test("watchdog with scan maxPartitions -- hive") { - Seq("textfile", "parquet").foreach { format => - withTable("test", "temp") { - sql( - s""" - |CREATE TABLE test(i int) - |PARTITIONED BY (p int) - |STORED AS $format""".stripMargin) - spark.range(0, 10, 1).selectExpr("id as col") - .createOrReplaceTempView("temp") - - for (part <- Range(0, 10)) { - sql( - s""" - |INSERT OVERWRITE TABLE test PARTITION (p='$part') - |select col from temp""".stripMargin) - } - checkMaxPartition - } - } - } - - test("watchdog with scan maxPartitions -- data source") { - withTempDir { dir => - withTempView("test") { - spark.range(10).selectExpr("id", "id as p") - .write - .partitionBy("p") - .mode("overwrite") - .save(dir.getCanonicalPath) - spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test") - checkMaxPartition - } - } - } - - test("test watchdog: simple SELECT STATEMENT") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => - List("", " DISTINCT").foreach { distinct => - assert(sql( - s""" - |SELECT $distinct * - |FROM t1 - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => - List("", "DISTINCT").foreach { distinct => - assert(sql( - s""" - |SELECT $distinct * - |FROM t1 - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - assert(!sql("SELECT count(*) FROM t1") - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, COUNT(*) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, COUNT(*) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: SELECT with CTE forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - val sorts = List("", "ORDER BY c1", "ORDER BY c2") - - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - |SELECT * - |FROM custom_cte - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - |SELECT * - |FROM custom_cte - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - - test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - assert(!sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT COUNT(*) - |FROM custom_cte - |""".stripMargin).queryExecution - .analyzed.isInstanceOf[GlobalLimit]) - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) as cnt - |FROM custom_cte - |GROUP BY c1 - |$having - |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) as cnt - |FROM custom_cte - |GROUP BY c1 - |$having - |$sort - |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - } - } - - test("test watchdog: UNION Statement for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - - List("", "ALL").foreach { x => - assert(sql( - s""" - |SELECT c1, c2 FROM t1 - |UNION $x - |SELECT c1, c2 FROM t2 - |UNION $x - |SELECT c1, c2 FROM t3 - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - - val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") - val havingConditions = List("", "HAVING cnt > 1") - - List("", "ALL").foreach { x => - havingConditions.foreach { having => - sorts.foreach { sort => - assert(sql( - s""" - |SELECT c1, count(c2) as cnt - |FROM t1 - |GROUP BY c1 - |$having - |UNION $x - |SELECT c1, COUNT(c2) as cnt - |FROM t2 - |GROUP BY c1 - |$having - |UNION $x - |SELECT c1, COUNT(c2) as cnt - |FROM t3 - |GROUP BY c1 - |$having - |$sort - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } - - limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => - assert(sql( - s""" - |SELECT c1, c2 FROM t1 - |UNION - |SELECT c1, c2 FROM t2 - |UNION - |SELECT c1, c2 FROM t3 - |LIMIT $limit - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(expected)) - } - } - } - - test("test watchdog: Select View Statement for forceMaxOutputRows") { - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") { - withTable("tmp_table", "tmp_union") { - withView("tmp_view", "tmp_view2") { - sql(s"create table tmp_table (a int, b int)") - sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - sql(s"create table tmp_union (a int, b int)") - sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)") - sql(s"create view tmp_view2 as select * from tmp_union") - assert(!sql( - s""" - |CREATE VIEW tmp_view - |as - |SELECT * FROM - |tmp_table - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - - assert(sql( - s""" - |SELECT * FROM - |tmp_view - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) - - assert(sql( - s""" - |SELECT * FROM - |tmp_view - |limit 11 - |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) - - assert(sql( - s""" - |SELECT * FROM - |(select * from tmp_view - |UNION - |select * from tmp_view2) - |ORDER BY a - |DESC - |""".stripMargin) - .collect().head.get(0).equals(10)) - } - } - } - } - - test("test watchdog: Insert Statement for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - withTable("tmp_table", "tmp_insert") { - spark.sql(s"create table tmp_table (a int, b int)") - spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - val multiInsertTableName1: String = "tmp_tbl1" - val multiInsertTableName2: String = "tmp_tbl2" - sql(s"drop table if exists $multiInsertTableName1") - sql(s"drop table if exists $multiInsertTableName2") - sql(s"create table $multiInsertTableName1 like tmp_table") - sql(s"create table $multiInsertTableName2 like tmp_table") - assert(!sql( - s""" - |FROM tmp_table - |insert into $multiInsertTableName1 select * limit 2 - |insert into $multiInsertTableName2 select * - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } - - test("test watchdog: Distribute by for forceMaxOutputRows") { - - withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - withTable("tmp_table") { - spark.sql(s"create table tmp_table (a int, b int)") - spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") - assert(sql( - s""" - |SELECT * - |FROM tmp_table - |DISTRIBUTE BY a - |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - } - } - } -} +class WatchDogSuite extends WatchDogSuiteBase {} diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 2b8920a254f..38426b8e672 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,6 +19,8 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.SparkSessionExtensions +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy} + // scalastyle:off line.size.limit /** * Depend on Spark SQL Extension framework, we can use this extension follow steps @@ -33,5 +35,9 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource) extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive) extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + + // watchdog extension + extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) + extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala new file mode 100644 index 00000000000..03d243f85e0 --- /dev/null +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE} + +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + + override protected def isChildAggregate(a: Aggregate): Boolean = false + + override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case WithCTE(plan, _) => this.canInsertLimitInner(plan) + case plan: LogicalPlan => super.canInsertLimitInner(plan) + } + + override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + p match { + case WithCTE(plan, _) => this.canInsertLimit(plan, maxOutputRowsOpt) + case _ => super.canInsertLimit(p, maxOutputRowsOpt) + } + } +} diff --git a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala new file mode 100644 index 00000000000..957089340ca --- /dev/null +++ b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +class WatchDogSuite extends WatchDogSuiteBase {} diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala new file mode 100644 index 00000000000..7f846d9616e --- /dev/null +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.catalyst.analysis.AnalysisContext +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DataWritingCommand + +import org.apache.kyuubi.sql.KyuubiSQLConf + +/* + * Add ForcedMaxOutputRows rule for output rows limitation + * to avoid huge output rows of non_limit query unexpectedly + * mainly applied to cases as below: + * + * case 1: + * {{{ + * SELECT [c1, c2, ...] + * }}} + * + * case 2: + * {{{ + * WITH CTE AS ( + * ...) + * SELECT [c1, c2, ...] FROM CTE ... + * }}} + * + * The Logical Rule add a GlobalLimit node before root project + * */ +trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { + + protected def isChildAggregate(a: Aggregate): Boolean + + protected def isView: Boolean = { + val nestedViewDepth = AnalysisContext.get.nestedViewDepth + nestedViewDepth > 0 + } + + protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false + case agg: Aggregate => !isChildAggregate(agg) + case _: RepartitionByExpression => true + case _: Distinct => true + case _: Filter => true + case _: Project => true + case Limit(_, _) => true + case _: Sort => true + case Union(children, _, _) => + if (children.exists(_.isInstanceOf[DataWritingCommand])) { + false + } else { + true + } + case _ => false + } + + protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + maxOutputRowsOpt match { + case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && + !p.maxRows.exists(_ <= forcedMaxOutputRows) && + !isView + case None => false + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) + plan match { + case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) => + Limit( + maxOutputRowsOpt.get, + plan) + case _ => plan + } + } +} diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala similarity index 100% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala similarity index 100% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxPartitionStrategy.scala diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala similarity index 96% rename from dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala rename to dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala index 65dd016e3d1..ce496eb474c 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.types.StructType trait PruneFileSourcePartitionHelper extends PredicateHelper { diff --git a/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala new file mode 100644 index 00000000000..a604308a537 --- /dev/null +++ b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException + +trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + case class LimitAndExpected(limit: Int, expected: Int) + val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) + + private def checkMaxPartition: Unit = { + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") { + checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) + } + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") { + sql("SELECT * FROM test where p=1").queryExecution.sparkPlan + + sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})") + .queryExecution.sparkPlan + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException](sql( + s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") + .queryExecution.sparkPlan) + } + } + + test("watchdog with scan maxPartitions -- hive") { + Seq("textfile", "parquet").foreach { format => + withTable("test", "temp") { + sql( + s""" + |CREATE TABLE test(i int) + |PARTITIONED BY (p int) + |STORED AS $format""".stripMargin) + spark.range(0, 10, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Range(0, 10)) { + sql( + s""" + |INSERT OVERWRITE TABLE test PARTITION (p='$part') + |select col from temp""".stripMargin) + } + checkMaxPartition + } + } + } + + test("watchdog with scan maxPartitions -- data source") { + withTempDir { dir => + withTempView("test") { + spark.range(10).selectExpr("id", "id as p") + .write + .partitionBy("p") + .mode("overwrite") + .save(dir.getCanonicalPath) + spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test") + checkMaxPartition + } + } + } + + test("test watchdog: simple SELECT STATEMENT") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", " DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", "DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql("SELECT count(*) FROM t1") + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT with CTE forceMaxOutputRows") { + // simple CTE + val q1 = + """ + |WITH t2 AS ( + | SELECT * FROM t1 + |) + |""".stripMargin + + // nested CTE + val q2 = + """ + |WITH + | t AS (SELECT * FROM t1), + | t2 AS ( + | WITH t3 AS (SELECT * FROM t1) + | SELECT * FROM t3 + | ) + |""".stripMargin + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + val sorts = List("", "ORDER BY c1", "ORDER BY c2") + + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql( + """ + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT COUNT(*) + |FROM custom_cte + |""".stripMargin).queryExecution + .analyzed.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: UNION Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ALL").foreach { x => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION $x + |SELECT c1, c2 FROM t2 + |UNION $x + |SELECT c1, c2 FROM t3 + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + List("", "ALL").foreach { x => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, count(c2) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t2 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t3 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION + |SELECT c1, c2 FROM t2 + |UNION + |SELECT c1, c2 FROM t3 + |LIMIT $limit + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + + test("test watchdog: Select View Statement for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") { + withTable("tmp_table", "tmp_union") { + withView("tmp_view", "tmp_view2") { + sql(s"create table tmp_table (a int, b int)") + sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + sql(s"create table tmp_union (a int, b int)") + sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)") + sql(s"create view tmp_view2 as select * from tmp_union") + assert(!sql( + s""" + |CREATE VIEW tmp_view + |as + |SELECT * FROM + |tmp_table + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |limit 11 + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |(select * from tmp_view + |UNION + |select * from tmp_view2) + |ORDER BY a + |DESC + |""".stripMargin) + .collect().head.get(0).equals(10)) + } + } + } + } + + test("test watchdog: Insert Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table", "tmp_insert") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + val multiInsertTableName1: String = "tmp_tbl1" + val multiInsertTableName2: String = "tmp_tbl2" + sql(s"drop table if exists $multiInsertTableName1") + sql(s"drop table if exists $multiInsertTableName2") + sql(s"create table $multiInsertTableName1 like tmp_table") + sql(s"create table $multiInsertTableName2 like tmp_table") + assert(!sql( + s""" + |FROM tmp_table + |insert into $multiInsertTableName1 select * limit 2 + |insert into $multiInsertTableName2 select * + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } + + test("test watchdog: Distribute by for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + assert(sql( + s""" + |SELECT * + |FROM tmp_table + |DISTRIBUTE BY a + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } +}