Skip to content

Commit 24bce72

Browse files
szehon-hosunchao
authored andcommitted
[SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle
### Why are the changes needed? Support SPJ one-side shuffle if other side has partition transform expression ### How was this patch tested? New unit test in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46255 from szehon-ho/spj_auto_bucket. Authored-by: Szehon Ho <szehon.apache@gmail.com> Signed-off-by: Chao Sun <chao@openai.com>
1 parent 201df0d commit 24bce72

5 files changed

Lines changed: 179 additions & 26 deletions

File tree

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark
1919

2020
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
2121

22+
import scala.collection.immutable.ArraySeq
2223
import scala.collection.mutable
2324
import scala.collection.mutable.ArrayBuffer
2425
import scala.math.log10
@@ -149,7 +150,9 @@ private[spark] class KeyGroupedPartitioner(
149150
override val numPartitions: Int) extends Partitioner {
150151
override def getPartition(key: Any): Int = {
151152
val keys = key.asInstanceOf[Seq[Any]]
152-
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
153+
val normalizedKeys = ArraySeq.from(keys)
154+
valueMap.getOrElseUpdate(normalizedKeys,
155+
Utils.nonNegativeMod(normalizedKeys.hashCode, numPartitions))
153156
}
154157
}
155158

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction}
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
22+
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
23+
import org.apache.spark.sql.errors.QueryExecutionErrors
2124
import org.apache.spark.sql.types.DataType
2225

