From 80dd67993f3f21954f2f5c7f6aafaeda58d7e5f6 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Mon, 27 Feb 2023 11:18:55 +0800 Subject: [PATCH] Add ReferenceAllColumns to skip rewriting attributes --- .../spark/sql/catalyst/plans/QueryPlan.scala | 37 +++++++++++-------- .../catalyst/plans/ReferenceAllColumns.scala | 34 +++++++++++++++++ .../plans/logical/ScriptTransformation.scala | 8 ++-- .../sql/catalyst/plans/logical/object.scala | 8 +--- .../catalyst/analysis/TypeCoercionSuite.scala | 18 +++++++++ .../apache/spark/sql/execution/objects.scala | 8 +--- 6 files changed, 81 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 90d1bd805cb5..ae5e9789dd94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] newChild } - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) + plan match { + case _: ReferenceAllColumns[_] => + // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and + // it's unnecessary to rewrite its attributes that all of references come from children + + case _ => + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } } val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala new file mode 100644 index 000000000000..613e2a06f498 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.expressions.AttributeSet + +/** + * A trait that overrides `references` using children output. + * + * It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references + * come from it's children. + * + * Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]]. + */ +trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => + + @transient + override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 5fe5dc373718..e6ebf981bc4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns /** * Transforms the input by forking and running the specified script. @@ -30,10 +31,7 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: LogicalPlan, - ioschema: ScriptInputOutputSchema) extends UnaryNode { - @transient - override lazy val references: AttributeSet = AttributeSet(child.output) - + ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] { override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index b27c650cfb29..c6a4779374db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan { * A trait for logical operators that consumes domain objects as input. * The output of its child must be a single-field row containing the input object. */ -trait ObjectConsumer extends UnaryNode { +trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] { assert(child.output.length == 1) - - // This operator always need all columns of its child, even it doesn't reference to. - @transient - override lazy val references: AttributeSet = child.outputSet - def inputObjAttr: Attribute = child.output.head } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index adce553d1942..e30cce23136f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf @@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { } } } + + test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") { + val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))()) + val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))()) + val unresolved = t1.union(t2).select(UnresolvedStar(None)) + val referenceAllColumns = FakeReferenceAllColumns(unresolved) + val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head)) + assert(wp1.isInstanceOf[Project]) + assert(wp1.expressions.forall(!_.exists(_ == t1.output.head))) + } } @@ -1798,3 +1809,10 @@ object TypeCoercionSuite { copy(left = newLeft, right = newRight) } } + +case class FakeReferenceAllColumns(child: LogicalPlan) + extends UnaryNode with ReferenceAllColumns[LogicalPlan] { + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index bda592ff9299..c8d575016fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.python.BatchIterator @@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan { /** * Physical version of `ObjectConsumer`. */ -trait ObjectConsumerExec extends UnaryExecNode { +trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] { assert(child.output.length == 1) - - // This operator always need all columns of its child, even it doesn't reference to. - @transient - override lazy val references: AttributeSet = child.outputSet - def inputObjectType: DataType = child.output.head.dataType }