Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.connector.expressions.{IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
Expand All @@ -37,7 +37,7 @@ private[sql] object CatalogV2Implicits {
}

implicit class BucketSpecHelper(spec: BucketSpec) {
def asTransform: BucketTransform = {
def asTransform: Transform = {
val references = spec.bucketColumnNames.map(col => reference(Seq(col)))
if (spec.sortColumnNames.nonEmpty) {
val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ private[sql] object LogicalExpressions {
def bucket(
numBuckets: Int,
references: Array[NamedReference],
sortedCols: Array[NamedReference]): BucketTransform =
BucketTransform(literal(numBuckets, IntegerType), references, sortedCols)
sortedCols: Array[NamedReference]): SortedBucketTransform =
SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols)

def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference)

Expand Down Expand Up @@ -101,8 +101,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R

private[sql] final case class BucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {
columns: Seq[NamedReference]) extends RewritableTransform {

override val name: String = "bucket"

Expand All @@ -112,26 +111,21 @@ private[sql] final case class BucketTransform(

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def toString: String =
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
}
override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"

override def toString: String = describe

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
}

private[sql] object BucketTransform {
def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] =
expr match {
def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = expr match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use this unapply?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was introduced in #30706 but doesn't seem to be used. I will remove for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments on unapply (if it is really used) about what it returns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why def unapply(expr: Expression) addresses only BucketTransform but def unapply(transform: Transform) addresses both sorted_bucket and bucket?

case transform: Transform =>
transform match {
case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) =>
Some((n, FieldReference(parts), FieldReference(sortCols)))
case BucketTransform(n, FieldReference(parts), _) =>
Some((n, FieldReference(parts), FieldReference(Seq.empty[String])))
case _ =>
None
}
Expand All @@ -141,20 +135,47 @@ private[sql] object BucketTransform {

def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] =
transform match {
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]),
Ref(sortCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(sortCols)))
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(Seq.empty[String])))
case NamedTransform("sorted_bucket", arguments) =>
var index: Int = -1
var posOfLit: Int = -1
var numOfBucket: Int = -1
arguments.foreach {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can do arguments.zipWithIndex.foreach, so that it's much easier to get posOfLit.

case Lit(value: Int, IntegerType) =>
numOfBucket = value
index = index + 1
posOfLit = index
case _ => index = index + 1
}
Some(numOfBucket, FieldReference(arguments.take(posOfLit).map(_.describe)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we know that the arguments of bucket/sorted_bucketare all NamedReference, how about arguments.take(posOfLit).map(_.asInstanceOf[NamedReference])?

FieldReference(arguments.drop(posOfLit + 1).map(_.describe)))
case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(seq: Seq[String]))) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't seem to be right. It only matches bucket with a single bucket column.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems somehow only a single column is supported in BucketTransform. Will fix this.

Some(value, FieldReference(seq), FieldReference(Seq.empty[String]))
case _ =>
None
}
}

private[sql] final case class SortedBucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {

override val name: String = "sorted_bucket"

override def references: Array[NamedReference] = {
arguments.collect { case named: NamedReference => named }
}

override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns

override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences.take(columns.length),
sortedColumns = newReferences.drop(columns.length))
}
}

private[sql] final case class ApplyTransform(
name: String,
args: Seq[Expression]) extends Transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class InMemoryTable(
case _: DaysTransform =>
case _: HoursTransform =>
case _: BucketTransform =>
case _: SortedBucketTransform =>
case t if !allowUnsupportedTransforms =>
throw new IllegalArgumentException(s"Transform $t is not a supported transform")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.types.DataType

class TransformExtractorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite {
// expected
}
}

test("Sorted Bucket extractor") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val sortedBucketTransform = new Transform {
override def name: String = "sorted_bucket"
override def references: Array[NamedReference] = col ++ sortedCol
override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol
override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " +
s"${sortedCol(0).describe} ${sortedCol(1).describe})"
}

sortedBucketTransform match {
case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
assert(sorted === Seq("c", "d"))
case _ =>
fail("Did not match BucketTransform extractor")
}
}

test("test bucket") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val bucketTransform = bucket(16, col)
val reference1 = bucketTransform.references
assert(reference1.length == 2)
assert(reference1(0).fieldNames() === Seq("a"))
assert(reference1(1).fieldNames() === Seq("b"))
val arguments1 = bucketTransform.arguments
assert(arguments1.length == 3)
assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
val copied1 = bucketTransform.withReferences(reference1)
assert(copied1.equals(bucketTransform))

val sortedBucketTransform = bucket(16, col, sortedCol)
val reference2 = sortedBucketTransform.references
assert(reference2.length == 4)
assert(reference2(0).fieldNames() === Seq("a"))
assert(reference2(1).fieldNames() === Seq("b"))
assert(reference2(2).fieldNames() === Seq("c"))
assert(reference2(3).fieldNames() === Seq("d"))
val arguments2 = sortedBucketTransform.arguments
assert(arguments2.length == 5)
assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c"))
assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d"))
val copied2 = sortedBucketTransform.withReferences(reference2)
assert(copied2.equals(sortedBucketTransform))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,11 @@ private[sql] object V2SessionCatalog {
identityCols += col

case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) =>
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil))
if (sortCol.isEmpty) {
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil))
} else {
bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil))
}

case transform =>
throw QueryExecutionErrors.unsupportedPartitionTransformError(transform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1566,18 +1566,21 @@ class DataSourceV2SQLSuite
test("create table using - with sorted bucket") {
val identifier = "testcat.table_name"
withTable(identifier) {
sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" +
s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changing this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure multiple columns/sortedColumns work ok.

val table = getTableMetadata(identifier)
sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" +
s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS")
val describe = spark.sql(s"DESCRIBE $identifier")
val part1 = describe
.filter("col_name = 'Part 0'")
.select("data_type").head.getString(0)
assert(part1 === "c")
assert(part1 === "a")
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, b, a)")
assert(part2 === "b")
val part3 = describe
.filter("col_name = 'Part 2'")
.select("data_type").head.getString(0)
assert(part3 === "sorted_bucket(c, d, 4, e, f)")
}
}

Expand Down