Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -964,11 +964,37 @@ class Analyzer(override val catalogManager: CatalogManager)
* columns are not accidentally selected by *.
*/
object AddMetadataColumns extends Rule[LogicalPlan] {
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._

private def hasMetadataCol(plan: LogicalPlan): Boolean = {
plan.expressions.exists(_.find {
case a: Attribute => a.isMetadataCol
case _ => false
}.isDefined)
}

private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match {
case r: DataSourceV2Relation => r.withMetadataColumns()
case _ => plan.withNewChildren(plan.children.map(addMetadataCol))
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case node if node.resolved && node.children.nonEmpty && node.missingInput.nonEmpty =>
node resolveOperatorsUp {
case rel: DataSourceV2Relation =>
rel.withMetadataColumns()
case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) =>
val inputAttrs = AttributeSet(node.children.flatMap(_.output))
val metaCols = node.expressions.flatMap(_.collect {
case a: Attribute if a.isMetadataCol && !inputAttrs.contains(a) => a
})
if (metaCols.isEmpty) {
node
} else {
val newNode = addMetadataCol(node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No matter how many meta cols we actually refer, we always add all meta cols, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's good enough because:

  1. We guarantee that the outer plan will project out extra columns, so the final output schema won't change
  2. Column pruning will work and eventually the data source doesn't need to produce un-referenced columns.

// We should not change the output schema of the plan. We should project away the extr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extr -> extra

// metadata columns if necessary.
if (newNode.sameOutput(node)) {
newNode
} else {
Project(node.output, newNode)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionSpec, UnresolvedPartitionSpec}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

object DataSourceV2Implicits {
private val METADATA_COL_ATTR_KEY = "__metadata_col"

implicit class TableHelper(table: Table) {
def asReadable: SupportsRead = {
table match {
Expand Down Expand Up @@ -83,7 +85,8 @@ object DataSourceV2Implicits {
implicit class MetadataColumnsHelper(metadata: Array[MetadataColumn]) {
def asStruct: StructType = {
val fields = metadata.map { metaCol =>
val field = StructField(metaCol.name, metaCol.dataType, metaCol.isNullable)
val fieldMeta = new MetadataBuilder().putBoolean(METADATA_COL_ATTR_KEY, true).build()
Copy link
Contributor

@imback82 imback82 Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can be created outside the loop (or even at object level - new metadata related object to go with METADATA_COL_ATTR_KEY).

val field = StructField(metaCol.name, metaCol.dataType, metaCol.isNullable, fieldMeta)
Option(metaCol.comment).map(field.withComment).getOrElse(field)
}
StructType(fields)
Expand All @@ -92,6 +95,11 @@ object DataSourceV2Implicits {
def toAttributes: Seq[AttributeReference] = asStruct.toAttributes
}

implicit class MetadataColumnHelper(attr: Attribute) {
def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) &&
attr.metadata.getBoolean(METADATA_COL_ATTR_KEY)
}

implicit class OptionsHelper(options: Map[String, String]) {
def asOptions: CaseInsensitiveStringMap = {
new CaseInsensitiveStringMap(options.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.sources.{And, EqualNullSafe, EqualTo, Filter, IsNotNull, IsNull}
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DateType, IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -61,7 +61,7 @@ class InMemoryTable(

private object IndexColumn extends MetadataColumn {
override def name: String = "index"
override def dataType: DataType = StringType
override def dataType: DataType = IntegerType
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual data is int.

override def comment: String = "Metadata column used to conflict with a data column"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DataSourceV2SQLSuite
Array("Part 0", "id", ""),
Array("", "", ""),
Array("# Metadata Columns", "", ""),
Array("index", "string", "Metadata column used to conflict with a data column"),
Array("index", "int", "Metadata column used to conflict with a data column"),
Array("_partition", "string", "Partition key used to store the row"),
Array("", "", ""),
Array("# Detailed Table Information", "", ""),
Expand Down Expand Up @@ -2443,9 +2443,12 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

checkAnswer(
spark.sql(s"SELECT id, data, _partition FROM $t1"),
Seq(Row(1, "a", "3/1"), Row(2, "b", "0/2"), Row(3, "c", "1/3")))
val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1")
val dfQuery = spark.table(t1).select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

Expand All @@ -2456,9 +2459,12 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, index), index)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

checkAnswer(
spark.sql(s"SELECT index, data, _partition FROM $t1"),
Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
val sqlQuery = spark.sql(s"SELECT index, data, _partition FROM $t1")
val dfQuery = spark.table(t1).select("index", "data", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
}
}
}

Expand All @@ -2469,9 +2475,27 @@ class DataSourceV2SQLSuite
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (3, 'c'), (2, 'b'), (1, 'a')")

checkAnswer(
spark.sql(s"SELECT * FROM $t1"),
Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
val sqlQuery = spark.sql(s"SELECT * FROM $t1")
val dfQuery = spark.table(t1)

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c"), Row(2, "b"), Row(1, "a")))
}
}
}

test("SPARK-31255: metadata column should only be produced when necessary") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")

val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
val dfQuery = spark.table(t1).filter("index = 0")

Seq(sqlQuery, dfQuery).foreach { query =>
assert(query.schema.fieldNames.toSeq == Seq("id", "data"))
}
}
}

Expand Down