diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 185a1a2644e2..b5e38659724e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIfNeeded -import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors /** @@ -37,7 +37,7 @@ private[sql] object CatalogV2Implicits { } implicit class BucketSpecHelper(spec: BucketSpec) { - def asTransform: BucketTransform = { + def asTransform: Transform = { val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index e3eab6f6730f..9402bceafbc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.{DataType, IntegerType, StringType} @@ -48,8 +49,8 @@ private[sql] object LogicalExpressions { def bucket( numBuckets: Int, references: Array[NamedReference], - sortedCols: Array[NamedReference]): BucketTransform = - BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + sortedCols: Array[NamedReference]): SortedBucketTransform = + SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols) def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) @@ -101,8 +102,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference], - sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { + columns: Seq[NamedReference]) extends RewritableTransform { override val name: String = "bucket" @@ -112,13 +112,9 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def toString: String = - if (sortedColumns.nonEmpty) { - s"bucket(${arguments.map(_.describe).mkString(", ")}," + - s" ${sortedColumns.map(_.describe).mkString(", ")})" - } else { - s"bucket(${arguments.map(_.describe).mkString(", ")})" - } + override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columns = newReferences) @@ -126,32 +122,52 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = - expr match { - case transform: Transform => + def unapply(transform: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] = transform match { - case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => - Some((n, FieldReference(parts), FieldReference(sortCols))) + case NamedTransform("sorted_bucket", arguments) => + var posOfLit: Int = -1 + var numOfBucket: Int = -1 + arguments.zipWithIndex.foreach { + case (Lit(value: Int, IntegerType), i) => + numOfBucket = value + posOfLit = i case _ => - None } + Some(numOfBucket, arguments.take(posOfLit).map(_.asInstanceOf[NamedReference]), + arguments.drop(posOfLit + 1).map(_.asInstanceOf[NamedReference])) + case NamedTransform("bucket", arguments) => + var numOfBucket: Int = -1 + arguments(0) match { + case Lit(value: Int, IntegerType) => + numOfBucket = value + case _ => throw new SparkException("The first element in BucketTransform arguments " + + "should be an Integer Literal.") + } + Some(numOfBucket, arguments.drop(1).map(_.asInstanceOf[NamedReference]), + Seq.empty[FieldReference]) case _ => None } +} - def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = - transform match { - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]), - Ref(sortCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(sortCols))) - case NamedTransform("bucket", Seq( - Lit(value: Int, IntegerType), - Ref(partCols: Seq[String]))) => - Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) - case _ => - None +private[sql] final case class SortedBucketTransform( + numBuckets: Literal[Int], + columns: Seq[NamedReference], + sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { + + override val name: String = "sorted_bucket" + + override def references: Array[NamedReference] = { + arguments.collect { case named: NamedReference => named } + } + + override def arguments: Array[Expression] = (columns.toArray :+ numBuckets) ++ sortedColumns + + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + this.copy(columns = newReferences.take(columns.length), + sortedColumns = newReferences.drop(columns.length)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index fa8be1b8fa3c..8e5e920d89ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -80,6 +80,7 @@ class InMemoryTable( case _: DaysTransform => case _: HoursTransform => case _: BucketTransform => + case _: SortedBucketTransform => case t if !allowUnsupportedTransforms => throw new IllegalArgumentException(s"Transform $t is not a supported transform") } @@ -161,10 +162,15 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref, _) => - val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) - val valueHashCode = if (value == null) 0 else value.hashCode - ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets + case BucketTransform(numBuckets, cols, _) => + val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) + var valueHashCode = 0 + valueTypePairs.foreach( pair => + if ( pair._1 != null) valueHashCode += pair._1.hashCode() + ) + var dataTypeHashCode = 0 + valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) + ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index b2371ce667ff..54ab1df3fa8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst +import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -139,9 +140,9 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq), _) => + case BucketTransform(numBuckets, cols, _) => assert(numBuckets === 16) - assert(seq === Seq("a", "b")) + assert(cols(0).fieldNames === Seq("a", "b")) case _ => fail("Did not match BucketTransform extractor") } @@ -153,4 +154,61 @@ class TransformExtractorSuite extends SparkFunSuite { // expected } } + + test("Sorted Bucket extractor") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val sortedBucketTransform = new Transform { + override def name: String = "sorted_bucket" + override def references: Array[NamedReference] = col ++ sortedCol + override def arguments: Array[Expression] = (col :+ lit(16)) ++ sortedCol + override def describe: String = s"bucket(16, ${col(0).describe}, ${col(1).describe} " + + s"${sortedCol(0).describe} ${sortedCol(1).describe})" + } + + sortedBucketTransform match { + case BucketTransform(numBuckets, cols, sortCols) => + assert(numBuckets === 16) + assert(cols.flatMap(c => c.fieldNames()) === Seq("a", "b")) + assert(sortCols.flatMap(c => c.fieldNames()) === Seq("c", "d")) + case _ => + fail("Did not match BucketTransform extractor") + } + } + + test("test bucket") { + val col = Array(ref("a"), ref("b")) + val sortedCol = Array(ref("c"), ref("d")) + + val bucketTransform = bucket(16, col) + val reference1 = bucketTransform.references + assert(reference1.length == 2) + assert(reference1(0).fieldNames() === Seq("a")) + assert(reference1(1).fieldNames() === Seq("b")) + val arguments1 = bucketTransform.arguments + assert(arguments1.length == 3) + assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + val copied1 = bucketTransform.withReferences(reference1) + assert(copied1.equals(bucketTransform)) + + val sortedBucketTransform = bucket(16, col, sortedCol) + val reference2 = sortedBucketTransform.references + assert(reference2.length == 4) + assert(reference2(0).fieldNames() === Seq("a")) + assert(reference2(1).fieldNames() === Seq("b")) + assert(reference2(2).fieldNames() === Seq("c")) + assert(reference2(3).fieldNames() === Seq("d")) + val arguments2 = sortedBucketTransform.arguments + assert(arguments2.length == 5) + assert(arguments2(0).asInstanceOf[NamedReference].fieldNames() === Seq("a")) + assert(arguments2(1).asInstanceOf[NamedReference].fieldNames() === Seq("b")) + assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16) + assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c")) + assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d")) + val copied2 = sortedBucketTransform.withReferences(reference2) + assert(copied2.equals(sortedBucketTransform)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index d5547c1f3c1e..f3833f53dcf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -318,8 +318,13 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) + case BucketTransform(numBuckets, col, sortCol) => + if (sortCol.isEmpty) { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), Nil)) + } else { + bucketSpec = Some(BucketSpec(numBuckets, col.map(_.fieldNames.mkString(".")), + sortCol.map(_.fieldNames.mkString(".")))) + } case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 90f9f157b028..ec9360dc55c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -410,9 +410,12 @@ class DataSourceV2SQLSuite test("SPARK-36850: CreateTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2, data3, data4) INTO 4 BUCKETS AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -421,18 +424,22 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "bucket(4, data1, data2, data3, data4)") } } test("SPARK-36850: ReplaceTableAsSelect partitions can be specified using " + "PARTITIONED BY and/or CLUSTERED BY") { val identifier = "testcat.table_name" + val df = spark.createDataFrame(Seq((1L, "a", "a1", "a2", "a3"), (2L, "b", "b1", "b2", "b3"), + (3L, "c", "c1", "c2", "c3"))).toDF("id", "data1", "data2", "data3", "data4") + df.createOrReplaceTempView("source_table") withTable(identifier) { spark.sql(s"CREATE TABLE $identifier USING foo " + "AS SELECT id FROM source") spark.sql(s"REPLACE TABLE $identifier USING foo PARTITIONED BY (id) " + - s"CLUSTERED BY (data) INTO 4 BUCKETS AS SELECT * FROM source") + s"CLUSTERED BY (data1, data2) SORTED by (data3, data4) INTO 4 BUCKETS " + + s"AS SELECT * FROM source_table") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") @@ -441,7 +448,7 @@ class DataSourceV2SQLSuite val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, data)") + assert(part2 === "sorted_bucket(data1, data2, 4, data3, data4)") } } @@ -1566,18 +1573,21 @@ class DataSourceV2SQLSuite test("create table using - with sorted bucket") { val identifier = "testcat.table_name" withTable(identifier) { - sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + - s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") - val table = getTableMetadata(identifier) + sql(s"CREATE TABLE $identifier (a int, b string, c int, d int, e int, f int) USING" + + s" $v2Source PARTITIONED BY (a, b) CLUSTERED BY (c, d) SORTED by (e, f) INTO 4 BUCKETS") val describe = spark.sql(s"DESCRIBE $identifier") val part1 = describe .filter("col_name = 'Part 0'") .select("data_type").head.getString(0) - assert(part1 === "c") + assert(part1 === "a") val part2 = describe .filter("col_name = 'Part 1'") .select("data_type").head.getString(0) - assert(part2 === "bucket(4, b, a)") + assert(part2 === "b") + val part3 = describe + .filter("col_name = 'Part 2'") + .select("data_type").head.getString(0) + assert(part3 === "sorted_bucket(c, d, 4, e, f)") } }