Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,29 @@ private[sql] final case class BucketTransform(
columns: Seq[NamedReference],
sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {

override val name: String = "bucket"
override val name: String = if (sortedColumns.nonEmpty) "sortedBucket" else "bucket"
Copy link
Contributor

@cloud-fan cloud-fan Jan 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a new class SortedBucketTransform to be clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a new class SortedBucketTransform. Thanks!


override def references: Array[NamedReference] = {
arguments.collect { case named: NamedReference => named }
}

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def toString: String =
override def arguments: Array[Expression] = {
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
(columns.toArray :+ numBuckets) ++ sortedColumns
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
numBuckets +: columns.toArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we keep consistent order of columns and numBuckets for two cases in arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are sortedColumn, we need numBuckets in between of columns and sortedColumns, because we need a way to figure out which elements in the array are for columns, and which elements are for sortedColumns.

}
}

override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
if (sortedColumns.isEmpty) {
this.copy(columns = newReferences)
} else {
val splits = newReferences.grouped(columns.length).toList
this.copy(columns = splits(0), sortedColumns = splits(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it: columns = newReferences.take(columns.length), sortedColumns = newReferences.drop(columns.length)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Thanks!

}
}
}

Expand All @@ -140,15 +145,22 @@ private[sql] object BucketTransform {
}

def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] =
transform match {
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]),
Ref(sortCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(sortCols)))
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(partCols: Seq[String]))) =>
transform match {
case NamedTransform("sortedBucket", s) =>
var index: Int = -1
var posOfLit: Int = -1
var numOfBucket: Int = -1
s.foreach {
case Lit(value: Int, IntegerType) =>
numOfBucket = value
index = index + 1
posOfLit = index
case _ => index = index + 1
}
val splits = s.splitAt(posOfLit)
Some(numOfBucket, FieldReference(
splits._1.map(_.describe)), FieldReference(splits._2.drop(1).map(_.describe)))
case NamedTransform("bucket", Seq(Lit(value: Int, IntegerType), Ref(partCols: Seq[String]))) =>
Some((value, FieldReference(partCols), FieldReference(Seq.empty[String])))
case _ =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.types.DataType

class TransformExtractorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite {
// expected
}
}

test("Sorted Bucket extractor") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val sortedBucketTransform = new Transform {
override def name: String = "sortedBucket"
override def references: Array[NamedReference] = col ++ sortedCol
override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol
override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " +
s"${sortedCol(0).describe} ${sortedCol(1).describe})"
}

sortedBucketTransform match {
case BucketTransform(numBuckets, FieldReference(seq), FieldReference(sorted)) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
assert(sorted === Seq("c", "d"))
case _ =>
fail("Did not match BucketTransform extractor")
}
}

test("test bucket") {
val col = Array(ref("a"), ref("b"))
val sortedCol = Array(ref("c"), ref("d"))

val bucketTransform = bucket(16, col)
val reference1 = bucketTransform.references
assert(reference1.length == 2)
assert(reference1(0).fieldNames() === Seq("a"))
assert(reference1(1).fieldNames() === Seq("b"))
val arguments1 = bucketTransform.arguments
assert(arguments1.length == 3)
assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
val copied1 = bucketTransform.withReferences(reference1)
assert(copied1.equals(bucketTransform))

val sortedBucketTransform = bucket(16, col, sortedCol)
val reference2 = sortedBucketTransform.references
assert(reference2.length == 4)
assert(reference2(0).fieldNames() === Seq("a"))
assert(reference2(1).fieldNames() === Seq("b"))
assert(reference2(2).fieldNames() === Seq("c"))
assert(reference2(3).fieldNames() === Seq("d"))
val arguments2 = sortedBucketTransform.arguments
assert(arguments2.length == 5)
assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a"))
assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b"))
assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16)
assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c"))
assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d"))
val copied2 = sortedBucketTransform.withReferences(reference2)
assert(copied2.equals(sortedBucketTransform))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1566,18 +1566,22 @@ class DataSourceV2SQLSuite
test("create table using - with sorted bucket") {
val identifier = "testcat.table_name"
withTable(identifier) {
sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" +
s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changing this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure multiple columns/sortedColumns work ok.

val table = getTableMetadata(identifier)
sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" +
s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS")
val describe = spark.sql(s"DESCRIBE $identifier")
describe.show(false)
val part1 = describe
.filter("col_name = 'Part 0'")
.select("data_type").head.getString(0)
assert(part1 === "c")
assert(part1 === "a")
val part2 = describe
.filter("col_name = 'Part 1'")
.select("data_type").head.getString(0)
assert(part2 === "bucket(4, b, a)")
assert(part2 === "b")
val part3 = describe
.filter("col_name = 'Part 2'")
.select("data_type").head.getString(0)
assert(part3 === "sortedBucket(c, d, 4, e, f)")
}
}

Expand Down