Skip to content

Commit 14f8c3e

Browse files
committed
Add SQL tests for metadata column projection.
1 parent 0ebbeea commit 14f8c3e

3 files changed

Lines changed: 112 additions & 13 deletions

File tree

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataColumn.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,16 @@ default boolean isNullable() {
4343
*
4444
* @return a documentation String
4545
*/
46-
String comment();
46+
default String comment() {
47+
return null;
48+
}
4749

4850
/**
4951
* The {@link Transform} used to produce this metadata column from data rows, or null.
5052
*
5153
* @return a {@link Transform} used to produce the column's values, or null if there isn't one
5254
*/
53-
Transform transform();
55+
default Transform transform() {
56+
return null;
57+
}
5458
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@ import scala.collection.mutable
2727
import org.scalatest.Assertions._
2828

2929
import org.apache.spark.sql.catalyst.InternalRow
30+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
3031
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3132
import org.apache.spark.sql.connector.catalog._
3233
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
3334
import org.apache.spark.sql.connector.read._
3435
import org.apache.spark.sql.connector.write._
3536
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
3637
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
37-
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
38+
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampType}
3839
import org.apache.spark.sql.util.CaseInsensitiveStringMap
40+
import org.apache.spark.unsafe.types.UTF8String
3941

4042
/**
4143
* A simple in-memory table. Rows are stored as a buffered group produced by each output task.
@@ -45,7 +47,24 @@ class InMemoryTable(
4547
val schema: StructType,
4648
override val partitioning: Array[Transform],
4749
override val properties: util.Map[String, String])
48-
extends Table with SupportsRead with SupportsWrite with SupportsDelete {
50+
extends Table with SupportsRead with SupportsWrite with SupportsDelete
51+
with SupportsMetadataColumns {
52+
53+
private object PartitionKeyColumn extends MetadataColumn {
54+
override def name: String = "_partition"
55+
override def dataType: DataType = StringType
56+
override def comment: String = "Partition key used to store the row"
57+
}
58+
59+
private object IndexColumn extends MetadataColumn {
60+
override def name: String = "index"
61+
override def dataType: DataType = StringType
62+
override def comment: String = "Metadata column used to conflict with a data column"
63+
}
64+
65+
// purposely exposes a metadata column that conflicts with a data column in some tests
66+
override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
67+
private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name)
4968

5069
private val allowUnsupportedTransforms =
5170
properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
@@ -138,7 +157,7 @@ class InMemoryTable(
138157
val key = getKey(row)
139158
dataMap += dataMap.get(key)
140159
.map(key -> _.withRow(row))
141-
.getOrElse(key -> new BufferedRows().withRow(row))
160+
.getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row))
142161
})
143162
this
144163
}
@@ -152,17 +171,38 @@ class InMemoryTable(
152171
TableCapability.TRUNCATE).asJava
153172

154173
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
155-
() => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]))
174+
new InMemoryScanBuilder(schema)
175+
}
176+
177+
class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
178+
with SupportsPushDownRequiredColumns {
179+
private var schema: StructType = tableSchema
180+
181+
override def build: Scan =
182+
new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema)
183+
184+
override def pruneColumns(requiredSchema: StructType): Unit = {
185+
// if metadata columns are projected, return the table schema and metadata columns
186+
val hasMetadataColumns = requiredSchema.map(_.name).exists(metadataColumnNames.contains)
187+
if (hasMetadataColumns) {
188+
schema = StructType(tableSchema ++ metadataColumnNames
189+
.flatMap(name => metadataColumns.find(_.name == name))
190+
.map(col => StructField(col.name, col.dataType, col.isNullable)))
191+
}
192+
}
156193
}
157194

158-
class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch {
195+
class InMemoryBatchScan(data: Array[InputPartition], schema: StructType) extends Scan with Batch {
159196
override def readSchema(): StructType = schema
160197

161198
override def toBatch: Batch = this
162199

163200
override def planInputPartitions(): Array[InputPartition] = data
164201

165-
override def createReaderFactory(): PartitionReaderFactory = BufferedRowsReaderFactory
202+
override def createReaderFactory(): PartitionReaderFactory = {
203+
val metadataColumns = schema.map(_.name).filter(metadataColumnNames.contains)
204+
new BufferedRowsReaderFactory(metadataColumns)
205+
}
166206
}
167207

168208
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
@@ -332,7 +372,8 @@ object InMemoryTable {
332372
}
333373
}
334374

335-
class BufferedRows extends WriterCommitMessage with InputPartition with Serializable {
375+
class BufferedRows(
376+
val key: String = "") extends WriterCommitMessage with InputPartition with Serializable {
336377
val rows = new mutable.ArrayBuffer[InternalRow]()
337378

338379
def withRow(row: InternalRow): BufferedRows = {
@@ -341,21 +382,32 @@ class BufferedRows extends WriterCommitMessage with InputPartition with Serializ
341382
}
342383
}
343384

344-
private object BufferedRowsReaderFactory extends PartitionReaderFactory {
385+
private class BufferedRowsReaderFactory(
386+
metadataColumns: Seq[String]) extends PartitionReaderFactory {
345387
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
346-
new BufferedRowsReader(partition.asInstanceOf[BufferedRows])
388+
new BufferedRowsReader(partition.asInstanceOf[BufferedRows], metadataColumns)
347389
}
348390
}
349391

350-
private class BufferedRowsReader(partition: BufferedRows) extends PartitionReader[InternalRow] {
392+
private class BufferedRowsReader(
393+
partition: BufferedRows,
394+
metadataColumns: Seq[String]) extends PartitionReader[InternalRow] {
395+
private def addMetadata(row: InternalRow): InternalRow = {
396+
val metadataRow = new GenericInternalRow(metadataColumns.map {
397+
case "index" => index
398+
case "_partition" => UTF8String.fromString(partition.key)
399+
}.toArray)
400+
new JoinedRow(row, metadataRow)
401+
}
402+
351403
private var index: Int = -1
352404

353405
override def next(): Boolean = {
354406
index += 1
355407
index < partition.rows.length
356408
}
357409

358-
override def get(): InternalRow = partition.rows(index)
410+
override def get(): InternalRow = addMetadata(partition.rows(index))
359411

360412
override def close(): Unit = {}
361413
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ class DataSourceV2SQLSuite
161161
Array("# Partitioning", "", ""),
162162
Array("Part 0", "id", ""),
163163
Array("", "", ""),
164+
Array("# Metadata Columns", "", ""),
165+
Array("index", "string", "Metadata column used to conflict with a data column"),
166+
Array("_partition", "string", "Partition key used to store the row"),
167+
Array("", "", ""),
164168
Array("# Detailed Table Information", "", ""),
165169
Array("Name", "testcat.table_name", ""),
166170
Array("Comment", "this is a test table", ""),
@@ -2583,6 +2587,45 @@ class DataSourceV2SQLSuite
25832587
}
25842588
}
25852589

2590+
test("SPARK-31255: Project a metadata column") {
2591+
val t1 = s"${catalogAndNamespace}table"
2592+
withTable(t1) {
2593+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2594+
"PARTITIONED BY (bucket(4, id), id)")
2595+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2596+
2597+
checkAnswer(
2598+
spark.sql(s"SELECT id, data, _partition FROM $t1"),
2599+
Seq(Row(1, "a", "3/1"), Row(2, "b", "2/2"), Row(3, "c", "2/3")))
2600+
}
2601+
}
2602+
2603+
test("SPARK-31255: Projects data column when metadata column has the same name") {
2604+
val t1 = s"${catalogAndNamespace}table"
2605+
withTable(t1) {
2606+
sql(s"CREATE TABLE $t1 (index bigint, data string) USING $v2Format " +
2607+
"PARTITIONED BY (bucket(4, index), index)")
2608+
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")
2609+
2610+
checkAnswer(
2611+
spark.sql(s"SELECT index, data, _partition FROM $t1"),
2612+
Seq(Row(3, "c", "2/3"), Row(2, "b", "2/2"), Row(1, "a", "3/1")))
2613+
}
2614+
}
2615+
2616+
test("SPARK-31255: * expansion does not include metadata columns") {
2617+
val t1 = s"${catalogAndNamespace}table"
2618+
withTable(t1) {
2619+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2620+
"PARTITIONED BY (bucket(4, id), id)")
2621+
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")
2622+
2623+
checkAnswer(
2624+
spark.sql(s"SELECT * FROM $t1"),
2625+
Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
2626+
}
2627+
}
2628+
25862629
private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
25872630
val e = intercept[AnalysisException] {
25882631
sql(s"$sqlCommand $sqlParams")

0 commit comments

Comments
 (0)