Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) {
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
// For case-sensitivity aware field resolution, we should take `ordinal` which
// points to correct struct field.
// points to correct struct field, because `ExtractValue` actually does column
// name resolving correctly.
val selectedField = a.child.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType](a.ordinal)
val prunedField = projSchema(selectedField.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,27 @@ object NestedColumnAliasing {
* of it.
*/
object GeneratorNestedColumnAliasing {
// Partitions `attrToAliases` based on whether the attribute is in Generator's output.
private def aliasesOnGeneratorOutput(
attrToAliases: Map[ExprId, Seq[Alias]],
generatorOutput: Seq[Attribute]) = {
val generatorOutputExprId = generatorOutput.map(_.exprId)
attrToAliases.partition { k =>
generatorOutputExprId.contains(k._1)
}
}

// Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
// is in Generator's output.
private def nestedFieldOnGeneratorOutput(
nestedFieldToAlias: Map[ExtractValue, Alias],
generatorOutput: Seq[Attribute]) = {
val generatorOutputSet = AttributeSet(generatorOutput)
nestedFieldToAlias.partition { pair =>
pair._1.references.subsetOf(generatorOutputSet)
}
}

def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
// need to prune nested columns through Project and under Generate. The difference is
Expand All @@ -241,12 +262,68 @@ object GeneratorNestedColumnAliasing {
// On top on `Generate`, a `Project` that might have nested column accessors.
// We try to get alias maps for both project list and generator's children expressions.
val exprsToPrune = projectList ++ g.generator.children
NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
case (nestedFieldToAlias, attrToAliases) =>
val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)

// Push nested column accessors through `Generator`.
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
val newChild =
NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
val newChild = NestedColumnAliasing.replaceWithAliases(g,
nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
val pushedThrough = Project(NestedColumnAliasing
.getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)

// Pruning on `Generator`'s output. We only process single field case.
// For multiple field case, we cannot directly move field extractor into
// the generator expression. A workaround is to re-construct array of struct
// from multiple fields. But it will be more complicated and may not worth.
// TODO(SPARK-34956): support multiple fields.
if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.size == 0) {
pushedThrough
} else {
// Only one nested column accessor.
// E.g., df.select(explode($"items").as("item")).select($"item.a")
pushedThrough match {
case p @ Project(_, newG: Generate) =>
// Replace the child expression of `ExplodeBase` generator with
// nested column accessor.
// E.g., df.select(explode($"items").as("item")).select($"item.a") =>
// df.select(explode($"items.a").as("item.a"))
val rewrittenG = newG.transformExpressions {
case e: ExplodeBase =>
val extractor = nestedFieldsOnGenerator.head._1.transformUp {
case _: Attribute =>
e.child
case g: GetStructField =>
ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver)
}
e.withNewChildren(Seq(extractor))
}

// As we change the child of the generator, its output data type must be updated.
val updatedGeneratorOutput = rewrittenG.generatorOutput
.zip(rewrittenG.generator.elementSchema.toAttributes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation? It seems that two-space is enough in this case.

.map { case (oldAttr, newAttr) =>
newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
}
assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
"Updated generator output must have same length as original generator output.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, same length as -> the same length with?

val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)

// Replace nested column accessor with generator output.
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
updatedGenerate.output
.find(a => attrToAliasesOnGenerator.contains(a.exprId))
.getOrElse(f)
}

case _ => pushedThrough
}
}
}

case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}

test("Nested field pruning for Project and Generate: not prune on generator output") {
test("Nested field pruning for Project and Generate: multiple-field case is not supported") {
val companies = LocalRelation(
'id.int,
'employers.array(employer))

val query = companies
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
.select('company.getField("name"))
.select('company.getField("name"), 'company.getField("address"))
.analyze
val optimized = Optimize.execute(query)

Expand All @@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.generate(Explode($"${aliases(0)}"),
unrequiredChildIndex = Seq(0),
outputNames = Seq("company"))
.select('company.getField("name").as("company.name"))
.select('company.getField("name").as("company.name"),
'company.getField("address").as("company.address"))
.analyze
comparePlans(optimized, expected)
}
Expand Down Expand Up @@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
).analyze
comparePlans(optimized2, expected2)
}

test("SPARK-34638: nested column prune on generator output for one field") {
val companies = LocalRelation(
'id.int,
'employers.array(employer))

val query = companies
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
.select('company.getField("name"))
.analyze
val optimized = Optimize.execute(query)

val aliases = collectGeneratedAliases(optimized)

val expected = companies
.select('employers.getField("company").getField("name").as(aliases(0)))
.generate(Explode($"${aliases(0)}"),
unrequiredChildIndex = Seq(0),
outputNames = Seq("company"))
.select('company.as("company.name"))
.analyze
comparePlans(optimized, expected)
}
}

object NestedColumnAliasingSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,41 @@ abstract class SchemaPruningSuite
}
}

testSchemaPruning("SPARK-34638: nested column prune on generator output") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first")
checkScan(query1, "struct<friends:array<struct<first:string>>>")
checkAnswer(query1, Row("Susan") :: Nil)

// Currently we don't prune multiple field case.
val query2 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first", "friend.middle")
checkScan(query2, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
checkAnswer(query2, Row("Susan", "Z.") :: Nil)

val query3 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.first", "friend.middle", "friend")
checkScan(query3, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val query4 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.First")
checkScan(query4, "struct<friends:array<struct<first:string>>>")
checkAnswer(query4, Row("Susan") :: Nil)

val query5 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
.select("friend.MIDDLE")
checkScan(query5, "struct<friends:array<struct<middle:string>>>")
checkAnswer(query5, Row("Z.") :: Nil)
}
}

testSchemaPruning("select one deep nested complex field after repartition") {
val query = sql("select * from contacts")
.repartition(100)
Expand Down