diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala index 9a5a7e6aab63a..64e80081018a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioning.scala @@ -39,8 +39,18 @@ object V2ScanPartitioning extends Rule[LogicalPlan] with SQLConfHelper { } val catalystPartitioning = scan.outputPartitioning() match { - case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( - V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt))) + case kgp: KeyGroupedPartitioning => + val partitioning = sequenceToOption(kgp.keys().map( + V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt))) + if (partitioning.isEmpty) { + None + } else { + if (partitioning.get.forall(p => p.references.subsetOf(d.outputSet))) { + partitioning + } else { + None + } + } case _: UnknownPartitioning => None case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + "type: " + p.getClass.getSimpleName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index 95b9c4f72356a..7f0e74f6bc7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -216,4 +216,20 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { .withColumn("right_all", struct($"right.*")) checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) } + + test("SPARK-40429: Only set KeyGroupedPartitioning when the referenced column is in the output") { + withTable(tbl) { + sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (id)") + sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") + checkAnswer( + spark.table(tbl).select("index", "_partition"), + Seq(Row(0, "3"), Row(0, "2"), Row(0, "1")) + ) + + checkAnswer( + spark.table(tbl).select("id", "index", "_partition"), + Seq(Row(3, 0, "3"), Row(2, 0, "2"), Row(1, 0, "1")) + ) + } + } }