@@ -27,15 +27,17 @@ import scala.collection.mutable
2727import org .scalatest .Assertions ._
2828
2929import org .apache .spark .sql .catalyst .InternalRow
30+ import org .apache .spark .sql .catalyst .expressions .{GenericInternalRow , JoinedRow }
3031import org .apache .spark .sql .catalyst .util .DateTimeUtils
3132import org .apache .spark .sql .connector .catalog ._
3233import org .apache .spark .sql .connector .expressions .{BucketTransform , DaysTransform , HoursTransform , IdentityTransform , MonthsTransform , Transform , YearsTransform }
3334import org .apache .spark .sql .connector .read ._
3435import org .apache .spark .sql .connector .write ._
3536import org .apache .spark .sql .connector .write .streaming .{StreamingDataWriterFactory , StreamingWrite }
3637import org .apache .spark .sql .sources .{And , EqualTo , Filter , IsNotNull }
37- import org .apache .spark .sql .types .{DataType , DateType , StructType , TimestampType }
38+ import org .apache .spark .sql .types .{DataType , DateType , StringType , StructField , StructType , TimestampType }
3839import org .apache .spark .sql .util .CaseInsensitiveStringMap
40+ import org .apache .spark .unsafe .types .UTF8String
3941
4042/**
4143 * A simple in-memory table. Rows are stored as a buffered group produced by each output task.
@@ -45,7 +47,24 @@ class InMemoryTable(
4547 val schema : StructType ,
4648 override val partitioning : Array [Transform ],
4749 override val properties : util.Map [String , String ])
48- extends Table with SupportsRead with SupportsWrite with SupportsDelete {
50+ extends Table with SupportsRead with SupportsWrite with SupportsDelete
51+ with SupportsMetadataColumns {
52+
53+ private object PartitionKeyColumn extends MetadataColumn {
54+ override def name : String = " _partition"
55+ override def dataType : DataType = StringType
56+ override def comment : String = " Partition key used to store the row"
57+ }
58+
59+ private object IndexColumn extends MetadataColumn {
60+ override def name : String = " index"
61+ override def dataType : DataType = StringType
62+ override def comment : String = " Metadata column used to conflict with a data column"
63+ }
64+
65+ // purposely exposes a metadata column that conflicts with a data column in some tests
66+ override val metadataColumns : Array [MetadataColumn ] = Array (IndexColumn , PartitionKeyColumn )
67+ private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name)
4968
5069 private val allowUnsupportedTransforms =
5170 properties.getOrDefault(" allow-unsupported-transforms" , " false" ).toBoolean
@@ -138,7 +157,7 @@ class InMemoryTable(
138157 val key = getKey(row)
139158 dataMap += dataMap.get(key)
140159 .map(key -> _.withRow(row))
141- .getOrElse(key -> new BufferedRows ().withRow(row))
160+ .getOrElse(key -> new BufferedRows (key.toArray.mkString( " / " ) ).withRow(row))
142161 })
143162 this
144163 }
@@ -152,17 +171,38 @@ class InMemoryTable(
152171 TableCapability .TRUNCATE ).asJava
153172
154173 override def newScanBuilder (options : CaseInsensitiveStringMap ): ScanBuilder = {
155- () => new InMemoryBatchScan (data.map(_.asInstanceOf [InputPartition ]))
174+ new InMemoryScanBuilder (schema)
175+ }
176+
177+ class InMemoryScanBuilder (tableSchema : StructType ) extends ScanBuilder
178+ with SupportsPushDownRequiredColumns {
179+ private var schema : StructType = tableSchema
180+
181+ override def build : Scan =
182+ new InMemoryBatchScan (data.map(_.asInstanceOf [InputPartition ]), schema)
183+
184+ override def pruneColumns (requiredSchema : StructType ): Unit = {
185+ // if metadata columns are projected, return the table schema and metadata columns
186+ val hasMetadataColumns = requiredSchema.map(_.name).exists(metadataColumnNames.contains)
187+ if (hasMetadataColumns) {
188+ schema = StructType (tableSchema ++ metadataColumnNames
189+ .flatMap(name => metadataColumns.find(_.name == name))
190+ .map(col => StructField (col.name, col.dataType, col.isNullable)))
191+ }
192+ }
156193 }
157194
158- class InMemoryBatchScan (data : Array [InputPartition ]) extends Scan with Batch {
195+ class InMemoryBatchScan (data : Array [InputPartition ], schema : StructType ) extends Scan with Batch {
159196 override def readSchema (): StructType = schema
160197
161198 override def toBatch : Batch = this
162199
163200 override def planInputPartitions (): Array [InputPartition ] = data
164201
165- override def createReaderFactory (): PartitionReaderFactory = BufferedRowsReaderFactory
202+ override def createReaderFactory (): PartitionReaderFactory = {
203+ val metadataColumns = schema.map(_.name).filter(metadataColumnNames.contains)
204+ new BufferedRowsReaderFactory (metadataColumns)
205+ }
166206 }
167207
168208 override def newWriteBuilder (info : LogicalWriteInfo ): WriteBuilder = {
@@ -332,7 +372,8 @@ object InMemoryTable {
332372 }
333373}
334374
335- class BufferedRows extends WriterCommitMessage with InputPartition with Serializable {
375+ class BufferedRows (
376+ val key : String = " " ) extends WriterCommitMessage with InputPartition with Serializable {
336377 val rows = new mutable.ArrayBuffer [InternalRow ]()
337378
338379 def withRow (row : InternalRow ): BufferedRows = {
@@ -341,21 +382,32 @@ class BufferedRows extends WriterCommitMessage with InputPartition with Serializ
341382 }
342383}
343384
344- private object BufferedRowsReaderFactory extends PartitionReaderFactory {
385+ private class BufferedRowsReaderFactory (
386+ metadataColumns : Seq [String ]) extends PartitionReaderFactory {
345387 override def createReader (partition : InputPartition ): PartitionReader [InternalRow ] = {
346- new BufferedRowsReader (partition.asInstanceOf [BufferedRows ])
388+ new BufferedRowsReader (partition.asInstanceOf [BufferedRows ], metadataColumns )
347389 }
348390}
349391
350- private class BufferedRowsReader (partition : BufferedRows ) extends PartitionReader [InternalRow ] {
392+ private class BufferedRowsReader (
393+ partition : BufferedRows ,
394+ metadataColumns : Seq [String ]) extends PartitionReader [InternalRow ] {
395+ private def addMetadata (row : InternalRow ): InternalRow = {
396+ val metadataRow = new GenericInternalRow (metadataColumns.map {
397+ case " index" => index
398+ case " _partition" => UTF8String .fromString(partition.key)
399+ }.toArray)
400+ new JoinedRow (row, metadataRow)
401+ }
402+
351403 private var index : Int = - 1
352404
353405 override def next (): Boolean = {
354406 index += 1
355407 index < partition.rows.length
356408 }
357409
358- override def get (): InternalRow = partition.rows(index)
410+ override def get (): InternalRow = addMetadata( partition.rows(index) )
359411
360412 override def close (): Unit = {}
361413}
0 commit comments