diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 42b92d4593c77..189318acd8661 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { + def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv))) + } + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } @@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = get(k).getOrElse(default) + override def contains(k: Attribute): Boolean = get(k).isDefined override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index e6b53e3e6548f..77152918bf687 100644 --- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { + def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv))) + } + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } @@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = get(k).getOrElse(default) + override def contains(k: Attribute): Boolean = get(k).isDefined override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 5b12667f4a884..cd7032d555992 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -17,71 +17,151 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** - * This aims to handle a nested column aliasing pattern inside the `ColumnPruning` optimizer rule. - * If a project or its child references to nested fields, and not all the fields - * in a nested attribute are used, we can substitute them by alias attributes; then a project - * of the nested fields as aliases on the children of the child will be created. + * This aims to handle a nested column aliasing pattern inside the [[ColumnPruning]] optimizer rule. + * If: + * - A [[Project]] or its child references nested fields + * - Not all of the fields in a nested attribute are used + * Then: + * - Substitute the nested field references with alias attributes + * - Add grandchild [[Project]]s transforming the nested fields to aliases + * + * Example 1: Project + * ------------------ + * Before: + * +- Project [concat_ws(s#0.a, s#0.b) AS concat_ws(s.a, s.b)#1] + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- LocalRelation , [s#0] + * After: + * +- Project [concat_ws(_extract_a#2, _extract_b#3) AS concat_ws(s.a, s.b)#1] + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3] + * +- LocalRelation , [s#0] + * + * Example 2: Project above Filter + * ------------------------------- + * Before: + * +- Project [s#0.a AS s.a#1] + * +- Filter (length(s#0.b) > 2) + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- LocalRelation , [s#0] + * After: + * +- Project [_extract_a#2 AS s.a#1] + * +- Filter (length(_extract_b#3) > 2) + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- Project [s#0.a AS _extract_a#2, s#0.b AS _extract_b#3] + * +- LocalRelation , [s#0] + * + * Example 3: Nested fields with referenced parents + * ------------------------------------------------ + * Before: + * +- Project [s#0.a AS s.a#1, s#0.a.a1 AS s.a.a1#2] + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- LocalRelation , [s#0] + * After: + * +- Project [_extract_a#3 AS s.a#1, _extract_a#3.name AS s.a.a1#2] + * +- GlobalLimit 5 + * +- LocalLimit 5 + * +- Project [s#0.a AS _extract_a#3] + * +- LocalRelation , [s#0] + * + * The schema of the datasource relation will be pruned in the [[SchemaPruning]] optimizer rule. */ object NestedColumnAliasing { def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { /** * This pattern is needed to support [[Filter]] plan cases like - * [[Project]]->[[Filter]]->listed plan in `canProjectPushThrough` (e.g., [[Window]]). - * The reason why we don't simply add [[Filter]] in `canProjectPushThrough` is that + * [[Project]]->[[Filter]]->listed plan in [[canProjectPushThrough]] (e.g., [[Window]]). + * The reason why we don't simply add [[Filter]] in [[canProjectPushThrough]] is that * the optimizer can hit an infinite loop during the [[PushDownPredicates]] rule. */ - case Project(projectList, Filter(condition, child)) - if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => - val exprCandidatesToPrune = projectList ++ Seq(condition) ++ child.expressions - getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq).map { - case (nestedFieldToAlias, attrToAliases) => - NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias, attrToAliases) - } + case Project(projectList, Filter(condition, child)) if + SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => + rewritePlanIfSubsetFieldsUsed( + plan, projectList ++ Seq(condition) ++ child.expressions, child.producedAttributes.toSeq) - case Project(projectList, child) - if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => - val exprCandidatesToPrune = projectList ++ child.expressions - getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq).map { - case (nestedFieldToAlias, attrToAliases) => - NestedColumnAliasing.replaceToAliases(plan, nestedFieldToAlias, attrToAliases) - } + case Project(projectList, child) if + SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => + rewritePlanIfSubsetFieldsUsed( + plan, projectList ++ child.expressions, child.producedAttributes.toSeq) case p if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(p) => - val exprCandidatesToPrune = p.expressions - getAliasSubMap(exprCandidatesToPrune, p.producedAttributes.toSeq).map { - case (nestedFieldToAlias, attrToAliases) => - NestedColumnAliasing.replaceToAliases(p, nestedFieldToAlias, attrToAliases) - } + rewritePlanIfSubsetFieldsUsed( + plan, p.expressions, p.producedAttributes.toSeq) case _ => None } + /** + * Rewrites a plan with aliases if only a subset of the nested fields are used. + */ + def rewritePlanIfSubsetFieldsUsed( + plan: LogicalPlan, + exprList: Seq[Expression], + exclusiveAttrs: Seq[Attribute]): Option[LogicalPlan] = { + val attrToExtractValues = getAttributeToExtractValues(exprList, exclusiveAttrs) + if (attrToExtractValues.isEmpty) { + None + } else { + Some(rewritePlanWithAliases(plan, attrToExtractValues)) + } + } + /** * Replace nested columns to prune unused nested columns later. */ - private def replaceToAliases( + def rewritePlanWithAliases( plan: LogicalPlan, - nestedFieldToAlias: Map[ExtractValue, Alias], - attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = plan match { - case Project(projectList, child) => - Project( - getNewProjectList(projectList, nestedFieldToAlias), - replaceWithAliases(child, nestedFieldToAlias, attrToAliases)) - - // The operators reaching here was already guarded by `canPruneOn`. - case other => - replaceWithAliases(other, nestedFieldToAlias, attrToAliases) + attributeToExtractValues: Map[Attribute, Seq[ExtractValue]]): LogicalPlan = { + // Each expression can contain multiple nested fields. + // Note that we keep the original names to deliver to parquet in a case-sensitive way. + // A new alias is created for each nested field. + // Implementation detail: we don't use mapValues, because it creates a mutable view. + val attributeToExtractValuesAndAliases = + attributeToExtractValues.map { case (attr, evSeq) => + val evAliasSeq = evSeq.map { ev => + val fieldName = ev match { + case g: GetStructField => g.extractFieldName + case g: GetArrayStructFields => g.field.name + } + ev -> Alias(ev, s"_extract_$fieldName")() + } + + attr -> evAliasSeq + } + + val nestedFieldToAlias = attributeToExtractValuesAndAliases.values.flatten.toMap + + // A reference attribute can have multiple aliases for nested fields. + val attrToAliases = AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2))) + + plan match { + case Project(projectList, child) => + Project( + getNewProjectList(projectList, nestedFieldToAlias), + replaceWithAliases(child, nestedFieldToAlias, attrToAliases)) + + // The operators reaching here are already guarded by [[canPruneOn]]. + case other => + replaceWithAliases(other, nestedFieldToAlias, attrToAliases) + } } /** - * Return a replaced project list. + * Replace the [[ExtractValue]]s in a project list with aliased attributes. */ def getNewProjectList( projectList: Seq[NamedExpression], @@ -93,15 +173,15 @@ object NestedColumnAliasing { } /** - * Return a plan with new children replaced with aliases, and expressions replaced with - * aliased attributes. + * Replace the grandchildren of a plan with [[Project]]s of the nested fields as aliases, + * and replace the [[ExtractValue]] expressions with aliased attributes. */ def replaceWithAliases( plan: LogicalPlan, nestedFieldToAlias: Map[ExtractValue, Alias], - attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { + attrToAliases: AttributeMap[Seq[Alias]]): LogicalPlan = { plan.withNewChildren(plan.children.map { plan => - Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, Seq(a))), plan) + Project(plan.output.flatMap(a => attrToAliases.getOrElse(a, Seq(a))), plan) }).transformExpressions { case f: ExtractValue if nestedFieldToAlias.contains(f) => nestedFieldToAlias(f).toAttribute @@ -109,7 +189,7 @@ object NestedColumnAliasing { } /** - * Returns true for those operators that we can prune nested column on it. + * Returns true for operators on which we can prune nested columns. */ private def canPruneOn(plan: LogicalPlan) = plan match { case _: Aggregate => true @@ -118,7 +198,7 @@ object NestedColumnAliasing { } /** - * Returns true for those operators that project can be pushed through. + * Returns true for operators through which project can be pushed. */ private def canProjectPushThrough(plan: LogicalPlan) = plan match { case _: GlobalLimit => true @@ -133,9 +213,10 @@ object NestedColumnAliasing { } /** - * Return root references that are individually accessed as a whole, and `GetStructField`s - * or `GetArrayStructField`s which on top of other `ExtractValue`s or special expressions. - * Check `SelectedField` to see which expressions should be listed here. + * Returns two types of expressions: + * - Root references that are individually accessed + * - [[GetStructField]] or [[GetArrayStructFields]] on top of other [[ExtractValue]]s + * or special expressions. */ private def collectRootReferenceAndExtractValue(e: Expression): Seq[Expression] = e match { case _: AttributeReference => Seq(e) @@ -149,67 +230,55 @@ object NestedColumnAliasing { } /** - * Return two maps in order to replace nested fields to aliases. - * - * If `exclusiveAttrs` is given, any nested field accessors of these attributes - * won't be considered in nested fields aliasing. - * - * 1. ExtractValue -> Alias: A new alias is created for each nested field. - * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it. + * Creates a map from root [[Attribute]]s to non-redundant nested [[ExtractValue]]s. + * Nested field accessors of `exclusiveAttrs` are not considered in nested fields aliasing. */ - def getAliasSubMap(exprList: Seq[Expression], exclusiveAttrs: Seq[Attribute] = Seq.empty) - : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = { - val (nestedFieldReferences, otherRootReferences) = - exprList.flatMap(collectRootReferenceAndExtractValue).partition { - case _: ExtractValue => true - case _ => false + def getAttributeToExtractValues( + exprList: Seq[Expression], + exclusiveAttrs: Seq[Attribute]): Map[Attribute, Seq[ExtractValue]] = { + + val nestedFieldReferences = new mutable.ArrayBuffer[ExtractValue]() + val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]() + exprList.foreach { e => + collectRootReferenceAndExtractValue(e).foreach { + case ev: ExtractValue => + if (ev.references.size == 1) { + nestedFieldReferences.append(ev) + } + case ar: AttributeReference => otherRootReferences.append(ar) } - - // Note that when we group by extractors with their references, we should remove - // cosmetic variations. + } val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) - val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]] + + // Remove cosmetic variations when we group extractors by their references + nestedFieldReferences .filter(!_.references.subsetOf(exclusiveAttrSet)) .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute]) - .flatMap { case (attr, nestedFields: Seq[ExtractValue]) => - // Remove redundant `ExtractValue`s if they share the same parent nest field. + .flatMap { case (attr: Attribute, nestedFields: Seq[ExtractValue]) => + // Remove redundant [[ExtractValue]]s if they share the same parent nest field. // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. - // We only need to deal with two `ExtractValue`: `GetArrayStructFields` and - // `GetStructField`. Please refer to the method `collectRootReferenceAndExtractValue`. + // Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`. val dedupNestedFields = nestedFields.filter { + // See [[collectExtractValue]]: we only need to deal with [[GetArrayStructFields]] and + // [[GetStructField]] case e @ (_: GetStructField | _: GetArrayStructFields) => val child = e.children.head nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty) case _ => true - } - - // Each expression can contain multiple nested fields. - // Note that we keep the original names to deliver to parquet in a case-sensitive way. - val nestedFieldToAlias = dedupNestedFields.distinct.map { f => - val exprId = NamedExpression.newExprId - (f, Alias(f, s"_gen_alias_${exprId.id}")(exprId, Seq.empty, None)) - } + }.distinct // If all nested fields of `attr` are used, we don't need to introduce new aliases. - // By default, ColumnPruning rule uses `attr` already. + // By default, the [[ColumnPruning]] rule uses `attr` already. // Note that we need to remove cosmetic variations first, so we only count a // nested field once. - if (nestedFieldToAlias.nonEmpty && - dedupNestedFields.map(_.canonicalized) - .distinct - .map { nestedField => totalFieldNum(nestedField.dataType) } - .sum < totalFieldNum(attr.dataType)) { - Some(attr.exprId -> nestedFieldToAlias) + val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct + .map { nestedField => totalFieldNum(nestedField.dataType) }.sum + if (numUsedNestedFields < totalFieldNum(attr.dataType)) { + Some((attr, dedupNestedFields.toSeq)) } else { None } } - - if (aliasSub.isEmpty) { - None - } else { - Some((aliasSub.values.flatten.toMap, aliasSub.map(x => (x._1, x._2.map(_._2))))) - } } /** @@ -227,31 +296,9 @@ object NestedColumnAliasing { } /** - * This prunes unnecessary nested columns from `Generate` and optional `Project` on top - * of it. + * This prunes unnecessary nested columns from [[Generate]], or [[Project]] -> [[Generate]] */ object GeneratorNestedColumnAliasing { - // Partitions `attrToAliases` based on whether the attribute is in Generator's output. - private def aliasesOnGeneratorOutput( - attrToAliases: Map[ExprId, Seq[Alias]], - generatorOutput: Seq[Attribute]) = { - val generatorOutputExprId = generatorOutput.map(_.exprId) - attrToAliases.partition { k => - generatorOutputExprId.contains(k._1) - } - } - - // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor - // is in Generator's output. - private def nestedFieldOnGeneratorOutput( - nestedFieldToAlias: Map[ExtractValue, Alias], - generatorOutput: Seq[Attribute]) = { - val generatorOutputSet = AttributeSet(generatorOutput) - nestedFieldToAlias.partition { pair => - pair._1.references.subsetOf(generatorOutputSet) - } - } - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we // need to prune nested columns through Project and under Generate. The difference is @@ -261,103 +308,100 @@ object GeneratorNestedColumnAliasing { SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) => // On top on `Generate`, a `Project` that might have nested column accessors. // We try to get alias maps for both project list and generator's children expressions. - val exprsToPrune = projectList ++ g.generator.children - NestedColumnAliasing.getAliasSubMap(exprsToPrune).map { - case (nestedFieldToAlias, attrToAliases) => - val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) = - nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput) - val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) = - aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput) - - // Push nested column accessors through `Generator`. - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newChild = NestedColumnAliasing.replaceWithAliases(g, - nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator) - val pushedThrough = Project(NestedColumnAliasing - .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild) - - // If the generator output is `ArrayType`, we cannot push through the extractor. - // It is because we don't allow field extractor on two-level array, - // i.e., attr.field when attr is a ArrayType(ArrayType(...)). - // Similarily, we also cannot push through if the child of generator is `MapType`. - g.generator.children.head.dataType match { - case _: MapType => return Some(pushedThrough) - case ArrayType(_: ArrayType, _) => return Some(pushedThrough) - case _ => - } - - // Pruning on `Generator`'s output. We only process single field case. - // For multiple field case, we cannot directly move field extractor into - // the generator expression. A workaround is to re-construct array of struct - // from multiple fields. But it will be more complicated and may not worth. - // TODO(SPARK-34956): support multiple fields. - if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) { - pushedThrough - } else { - // Only one nested column accessor. - // E.g., df.select(explode($"items").as("item")).select($"item.a") - pushedThrough match { - case p @ Project(_, newG: Generate) => - // Replace the child expression of `ExplodeBase` generator with - // nested column accessor. - // E.g., df.select(explode($"items").as("item")).select($"item.a") => - // df.select(explode($"items.a").as("item.a")) - val rewrittenG = newG.transformExpressions { - case e: ExplodeBase => - val extractor = nestedFieldsOnGenerator.head._1.transformUp { - case _: Attribute => - e.child - case g: GetStructField => - ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver) - } - e.withNewChildren(Seq(extractor)) - } + val attrToExtractValues = NestedColumnAliasing.getAttributeToExtractValues( + projectList ++ g.generator.children, Seq.empty) + if (attrToExtractValues.isEmpty) { + return None + } + val generatorOutputSet = AttributeSet(g.qualifiedGeneratorOutput) + val (attrToExtractValuesOnGenerator, attrToExtractValuesNotOnGenerator) = + attrToExtractValues.partition { case (attr, _) => + attr.references.subsetOf(generatorOutputSet) } + + val pushedThrough = NestedColumnAliasing.rewritePlanWithAliases( + plan, attrToExtractValuesNotOnGenerator) + + // If the generator output is `ArrayType`, we cannot push through the extractor. + // It is because we don't allow field extractor on two-level array, + // i.e., attr.field when attr is a ArrayType(ArrayType(...)). + // Similarily, we also cannot push through if the child of generator is `MapType`. + g.generator.children.head.dataType match { + case _: MapType => return Some(pushedThrough) + case ArrayType(_: ArrayType, _) => return Some(pushedThrough) + case _ => + } - // As we change the child of the generator, its output data type must be updated. - val updatedGeneratorOutput = rewrittenG.generatorOutput - .zip(rewrittenG.generator.elementSchema.toAttributes) - .map { case (oldAttr, newAttr) => - newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) - } - assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length, - "Updated generator output must have the same length " + - "with original generator output.") - val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput) - - // Replace nested column accessor with generator output. - p.withNewChildren(Seq(updatedGenerate)).transformExpressions { - case f: ExtractValue if nestedFieldsOnGenerator.contains(f) => - updatedGenerate.output - .find(a => attrToAliasesOnGenerator.contains(a.exprId)) - .getOrElse(f) + // Pruning on `Generator`'s output. We only process single field case. + // For multiple field case, we cannot directly move field extractor into + // the generator expression. A workaround is to re-construct array of struct + // from multiple fields. But it will be more complicated and may not worth. + // TODO(SPARK-34956): support multiple fields. + val nestedFieldsOnGenerator = attrToExtractValuesOnGenerator.values.flatten.toSet + if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) { + Some(pushedThrough) + } else { + // Only one nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") + val nestedFieldOnGenerator = nestedFieldsOnGenerator.head + pushedThrough match { + case p @ Project(_, newG: Generate) => + // Replace the child expression of `ExplodeBase` generator with + // nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") => + // df.select(explode($"items.a").as("item.a")) + val rewrittenG = newG.transformExpressions { + case e: ExplodeBase => + val extractor = nestedFieldOnGenerator.transformUp { + case _: Attribute => + e.child + case g: GetStructField => + ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver) } + e.withNewChildren(Seq(extractor)) + } - case other => - // We should not reach here. - throw new IllegalStateException(s"Unreasonable plan after optimization: $other") + // As we change the child of the generator, its output data type must be updated. + val updatedGeneratorOutput = rewrittenG.generatorOutput + .zip(rewrittenG.generator.elementSchema.toAttributes) + .map { case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length, + "Updated generator output must have the same length " + + "with original generator output.") + val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput) + + // Replace nested column accessor with generator output. + val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet + val updatedProject = p.withNewChildren(Seq(updatedGenerate)).transformExpressions { + case f: ExtractValue if nestedFieldsOnGenerator.contains(f) => + updatedGenerate.output + .find(a => attrExprIdsOnGenerator.contains(a.exprId)) + .getOrElse(f) } - } + Some(updatedProject) + + case other => + // We should not reach here. + throw new IllegalStateException(s"Unreasonable plan after optimization: $other") + } } case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && - canPruneGenerator(g.generator) => + canPruneGenerator(g.generator) => // If any child output is required by higher projection, we cannot prune on it even we // only use part of nested column of it. A required child output means it is referred // as a whole or partially by higher projection, pruning it here will cause unresolved // query plan. - NestedColumnAliasing.getAliasSubMap( - g.generator.children, g.requiredChildOutput).map { - case (nestedFieldToAlias, attrToAliases) => - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) - } + NestedColumnAliasing.rewritePlanIfSubsetFieldsUsed( + plan, g.generator.children, g.requiredChildOutput) case _ => None } /** - * This is a while-list for pruning nested fields at `Generator`. + * Types of [[Generator]] on which we can prune nested fields. */ def canPruneGenerator(g: Generator): Boolean = g match { case _: Explode => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 16e3e43356b9c..c90f4bcdd2602 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -785,7 +785,7 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // prune unrequired nested fields from `Generate`. - case GeneratorNestedColumnAliasing(p) => p + case GeneratorNestedColumnAliasing(rewrittenPlan) => rewrittenPlan // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _, _) => @@ -819,7 +819,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // Can't prune the columns on LeafNode case p @ Project(_, _: LeafNode) => p - case NestedColumnAliasing(p) => p + case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan // for all other logical plans that inherits the output from it's children // Project over project is handled by the first case, skip it here. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index a856caa6781e8..643974c9c707d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -714,7 +714,7 @@ object NestedColumnAliasingSuite { def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = { val aliases = ArrayBuffer[String]() query.transformAllExpressions { - case a @ Alias(_, name) if name.startsWith("_gen_alias_") => + case a @ Alias(_, name) if name.startsWith("_extract_") => aliases += name a }