2326
/**
@@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType
3033
case class TransformExpression(
3134
function: BoundFunction,
3235
children: Seq[Expression],
33-
numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable {
36+
numBucketsOpt: Option[Int] = None) extends Expression {
3437

3538
override def nullable: Boolean = true
3639

@@ -113,4 +116,23 @@ case class TransformExpression(
113116

114117
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
115118
copy(children = newChildren)
119+
120+
private lazy val resolvedFunction: Option[Expression] = this match {
121+
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
122+
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
123+
Seq(Literal(numBuckets)) ++ arguments))
124+
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
125+
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
126+
case _ => None
127+
}
128+
129+
override def eval(input: InternalRow): Any = {
130+
resolvedFunction match {
131+
case Some(fn) => fn.eval(input)
132+
case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
133+
}
134+
}
135+
136+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
137+
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
116138
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,12 +871,30 @@ case class KeyGroupedShuffleSpec(
871871
if (results.forall(p => p.isEmpty)) None else Some(results)
872872
}
873873

874-
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
875-
// Only support partition expressions are AttributeReference for now
876-
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
874+
override def canCreatePartitioning: Boolean = {
875+
// Allow one side shuffle for SPJ for now only if partially-clustered is not enabled
876+
// and for join keys less than partition keys only if transforms are not enabled.
877+
val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
878+
e: Expression => e.isInstanceOf[AttributeReference]
879+
} else {
880+
e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
881+
}
882+
SQLConf.get.v2BucketingShuffleEnabled &&
883+
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
884+
partitioning.expressions.forall(checkExprType)
885+
}
886+
887+
877888

878889
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
879-
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
890+
val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map {
891+
case (c, e: TransformExpression) => TransformExpression(
892+
e.function, Seq(c), e.numBucketsOpt)
893+
case (c, _) => c
894+
}
895+
KeyGroupedPartitioning(newExpressions,
896+
partitioning.numPartitions,
897+
partitioning.partitionValues)
880898
}
881899
}
882900

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
11361136
val df = createJoinTestDF(Seq("arrive_time" -> "time"))
11371137
val shuffles = collectShuffles(df.queryExecution.executedPlan)
11381138
if (shuffle) {
1139-
assert(shuffles.size == 2, "partitioning with transform not work now")
1139+
assert(shuffles.size == 1, "partitioning with transform should trigger SPJ")
11401140
} else {
11411141
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
11421142
" is not enabled")
@@ -1991,22 +1991,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
19911991
"(6, 50.0, cast('2023-02-01' as timestamp))")
19921992

19931993
Seq(true, false).foreach { pushdownValues =>
1994-
Seq(true, false).foreach { partiallyClustered =>
1995-
withSQLConf(
1996-
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
1997-
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
1998-
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
1999-
-> partiallyClustered.toString,
2000-
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
2001-
val df = createJoinTestDF(Seq("id" -> "item_id"))
2002-
val shuffles = collectShuffles(df.queryExecution.executedPlan)
2003-
assert(shuffles.size == 1, "SPJ should be triggered")
2004-
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
2005-
Row(1, "aa", 30.0, 89.0),
2006-
Row(1, "aa", 40.0, 42.0),
2007-
Row(1, "aa", 40.0, 89.0),
2008-
Row(3, "bb", 10.0, 19.5)))
2009-
}
1994+
withSQLConf(
1995+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
1996+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
1997+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
1998+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
1999+
val df = createJoinTestDF(Seq("id" -> "item_id"))
2000+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
2001+
assert(shuffles.size == 1, "SPJ should be triggered")
2002+
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
2003+
Row(1, "aa", 30.0, 89.0),
2004+
Row(1, "aa", 40.0, 42.0),
2005+
Row(1, "aa", 40.0, 89.0),
2006+
Row(3, "bb", 10.0, 19.5)))
20102007
}
20112008
}
20122009
}
@@ -2052,4 +2049,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
20522049
}
20532050
}
20542051
}
2052+
2053+
test("SPARK-48012: one-side shuffle with partition transforms") {
2054+
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
2055+
val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))
2056+
2057+
Seq(items_partitions, items_partitions2).foreach { partition =>
2058+
catalog.clearTables()
2059+
2060+
createTable(items, itemsColumns, partition)
2061+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2062+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2063+
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
2064+
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
2065+
"(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
2066+
"(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
2067+
"(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")
2068+
2069+
createTable(purchases, purchasesColumns, Array.empty)
2070+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2071+
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
2072+
"(2, 10.7, cast('2020-01-01' as timestamp))," +
2073+
"(3, 19.5, cast('2020-02-01' as timestamp))," +
2074+
"(4, 56.5, cast('2020-02-01' as timestamp))")
2075+
2076+
withSQLConf(
2077+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
2078+
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
2079+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
2080+
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")
2081+
2082+
checkAnswer(df, Seq(
2083+
Row(1, "bb", 30.0, 42.0),
2084+
Row(1, "aa", 40.0, 42.0),
2085+
Row(4, "ee", 15.5, 56.5)))
2086+
}
2087+
}
2088+
}
2089+
2090+
test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") {
2091+
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
2092+
createTable(items, itemsColumns, items_partitions)
2093+
2094+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2095+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2096+
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
2097+
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")
2098+
2099+
createTable(purchases, purchasesColumns, Array.empty)
2100+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2101+
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
2102+
"(2, 10.7, cast('2020-01-01' as timestamp))")
2103+
2104+
Seq(true, false).foreach { pushDown => {
2105+
withSQLConf(
2106+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
2107+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
2108+
pushDown.toString) {
2109+
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
2110+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
2111+
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")
2112+
2113+
checkAnswer(df, Seq(
2114+
Row(1, "bb", 30.0, 42.0),
2115+
Row(1, "aa", 40.0, 42.0)))
2116+
}
2117+
}
2118+
}
2119+
}
2120+
2121+
test("SPARK-48012: one-side shuffle with partition transforms " +
2122+
"with fewer join keys than partition kes") {
2123+
val items_partitions = Array(bucket(2, "id"), identity("name"))
2124+
createTable(items, itemsColumns, items_partitions)
2125+
2126+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
2127+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
2128+
"(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
2129+
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
2130+
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
2131+
2132+
createTable(purchases, purchasesColumns, Array.empty)
2133+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
2134+
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
2135+
"(1, 89.0, cast('2020-01-03' as timestamp)), " +
2136+
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
2137+
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
2138+
"(6, 50.0, cast('2023-02-01' as timestamp))")
2139+
2140+
withSQLConf(
2141+
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
2142+
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
2143+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
2144+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
2145+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
2146+
val df = createJoinTestDF(Seq("id" -> "item_id"))
2147+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
2148+
assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" +
2149+
"less join keys than partition keys for now.")
2150+
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
2151+
Row(1, "aa", 30.0, 89.0),
2152+
Row(1, "aa", 40.0, 42.0),
2153+
Row(1, "aa", 40.0, 89.0),
2154+
Row(3, "bb", 10.0, 19.5)))
2155+
}
2156+
}
20552157
}

sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
*/
1717
package org.apache.spark.sql.connector.catalog.functions
1818

19-
import java.sql.Timestamp
19+
import java.time.{Instant, LocalDate, ZoneId}
20+
import java.time.temporal.ChronoUnit
2021

2122
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2224
import org.apache.spark.sql.types._
2325
import org.apache.spark.unsafe.types.UTF8String
2426

@@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] {
4446
override def name(): String = "years"
4547
override def canonicalName(): String = name()
4648

47-
def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900
49+
val UTC: ZoneId = ZoneId.of("UTC")
50+
val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate
51+
52+
def invoke(ts: Long): Long = {
53+
val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate
54+
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
55+
}
4856
}
4957

5058
object DaysFunction extends BoundFunction {

0 commit comments

Comments
 (0)