Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrd
class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)
// a help rule for ForcedMaxOutputRowsRule
extensions.injectResolutionRule(MarkAggregateOrderRule)

extensions.injectPostHocResolutionRule(KyuubiSqlClassification)
extensions.injectPostHocResolutionRule(RepartitionBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(RepartitionBeforeWritingHive)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)

// watchdog extension
// a help rule for ForcedMaxOutputRowsRule
extensions.injectResolutionRule(MarkAggregateOrderRule)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectPlannerStrategy(MaxPartitionStrategy)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,129 +18,7 @@
package org.apache.kyuubi.sql.watchdog

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.AnalysisContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, RepartitionByExpression, Sort, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.command.DataWritingCommand

import org.apache.kyuubi.sql.KyuubiSQLConf
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {}

object ForcedMaxOutputRowsConstraint {
val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__")
val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
}

/*
* Add ForcedMaxOutputRows rule for output rows limitation
* to avoid huge output rows of non_limit query unexpectedly
* mainly applied to cases as below:
*
* case 1:
* {{{
* SELECT [c1, c2, ...]
* }}}
*
* case 2:
* {{{
* WITH CTE AS (
* ...)
* SELECT [c1, c2, ...] FROM CTE ...
* }}}
*
* The Logical Rule add a GlobalLimit node before root project
* */
case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] {

private def isChildAggregate(a: Aggregate): Boolean = a
.aggregateExpressions.exists(p =>
p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
.contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))

private def isView: Boolean = {
val nestedViewDepth = AnalysisContext.get.nestedViewDepth
nestedViewDepth > 0
}

private def canInsertLimitInner(p: LogicalPlan): Boolean = p match {

case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false
case agg: Aggregate => !isChildAggregate(agg)
case _: RepartitionByExpression => true
case _: Distinct => true
case _: Filter => true
case _: Project => true
case Limit(_, _) => true
case _: Sort => true
case Union(children, _, _) =>
if (children.exists(_.isInstanceOf[DataWritingCommand])) {
false
} else {
true
}
case _ => false

}

private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {

maxOutputRowsOpt match {
case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
!p.maxRows.exists(_ <= forcedMaxOutputRows) &&
!isView
case None => false
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS)
plan match {
case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) =>
Limit(
maxOutputRowsOpt.get,
plan)
case _ => plan
}
}

}

case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] {

private def markChildAggregate(a: Aggregate): Unit = {
// mark child aggregate
a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue(
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE,
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
}

private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
/*
* The case mainly process order not aggregate column but grouping column as below
* SELECT c1, COUNT(*) as cnt
* FROM t1
* GROUP BY c1
* ORDER BY c1
* */
case a: Aggregate
if a.aggregateExpressions
.exists(x => x.resolved && x.name.equals("aggOrder")) =>
markChildAggregate(a)
plan

case _ =>
plan.children.foreach(_.foreach {
case agg: Aggregate => markChildAggregate(agg)
case _ => Unit
})
plan
}

override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf(
KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match {
case Some(_) => findAndMarkChildAggregate(plan)
case _ => plan
}
}
case class MarkAggregateOrderRule(sparkSession: SparkSession) extends MarkAggregateOrderBase {}
Loading