Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,41 @@ object ResolveHints {
object ResolveCoalesceHints extends Rule[LogicalPlan] {
private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION")

private def createRepartitionByExpression(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numPartitions: Int, parameters: Seq[Any], h: UnresolvedHint): RepartitionByExpression = {
val exprs = parameters.drop(1)
val errExprs = exprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (errExprs.nonEmpty) throw new AnalysisException(
s"""Invalid type exprs : $errExprs
|expects UnresolvedAttribute type
""".stripMargin)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz add tests for this exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will add this later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't add this test yet?

RepartitionByExpression(
exprs.map(_.asInstanceOf[UnresolvedAttribute]), h.child, numPartitions)
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the case, SELECT /*+ REPARTITION(a) */ * FROM t?

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
val hintName = h.name.toUpperCase(Locale.ROOT)
val shuffle = hintName match {
case "REPARTITION" => true
case "COALESCE" => false
}
val numPartitions = h.parameters match {

h.parameters match {
case Seq(IntegerLiteral(numPartitions)) =>
numPartitions
Repartition(numPartitions, shuffle, h.child)
case Seq(numPartitions: Int) =>
numPartitions
Repartition(numPartitions, shuffle, h.child)

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param, h)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about passing param.tail instead of param here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eum, that will be more elegant。

case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param, h)

case _ =>
throw new AnalysisException(s"$hintName Hint expects a partition number as parameter")
throw new AnalysisException("Repartition hint expects a partition number " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz keep $hintName here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice.

"and columns")
}
Repartition(numPartitions, shuffle, h.child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import org.apache.log4j.spi.LoggingEvent

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class ResolveHintsSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
Expand Down Expand Up @@ -150,24 +151,42 @@ class ResolveHintsSuite extends AnalysisTest {
UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")),
Repartition(numPartitions = 200, shuffle = true, child = testRelation))

val errMsgCoal = "COALESCE Hint expects a partition number as parameter"
val errMsg = "Repartition hint expects a partition number and columns"

assertAnalysisError(
UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))
assertAnalysisError(
UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))
assertAnalysisError(
UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))

val errMsgRepa = "REPARTITION Hint expects a partition number as parameter"
assertAnalysisError(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
Seq(errMsgRepa))
Seq(errMsg))
assertAnalysisError(
UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")),
Seq(errMsgRepa))
Seq(errMsg))

checkAnalysis(
UnresolvedHint("RePartition", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10))

checkAnalysis(
UnresolvedHint("REPARTITION", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10))

assertAnalysisError(
UnresolvedHint("REPARTITION", Seq(AttributeReference("a", IntegerType)()), table("TaBlE")),
Seq(errMsg))

assertAnalysisError(
UnresolvedHint("REPARTITION",
Seq(Literal(1.0), AttributeReference("a", IntegerType)()),
table("TaBlE")),
Seq(errMsg))
}

test("log warnings for invalid hints") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,37 @@ class PlanParserSuite extends AnalysisTest {
table("t").select(star()))))

intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input")

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star()))))

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star())))))

comparePlans(
parsePlan(
"""
|SELECT
|/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
|* FROM t
""".stripMargin),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")),
table("t").select(star()))))))
}

test("SPARK-20854: select hint syntax with expressions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,9 @@ class DataFrameHintSuite extends AnalysisTest with SharedSQLContext {
check(
df.hint("REPARTITION", 100),
UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan))

check(
df.hint("REPARTITION", 10, $"id".expr),
UnresolvedHint("REPARTITION", Seq(10, $"id".expr), df.logicalPlan))
}
}