Skip to content

Commit a2d5c9c

Browse files
ulysses-youcloud-fan
authored andcommitted
[SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
### What changes were proposed in this pull request? Add a new trait `ReferenceAllColumns ` that overrides `references` using children output. Then we can skip it during rewriting attributes in transformUpWithNewOutput. ### Why are the changes needed? There are two reasons with this new trait: 1. it's dangerous to call `references` on an unresolved plan that all of references come from children 2. it's unnecessary to rewrite its attributes that all of references come from children ### Does this PR introduce _any_ user-facing change? prevent potential bug ### How was this patch tested? add test and pass CI Closes #40154 from ulysses-you/references. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit db0e822) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 816774a commit a2d5c9c

6 files changed

Lines changed: 81 additions & 32 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
297297
newChild
298298
}
299299

300-
val attrMappingForCurrentPlan = attrMapping.filter {
301-
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
302-
// current `plan`, so the `oldAttr` must be part of `plan.references`.
303-
case (oldAttr, _) => plan.references.contains(oldAttr)
304-
}
305-
306-
if (attrMappingForCurrentPlan.nonEmpty) {
307-
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
308-
.exists(_._2.map(_._2.exprId).distinct.length > 1),
309-
"Found duplicate rewrite attributes")
310-
311-
val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
312-
// Using attrMapping from the children plans to rewrite their parent node.
313-
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
314-
newPlan = newPlan.rewriteAttrs(attributeRewrites)
300+
plan match {
301+
case _: ReferenceAllColumns[_] =>
302+
// It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and
303+
// it's unnecessary to rewrite its attributes that all of references come from children
304+
305+
case _ =>
306+
val attrMappingForCurrentPlan = attrMapping.filter {
307+
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
308+
// current `plan`, so the `oldAttr` must be part of `plan.references`.
309+
case (oldAttr, _) => plan.references.contains(oldAttr)
310+
}
311+
312+
if (attrMappingForCurrentPlan.nonEmpty) {
313+
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
314+
.exists(_._2.map(_._2.exprId).distinct.length > 1),
315+
"Found duplicate rewrite attributes")
316+
317+
val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
318+
// Using attrMapping from the children plans to rewrite their parent node.
319+
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
320+
newPlan = newPlan.rewriteAttrs(attributeRewrites)
321+
}
315322
}
316323

317324
val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans
19+
20+
import org.apache.spark.sql.catalyst.expressions.AttributeSet
21+
22+
/**
23+
* A trait that overrides `references` using children output.
24+
*
25+
* It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references
26+
* come from it's children.
27+
*
28+
* Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]].
29+
*/
30+
trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] =>
31+
32+
@transient
33+
override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet))
34+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
20+
import org.apache.spark.sql.catalyst.expressions.Attribute
21+
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
2122

2223
/**
2324
* Transforms the input by forking and running the specified script.
@@ -30,10 +31,7 @@ case class ScriptTransformation(
3031
script: String,
3132
output: Seq[Attribute],
3233
child: LogicalPlan,
33-
ioschema: ScriptInputOutputSchema) extends UnaryNode {
34-
@transient
35-
override lazy val references: AttributeSet = AttributeSet(child.output)
36-
34+
ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
3735
override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
3836
copy(child = newChild)
3937
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
2424
import org.apache.spark.sql.catalyst.encoders._
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
27+
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
2728
import org.apache.spark.sql.catalyst.trees.TreePattern._
2829
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
@@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan {
6465
* A trait for logical operators that consumes domain objects as input.
6566
* The output of its child must be a single-field row containing the input object.
6667
*/
67-
trait ObjectConsumer extends UnaryNode {
68+
trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
6869
assert(child.output.length == 1)
69-
70-
// This operator always need all columns of its child, even it doesn't reference to.
71-
@transient
72-
override lazy val references: AttributeSet = child.outputSet
73-
7470
def inputObjAttr: Attribute = child.output.head
7571
}
7672

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
2525
import org.apache.spark.sql.catalyst.dsl.expressions._
2626
import org.apache.spark.sql.catalyst.dsl.plans._
2727
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
2829
import org.apache.spark.sql.catalyst.plans.logical._
2930
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
3031
import org.apache.spark.sql.internal.SQLConf
@@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
17401741
}
17411742
}
17421743
}
1744+
1745+
test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") {
1746+
val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))())
1747+
val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))())
1748+
val unresolved = t1.union(t2).select(UnresolvedStar(None))
1749+
val referenceAllColumns = FakeReferenceAllColumns(unresolved)
1750+
val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head))
1751+
assert(wp1.isInstanceOf[Project])
1752+
assert(wp1.expressions.forall(!_.exists(_ == t1.output.head)))
1753+
}
17431754
}
17441755

17451756

@@ -1798,3 +1809,10 @@ object TypeCoercionSuite {
17981809
copy(left = newLeft, right = newRight)
17991810
}
18001811
}
1812+
1813+
case class FakeReferenceAllColumns(child: LogicalPlan)
1814+
extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
1815+
override def output: Seq[Attribute] = child.output
1816+
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
1817+
copy(child = newChild)
1818+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.expressions.codegen._
3434
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
35+
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
3536
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
3637
import org.apache.spark.sql.catalyst.plans.physical._
3738
import org.apache.spark.sql.execution.python.BatchIterator
@@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan {
5859
/**
5960
* Physical version of `ObjectConsumer`.
6061
*/
61-
trait ObjectConsumerExec extends UnaryExecNode {
62+
trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] {
6263
assert(child.output.length == 1)
63-
64-
// This operator always need all columns of its child, even it doesn't reference to.
65-
@transient
66-
override lazy val references: AttributeSet = child.outputSet
67-
6864
def inputObjectType: DataType = child.output.head.dataType
6965
}
7066

0 commit comments

Comments
 (0)