Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.connector.expressions.{IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
Expand All @@ -37,7 +37,7 @@ private[sql] object CatalogV2Implicits {
}

implicit class BucketSpecHelper(spec: BucketSpec) {
def asTransform: BucketTransform = {
def asTransform: Transform = {
val references = spec.bucketColumnNames.map(col => reference(Seq(col)))
if (spec.sortColumnNames.nonEmpty) {
val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
Expand Down Expand Up @@ -48,8 +49,8 @@ private[sql] object LogicalExpressions {
def bucket(
numBuckets: Int,
references: Array[NamedReference],
sortedCols: Array[NamedReference]): BucketTransform =
BucketTransform(literal(numBuckets, IntegerType), references, sortedCols)
sortedCols: Array[NamedReference]): SortedBucketTransform =
SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols)

def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference)

Expand Down Expand Up @@ -101,8 +102,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R

private[sql] final case class BucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {
columns: Seq[NamedReference]) extends RewritableTransform {

override val name: String = "bucket"

Expand All @@ -112,46 +112,62 @@ private[sql] final case class BucketTransform(

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def toString: String =
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
}
override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"

override def toString: String = describe

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
}

private[sql] object BucketTransform {
def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] =
expr match {
case transform: Transform =>
def unapply(transform: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] =
transform match {
case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) =>
Some((n, FieldReference(parts), FieldReference(sortCols)))
case NamedTransform("sorted_bucket", arguments) =>
var posOfLit: Int = -1
var numOfBucket: Int = -1
arguments.zipWithIndex.foreach {
case (Lit(value: Int, IntegerType), i) =>
numOfBucket = value
posOfLit = i
case _ =>
None
}
Some(numOfBucket, arguments.take(posOfLit).map(_.asInstanceOf[NamedReference]),
arguments.drop(posOfLit + 1).map(_.asInstanceOf[NamedReference]))
case NamedTransform("bucket", arguments) =>
var numOfBucket: Int = -1
arguments(0) match {
case Lit(value: Int, IntegerType) =>
numOfBucket = value
case _ => throw new SparkException("The first element in BucketTransform arguments " +
"should be an Integer Literal.")
}
Some(numOfBucket, arguments.drop(1).map(_.asInstanceOf[NamedReference]),
Seq.empty[FieldReference])
case _ =>
None
}
}

def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] =
transform match {
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]),
Ref(sortCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(sortCols)))
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(Seq.empty[String])))
case _ =>
None
private[sql] final case class SortedBucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {

override val name: String = "sorted_bucket"

override def references: Array[NamedReference] = {
arguments.collect { case named: NamedReference => named }
}

override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns

override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences.take(columns.length),
sortedColumns = newReferences.drop(columns.length))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class InMemoryTable(
case _: DaysTransform =>
case _: HoursTransform =>
case _: BucketTransform =>
case _: SortedBucketTransform =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}
Expand Down Expand Up @@ -161,10 +162,15 @@ class InMemoryTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
case BucketTransform(numBuckets, ref, _) =>
val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row)
val valueHashCode = if (value == null) 0 else value.hashCode
((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets
case BucketTransform(numBuckets, cols, _) =>
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.types.DataType

class TransformExtractorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -139,9 +140,9 @@ class TransformExtractorSuite extends SparkFunSuite {
}

bucketTransform match {
case BucketTransform(numBuckets, FieldReference(seq), _) =>
case BucketTransform(numBuckets, cols, _) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
assert(cols(0).fieldNames === Seq("a", "b"))
case _ =>
fail("Did not match BucketTransform extractor")
}
Expand All @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite {
// expected
}
}

test("Sorted Bucket extractor") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val sortedBucketTransform = new Transform {
override def name: String = "sorted_bucket"
override def references: Array[NamedReference] = col ++ sortedCol
override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol
override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " +
s"${sortedCol(0).describe} ${sortedCol(1).describe})"
}

sortedBucketTransform match {
case BucketTransform(numBuckets, cols, sortCols) =>
assert(numBuckets === 16)
assert(cols.flatMap(c => c.fieldNames()) === Seq("a", "b"))
assert(sortCols.flatMap(c => c.fieldNames()) === Seq("c", "d"))
case _ =>
fail("Did not match BucketTransform extractor")
}
}

test("test bucket") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val bucketTransform = bucket(16, col)
val reference1 = bucketTransform.references
assert(reference1.length == 2)
assert(reference1(0).fieldNames() === Seq("a"))
assert(reference1(1).fieldNames() === Seq("b"))
val arguments1 = bucketTransform.arguments
assert(arguments1.length == 3)
assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
val copied1 = bucketTransform.withReferences(reference1)
assert(copied1.equals(bucketTransform))

val sortedBucketTransform = bucket(16, col, sortedCol)
val reference2 = sortedBucketTransform.references
assert(reference2.length == 4)
assert(reference2(0).fieldNames() === Seq("a"))
assert(reference2(1).fieldNames() === Seq("b"))
assert(reference2(2).fieldNames() === Seq("c"))
assert(reference2(3).fieldNames() === Seq("d"))
val arguments2 = sortedBucketTransform.arguments
assert(arguments2.length == 5)
assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c"))
assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d"))
val copied2 = sortedBucketTransform.withReferences(reference2)
assert(copied2.equals(sortedBucketTransform))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,13 @@ private[sql] object V2SessionCatalog {
case IdentityTransform(FieldReference(Seq(col))) =>
identityCols += col

case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) =>
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil))
case BucketTransform(numBuckets, col, sortCol) =>
if (sortCol.isEmpty) {
bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil))
} else {
bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")),
sortCol.map(_.fieldNames.mkString("."))))
}

case transform =>
throw QueryExecutionErrors.unsupportedPartitionTransformError(transform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,12 @@ class DataSourceV2SQLSuite
test("SPARK-36850: CreateTableAsSelect partitions can be specified using " +
"PARTITIONED BY and/or CLUSTERED BY") {
val identifier = "testcat.table_name"
val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
(3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
df.createOrReplaceTempView("source_table")
withTable(identifier) {
spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " +
s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source")
s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
Expand All @@ -421,18 +424,22 @@ class DataSourceV2SQLSuite
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, data)")
assert(part2 === "bucket(4, data1, data2, data3, data4)")
}
}

test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " +
"PARTITIONED BY and/or CLUSTERED BY") {
val identifier = "testcat.table_name"
val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"),
(3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4")
df.createOrReplaceTempView("source_table")
withTable(identifier) {
spark.sql(s"CREATE TABLE $identifier USING foo " +
"AS SELECT id FROM source")
spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " +
s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source")
s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " +
s"AS SELECT * FROM source_table")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
Expand All @@ -441,7 +448,7 @@ class DataSourceV2SQLSuite
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, data)")
assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)")
}
}

Expand Down Expand Up @@ -1566,18 +1573,21 @@ class DataSourceV2SQLSuite
test("create table using - with sorted bucket") {
val identifier = "testcat.table_name"
withTable(identifier) {
sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" +
s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changing this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure multiple columns/sortedColumns work ok.

val table = getTableMetadata(identifier)
sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" +
s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
.select("data_type").head.getString(0)
assert(part1 === "c")
assert(part1 === "a")
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, b, a)")
assert(part2 === "b")
val part3 = describe
.filter("col_name = 'Part 2'")
.select("data_type").head.getString(0)
assert(part3 === "sorted_bucket(c, d, 4, e, f)")
}
}

Expand Down