From 7eaa4287a5c112a192cca388863324ffd855203e Mon Sep 17 00:00:00 2001 From: liyuanjian Date: Mon, 5 Sep 2016 16:27:02 +0800 Subject: [PATCH 1/9] [SPARK-4502][SQL]Support parquet nested struct pruning and add relevant tests --- .../apache/spark/sql/types/StructType.scala | 19 +++++- .../datasources/FileSourceStrategy.scala | 63 ++++++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 7 +++ .../parquet/ParquetQuerySuite.scala | 38 +++++++++++ 4 files changed, 124 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index dd4c88c4c43b..3b16fbf8f660 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -259,8 +259,23 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * @throws IllegalArgumentException if a field with the given name does not exist */ def apply(name: String): StructField = { - nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + if (name.contains('.')) { + val curFieldStr = name.split("\\.", 2)(0) + val nextFieldStr = name.split("\\.", 2)(1) + val curField = nameToField.getOrElse(curFieldStr, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + curField.dataType match { + case st: StructType => + val newField = StructType(st.fields).apply(nextFieldStr) + StructField(curField.name, StructType(Seq(newField)), + curField.nullable, curField.metadata) + case _ => + throw new IllegalArgumentException(s"""Field "$curFieldStr" is not struct field.""") + } + } else { + nameToField.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 8b36caf6f1e0..f7521d2077d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -97,7 +98,15 @@ object FileSourceStrategy extends Strategy with Logging { dataColumns .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) - val outputSchema = readDataColumns.toStructType + val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning) { + val requiredColumnsWithNesting = generateRequiredColumnsContainsNesting( + projects, readDataColumns.attrs.map(_.name).toArray) + val totalSchema = readDataColumns.toStructType + val prunedSchema = StructType(requiredColumnsWithNesting.map(totalSchema(_))) + // Merge schema in same StructType and merge with filterAttributes + prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _) + .merge(filterAttributes.toSeq.toStructType) + } else readDataColumns.toStructType logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) @@ -126,4 +135,56 @@ object FileSourceStrategy extends Strategy with Logging { case _ => Nil } + + private def generateRequiredColumnsContainsNesting(projects: Seq[Expression], + columns: Array[String]) : Array[String] = { + def generateAttributeMap(nestFieldMap: scala.collection.mutable.Map[String, Seq[String]], + isNestField: Boolean, curString: Option[String], + node: Expression) { + node match { + case ai: GetArrayItem => + // Here we drop the curString for simplify array and map support. + // Same strategy in GetArrayStructFields and GetMapValue + generateAttributeMap(nestFieldMap, isNestField = true, None, ai.child) + + case asf: GetArrayStructFields => + generateAttributeMap(nestFieldMap, isNestField = true, None, asf.child) + + case mv: GetMapValue => + generateAttributeMap(nestFieldMap, isNestField = true, None, mv.child) + + case attr: AttributeReference => + if (isNestField && curString.isDefined) { + val attrStr = attr.name + if (nestFieldMap.contains(attrStr)) { + nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "." + curString.get) + } else { + nestFieldMap += (attrStr -> Seq(attrStr + "." + curString.get)) + } + } + case sf: GetStructField => + val str = if (curString.isDefined) { + sf.name.get + "." + curString.get + } else sf.name.get + generateAttributeMap(nestFieldMap, isNestField = true, Option(str), sf.child) + case _ => + if (node.children.nonEmpty) { + node.children.foreach(child => generateAttributeMap(nestFieldMap, + isNestField, curString, child)) + } + } + } + + val nestFieldMap = scala.collection.mutable.Map.empty[String, Seq[String]] + projects.foreach(p => generateAttributeMap(nestFieldMap, isNestField = false, None, p)) + val col_list = columns.toList.flatMap(col => { + if (nestFieldMap.contains(col)) { + nestFieldMap.get(col).get.toList + } else { + List(col) + } + }) + col_list.toArray + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d6ca5a965cb..fb2003731240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -212,6 +212,11 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_NEST_COLUMN_PRUNING = SQLConfigBuilder("spark.sql.parquet.nestColumnPruning") + .doc("When set this to true, we will tell parquet only read the nest column`s leaf fields ") + .booleanConf + .createWithDefault(false) + val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata") .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") .booleanConf @@ -661,6 +666,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetNestColumnPruning: Boolean = getConf(PARQUET_NEST_COLUMN_PRUNING) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9dd8d9f80496..a7c03638a47c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -571,6 +571,44 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-4502 parquet nested fields pruning") { + // Schema of "test-data/nested-array-struct.parquet": + // root + // |-- primitive: integer (nullable = true) + // |-- myComplex: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- id: integer (nullable = true) + // | | |-- repeatedMessage: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- someId: integer (nullable = true) + val df = readResourceParquetFile("test-data/nested-array-struct.parquet") + df.createOrReplaceTempView("tmp_table") + // normal test + val query1 = "select primitive,myComplex[0].id from tmp_table" + val result1 = sql(query1) + withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query1), result1) + } + // test for array in struct + val query2 = "select primitive,myComplex[0].repeatedMessage[0].someId from tmp_table" + val result2 = sql(query2) + withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query2), result2) + } + // test for same struct meta merge + // myComplex.id and myComplex.repeatedMessage.someId should merge + // like myComplex.[id, repeatedMessage.someId] before pass to parquet + val query3 = "select myComplex[0].id, myComplex[0].repeatedMessage[0].someId" + + " from tmp_table" + val result3 = sql(query3) + withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query3), result3) + } + + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp_table"), ignoreIfNotExists = true, purge = false) + } + test("expand UDT in StructType") { val schema = new StructType().add("n", new NestedStructUDT, nullable = true) val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true) From 46153f14ebe183b3327d133f23043bf656c3b99f Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Wed, 7 Sep 2016 10:43:25 +0800 Subject: [PATCH 2/9] change the seperator of recursive fields in nested struct pruning --- .../apache/spark/sql/types/StructType.scala | 42 +++++++++++-------- .../datasources/FileSourceStrategy.scala | 9 ++-- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3b16fbf8f660..de30f9ad0eb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -259,23 +259,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * @throws IllegalArgumentException if a field with the given name does not exist */ def apply(name: String): StructField = { - if (name.contains('.')) { - val curFieldStr = name.split("\\.", 2)(0) - val nextFieldStr = name.split("\\.", 2)(1) - val curField = nameToField.getOrElse(curFieldStr, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - curField.dataType match { - case st: StructType => - val newField = StructType(st.fields).apply(nextFieldStr) - StructField(curField.name, StructType(Seq(newField)), - curField.nullable, curField.metadata) - case _ => - throw new IllegalArgumentException(s"""Field "$curFieldStr" is not struct field.""") - } - } else { - nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - } + nameToField.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) } /** @@ -294,6 +279,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(fields.filter(f => names.contains(f.name))) } + /** + * Extracts the [[StructField]] with the given name recursively. + * + * @throws IllegalArgumentException if the parent field's type is not StructType + */ + def getFieldRecursively(name: String): StructField = { + if (name.contains('#')) { + val curFieldStr = name.split("#", 2)(0) + val nextFieldStr = name.split("#", 2)(1) + val curField = this.apply(curFieldStr) + curField.dataType match { + case st: StructType => + val newField = StructType(st.fields).getFieldRecursively(nextFieldStr) + StructField(curField.name, StructType(Seq(newField)), + curField.nullable, curField.metadata) + case _ => + throw new IllegalArgumentException(s"""Field "$curFieldStr" is not struct field.""") + } + } else { + this.apply(name) + } + } + /** * Returns the index of a given field. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index f7521d2077d5..7223d58b5c93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -102,7 +102,8 @@ object FileSourceStrategy extends Strategy with Logging { val requiredColumnsWithNesting = generateRequiredColumnsContainsNesting( projects, readDataColumns.attrs.map(_.name).toArray) val totalSchema = readDataColumns.toStructType - val prunedSchema = StructType(requiredColumnsWithNesting.map(totalSchema(_))) + val prunedSchema = StructType(requiredColumnsWithNesting + .map(totalSchema.getFieldRecursively)) // Merge schema in same StructType and merge with filterAttributes prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _) .merge(filterAttributes.toSeq.toStructType) @@ -157,14 +158,14 @@ object FileSourceStrategy extends Strategy with Logging { if (isNestField && curString.isDefined) { val attrStr = attr.name if (nestFieldMap.contains(attrStr)) { - nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "." + curString.get) + nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "#" + curString.get) } else { - nestFieldMap += (attrStr -> Seq(attrStr + "." + curString.get)) + nestFieldMap += (attrStr -> Seq(attrStr + "#" + curString.get)) } } case sf: GetStructField => val str = if (curString.isDefined) { - sf.name.get + "." + curString.get + sf.name.get + "#" + curString.get } else sf.name.get generateAttributeMap(nestFieldMap, isNestField = true, Option(str), sf.child) case _ => From 46c2474ca76ee190ec9333eeae22eb010dce0195 Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Wed, 7 Sep 2016 10:52:25 +0800 Subject: [PATCH 3/9] use ',' as the recursive fields seperator --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 6 +++--- .../sql/execution/datasources/FileSourceStrategy.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index de30f9ad0eb9..0e1ca7e66c6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -285,9 +285,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * @throws IllegalArgumentException if the parent field's type is not StructType */ def getFieldRecursively(name: String): StructField = { - if (name.contains('#')) { - val curFieldStr = name.split("#", 2)(0) - val nextFieldStr = name.split("#", 2)(1) + if (name.contains(',')) { + val curFieldStr = name.split(",", 2)(0) + val nextFieldStr = name.split(",", 2)(1) val curField = this.apply(curFieldStr) curField.dataType match { case st: StructType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 7223d58b5c93..c0bcd55d511d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -158,14 +158,14 @@ object FileSourceStrategy extends Strategy with Logging { if (isNestField && curString.isDefined) { val attrStr = attr.name if (nestFieldMap.contains(attrStr)) { - nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "#" + curString.get) + nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "," + curString.get) } else { - nestFieldMap += (attrStr -> Seq(attrStr + "#" + curString.get)) + nestFieldMap += (attrStr -> Seq(attrStr + "," + curString.get)) } } case sf: GetStructField => val str = if (curString.isDefined) { - sf.name.get + "#" + curString.get + sf.name.get + "," + curString.get } else sf.name.get generateAttributeMap(nestFieldMap, isNestField = true, Option(str), sf.child) case _ => From 1c348779d4f7fff4de677b52a995330c9cfb6d03 Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Wed, 7 Sep 2016 20:07:11 +0800 Subject: [PATCH 4/9] Get the nested fields not modifying the column names --- .../apache/spark/sql/types/StructType.scala | 23 ------ .../datasources/FileSourceStrategy.scala | 72 ++++++++---------- .../test-data/nested-struct.snappy.parquet | Bin 0 -> 1402 bytes .../parquet/ParquetQuerySuite.scala | 33 ++++---- 4 files changed, 46 insertions(+), 82 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/nested-struct.snappy.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 0e1ca7e66c6a..dd4c88c4c43b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -279,29 +279,6 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(fields.filter(f => names.contains(f.name))) } - /** - * Extracts the [[StructField]] with the given name recursively. - * - * @throws IllegalArgumentException if the parent field's type is not StructType - */ - def getFieldRecursively(name: String): StructField = { - if (name.contains(',')) { - val curFieldStr = name.split(",", 2)(0) - val nextFieldStr = name.split(",", 2)(1) - val curField = this.apply(curFieldStr) - curField.dataType match { - case st: StructType => - val newField = StructType(st.fields).getFieldRecursively(nextFieldStr) - StructField(curField.name, StructType(Seq(newField)), - curField.nullable, curField.metadata) - case _ => - throw new IllegalArgumentException(s"""Field "$curFieldStr" is not struct field.""") - } - } else { - this.apply(name) - } - } - /** * Returns the index of a given field. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c0bcd55d511d..802cd90345c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -99,11 +99,9 @@ object FileSourceStrategy extends Strategy with Logging { .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning) { - val requiredColumnsWithNesting = generateRequiredColumnsContainsNesting( - projects, readDataColumns.attrs.map(_.name).toArray) val totalSchema = readDataColumns.toStructType - val prunedSchema = StructType(requiredColumnsWithNesting - .map(totalSchema.getFieldRecursively)) + val prunedSchema = StructType( + generateStructFieldsContainsNesting(projects, totalSchema)) // Merge schema in same StructType and merge with filterAttributes prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _) .merge(filterAttributes.toSeq.toStructType) @@ -137,55 +135,51 @@ object FileSourceStrategy extends Strategy with Logging { case _ => Nil } - private def generateRequiredColumnsContainsNesting(projects: Seq[Expression], - columns: Array[String]) : Array[String] = { - def generateAttributeMap(nestFieldMap: scala.collection.mutable.Map[String, Seq[String]], - isNestField: Boolean, curString: Option[String], - node: Expression) { + private def generateStructFieldsContainsNesting(projects: Seq[Expression], + totalSchema: StructType) : Seq[StructField] = { + def generateStructField(curField: List[String], + node: Expression) : Seq[StructField] = { node match { case ai: GetArrayItem => - // Here we drop the curString for simplify array and map support. + // Here we drop the previous for simplify array and map support. // Same strategy in GetArrayStructFields and GetMapValue - generateAttributeMap(nestFieldMap, isNestField = true, None, ai.child) - + generateStructField(List.empty[String], ai.child) case asf: GetArrayStructFields => - generateAttributeMap(nestFieldMap, isNestField = true, None, asf.child) - + generateStructField(List.empty[String], asf.child) case mv: GetMapValue => - generateAttributeMap(nestFieldMap, isNestField = true, None, mv.child) - + generateStructField(List.empty[String], mv.child) case attr: AttributeReference => - if (isNestField && curString.isDefined) { - val attrStr = attr.name - if (nestFieldMap.contains(attrStr)) { - nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "," + curString.get) - } else { - nestFieldMap += (attrStr -> Seq(attrStr + "," + curString.get)) - } - } + Seq(getFieldRecursively(totalSchema, attr.name :: curField)) case sf: GetStructField => - val str = if (curString.isDefined) { - sf.name.get + "," + curString.get - } else sf.name.get - generateAttributeMap(nestFieldMap, isNestField = true, Option(str), sf.child) + generateStructField(sf.name.get :: curField, sf.child) case _ => if (node.children.nonEmpty) { - node.children.foreach(child => generateAttributeMap(nestFieldMap, - isNestField, curString, child)) + node.children.flatMap(child => generateStructField(curField, child)) + } else { + Seq.empty[StructField] } } } - val nestFieldMap = scala.collection.mutable.Map.empty[String, Seq[String]] - projects.foreach(p => generateAttributeMap(nestFieldMap, isNestField = false, None, p)) - val col_list = columns.toList.flatMap(col => { - if (nestFieldMap.contains(col)) { - nestFieldMap.get(col).get.toList + def getFieldRecursively(totalSchema: StructType, + name: List[String]): StructField = { + if (name.length > 1) { + val curField = name.head + val curFieldType = totalSchema(curField) + curFieldType.dataType match { + case st: StructType => + val newField = getFieldRecursively(StructType(st.fields), name.drop(1)) + StructField(curFieldType.name, StructType(Seq(newField)), + curFieldType.nullable, curFieldType.metadata) + case _ => + throw new IllegalArgumentException(s"""Field "$curField" is not struct field.""") + } } else { - List(col) + totalSchema(name.head) } - }) - col_list.toArray + } + + projects.flatMap(p => generateStructField(List.empty[String], p)) } } diff --git a/sql/core/src/test/resources/test-data/nested-struct.snappy.parquet b/sql/core/src/test/resources/test-data/nested-struct.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..2f02fb0dea3b4fff2353933a3cc63d86c9b4da8a GIT binary patch literal 1402 zcmbVML5tHs6rM>!I(3VB*og`BP=XtGp&LxnrQOnl$|4HOvUpV~(s-BBgf^ z9z2MM$l}?v_(S{=d+-l<)Hg}fb{AJ$=a6~4H{bWZmv55eyN@i2=n*}rQUZZYNm(Tl z%9cv-EK>M>t0KZv{Dm+oVk8s`2;m!5UdaI|)%|5E!ppE!B4hzNPbF|FsaUv!)%@vptndX7os;g7)S?DIT%t zc{}bt^9GE{XireZ?l1sqh=kaJfJ$3}D#cdYT7)*DCCG}EC zQ%T#C5z!i)88}leDkwzxjFtD;AgxOsH%{>%;dT!HfZ+(>bJH%^YhLxtkYkHNK1U&MKcRMV(^h9fs0h6rV|1M~P_Z`$=)HvGNuhXHqNoBHvR_ zl#jAPkXBZAhEdOCBL;4msWE1J5}4r7$=FNGX!z0`cnNb^!qgw)L{DCgJiV#INaJpz zS9QPd1#YZ2A5Zj<4R{GAhJ1EzmGiN+=JECdpa;XS2k>wl1gsN4d(h=o^<1*voJ^>BAGnno>#&;3yn3f|+pT$Sr|Z}EZHIZ5 "true") { checkAnswer(sql(query1), result1) } - // test for array in struct - val query2 = "select primitive,myComplex[0].repeatedMessage[0].someId from tmp_table" + // test for same struct meta merge + // col.s1.s1_1 and col.str should merge + // like col.[s1.s1_1, str] before pass to parquet + val query2 = "select col.s1.s1_1,col.str from tmp_table" val result2 = sql(query2) withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { checkAnswer(sql(query2), result2) } - // test for same struct meta merge - // myComplex.id and myComplex.repeatedMessage.someId should merge - // like myComplex.[id, repeatedMessage.someId] before pass to parquet - val query3 = "select myComplex[0].id, myComplex[0].repeatedMessage[0].someId" + - " from tmp_table" - val result3 = sql(query3) - withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { - checkAnswer(sql(query3), result3) - } spark.sessionState.catalog.dropTable( TableIdentifier("tmp_table"), ignoreIfNotExists = true, purge = false) From 23465babd0f60db8a79e70f0589af2ec5bf360eb Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Thu, 8 Sep 2016 09:51:44 +0800 Subject: [PATCH 5/9] Add fileFormat check for nested fields pruning, only works for parquetFileFormat --- .../spark/sql/execution/datasources/FileSourceStrategy.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 802cd90345c6..f46c3a593c02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types.{StructField, StructType} /** @@ -98,7 +99,8 @@ object FileSourceStrategy extends Strategy with Logging { dataColumns .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) - val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning) { + val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning + && fsRelation.fileFormat.isInstanceOf[ParquetFileFormat]) { val totalSchema = readDataColumns.toStructType val prunedSchema = StructType( generateStructFieldsContainsNesting(projects, totalSchema)) From ab8f5ec15b2682ee40ea0483e0b6642b2a14c7ad Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Sun, 23 Oct 2016 17:25:22 +0800 Subject: [PATCH 6/9] fix code style and variable name --- .../datasources/FileSourceStrategy.scala | 35 +++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 6 ++-- .../parquet/ParquetQuerySuite.scala | 35 +++++++++---------- 3 files changed, 41 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index f46c3a593c02..b85e7328a78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -99,15 +99,19 @@ object FileSourceStrategy extends Strategy with Logging { dataColumns .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) - val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning - && fsRelation.fileFormat.isInstanceOf[ParquetFileFormat]) { - val totalSchema = readDataColumns.toStructType + val outputSchema = if ( + fsRelation.sqlContext.conf.parquetNestedColumnPruningEnabled && + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + ) { + val fullSchema = readDataColumns.toStructType val prunedSchema = StructType( - generateStructFieldsContainsNesting(projects, totalSchema)) + generateStructFieldsContainsNesting(projects, fullSchema)) // Merge schema in same StructType and merge with filterAttributes prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _) .merge(filterAttributes.toSeq.toStructType) - } else readDataColumns.toStructType + } else { + readDataColumns.toStructType + } logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) @@ -137,10 +141,12 @@ object FileSourceStrategy extends Strategy with Logging { case _ => Nil } - private def generateStructFieldsContainsNesting(projects: Seq[Expression], - totalSchema: StructType) : Seq[StructField] = { - def generateStructField(curField: List[String], - node: Expression) : Seq[StructField] = { + private def generateStructFieldsContainsNesting( + projects: Seq[Expression], + fullSchema: StructType) : Seq[StructField] = { + def generateStructField( + curField: List[String], + node: Expression) : Seq[StructField] = { node match { case ai: GetArrayItem => // Here we drop the previous for simplify array and map support. @@ -151,7 +157,7 @@ object FileSourceStrategy extends Strategy with Logging { case mv: GetMapValue => generateStructField(List.empty[String], mv.child) case attr: AttributeReference => - Seq(getFieldRecursively(totalSchema, attr.name :: curField)) + Seq(getFieldRecursively(fullSchema, attr.name :: curField)) case sf: GetStructField => generateStructField(sf.name.get :: curField, sf.child) case _ => @@ -163,11 +169,12 @@ object FileSourceStrategy extends Strategy with Logging { } } - def getFieldRecursively(totalSchema: StructType, - name: List[String]): StructField = { + def getFieldRecursively( + schema: StructType, + name: List[String]): StructField = { if (name.length > 1) { val curField = name.head - val curFieldType = totalSchema(curField) + val curFieldType = schema(curField) curFieldType.dataType match { case st: StructType => val newField = getFieldRecursively(StructType(st.fields), name.drop(1)) @@ -177,7 +184,7 @@ object FileSourceStrategy extends Strategy with Logging { throw new IllegalArgumentException(s"""Field "$curField" is not struct field.""") } } else { - totalSchema(name.head) + schema(name.head) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fb2003731240..e2bacbc2c78f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -212,8 +212,8 @@ object SQLConf { .booleanConf .createWithDefault(true) - val PARQUET_NEST_COLUMN_PRUNING = SQLConfigBuilder("spark.sql.parquet.nestColumnPruning") - .doc("When set this to true, we will tell parquet only read the nest column`s leaf fields ") + val PARQUET_NESTED_COLUMN_PRUNING = SQLConfigBuilder("spark.sql.parquet.nestedColumnPruning") + .doc("When true, Parquet column pruning also works for nested fields.") .booleanConf .createWithDefault(false) @@ -666,7 +666,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - def isParquetNestColumnPruning: Boolean = getConf(PARQUET_NEST_COLUMN_PRUNING) + def parquetNestedColumnPruningEnabled: Boolean = getConf(PARQUET_NESTED_COLUMN_PRUNING) def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index d3ed5bd013a2..d4d688d4add5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -581,25 +581,24 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext // | |-- str: string (nullable = true) // |-- num: long (nullable = true) // |-- str: string (nullable = true) - val df = readResourceParquetFile("test-data/nested-struct.snappy.parquet") - df.createOrReplaceTempView("tmp_table") - // normal test - val query1 = "select num,col.s1.s1_1 from tmp_table" - val result1 = sql(query1) - withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { - checkAnswer(sql(query1), result1) - } - // test for same struct meta merge - // col.s1.s1_1 and col.str should merge - // like col.[s1.s1_1, str] before pass to parquet - val query2 = "select col.s1.s1_1,col.str from tmp_table" - val result2 = sql(query2) - withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") { - checkAnswer(sql(query2), result2) + withTempView("tmp_table") { + val df = readResourceParquetFile("test-data/nested-struct.snappy.parquet") + df.createOrReplaceTempView("tmp_table") + // normal test + val query1 = "select num,col.s1.s1_1 from tmp_table" + val result1 = sql(query1) + withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query1), result1) + } + // test for same struct meta merge + // col.s1.s1_1 and col.str should merge + // like col.[s1.s1_1, str] before pass to parquet + val query2 = "select col.s1.s1_1,col.str from tmp_table" + val result2 = sql(query2) + withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query2), result2) + } } - - spark.sessionState.catalog.dropTable( - TableIdentifier("tmp_table"), ignoreIfNotExists = true, purge = false) } test("expand UDT in StructType") { From 92ed3696adc23be5edc8bba18595e35cfdbf8956 Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Tue, 25 Oct 2016 21:56:28 +0800 Subject: [PATCH 7/9] add comments and test cases for method generateStructFieldsContainsNesting --- .../datasources/FileSourceStrategy.scala | 10 ++- .../datasources/FileSourceStrategySuite.scala | 86 +++++++++++++++++-- 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index b85e7328a78d..a382f9cffda5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -144,6 +144,10 @@ object FileSourceStrategy extends Strategy with Logging { private def generateStructFieldsContainsNesting( projects: Seq[Expression], fullSchema: StructType) : Seq[StructField] = { + // By traverse projects, we can fisrt generate the access path of nested struct, then use the + // access path reconstruct the schema after pruning. + // In the process of traversing, we should deal with all expressions releted with complex + // struct type like GetArrayItem, GetArrayStructFields, GetMapValue and GetStructField def generateStructField( curField: List[String], node: Expression) : Seq[StructField] = { @@ -157,6 +161,8 @@ object FileSourceStrategy extends Strategy with Logging { case mv: GetMapValue => generateStructField(List.empty[String], mv.child) case attr: AttributeReference => + // Finally reach the leaf node AttributeReference, call getFieldRecursively + // and pass the access path of current nested struct Seq(getFieldRecursively(fullSchema, attr.name :: curField)) case sf: GetStructField => generateStructField(sf.name.get :: curField, sf.child) @@ -169,9 +175,7 @@ object FileSourceStrategy extends Strategy with Logging { } } - def getFieldRecursively( - schema: StructType, - name: List[String]): StructField = { + def getFieldRecursively(schema: StructType, name: List[String]): StructField = { if (name.length > 1) { val curField = name.head val curFieldType = schema(curField) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 09fd75018035..7cfa5f6fc386 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -23,22 +23,23 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkConf import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{util, InternalRow} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} -import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ExpressionSet, GetArrayItem, GetStructField, Literal, PredicateHelper} import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper + with PrivateMethodTester{ import testImplicits._ protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") @@ -441,6 +442,79 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("[SPARK-4502] pruning nested schema by projects correctly") { + val testFunc = PrivateMethod[Seq[StructField]]('generateStructFieldsContainsNesting) + // Construct fullSchema like below: + // root + // |-- col: struct (nullable = true) + // | |-- s1: struct (nullable = true) + // | | |-- s1_1: long (nullable = true) + // | | |-- s1_2: long (nullable = true) + // | |-- str: string (nullable = true) + // | |-- info_list: array (nullable = true) + // | | |-- element: struct (containsNull = true) + // | | | |-- s1: struct (nullable = true) + // | | | | |-- s1_1: long (nullable = true) + // | | | | |-- s1_2: long (nullable = true) + // |-- num: long (nullable = true) + // |-- str: string (nullable = true) + val nested_s1 = StructField("s1", + StructType( + Seq( + StructField("s1_1", LongType, true), + StructField("s1_2", LongType, true) + ) + ), true) + val flat_str = StructField("str", StringType, true) + val nested_arr = StructField("info_list", ArrayType(StructType(Seq(nested_s1))), true) + + val fullSchema = StructType( + Seq( + StructField("col", StructType(Seq(nested_s1, flat_str, nested_arr)), true), + StructField("num", LongType, true), + flat_str + )) + + // Attr of struct col + val colAttr = AttributeReference("col", StructType( + Seq(nested_s1, flat_str, nested_arr)), true)() + // Child expression of col.s1.s1_1 + val childExp = GetStructField( + GetStructField(colAttr, 0, Some("s1")), 0, Some("s1_1")) + // Child expression of col.info_list[0].s1.s1_1 + val arrayChildExp = GetStructField( + GetStructField( + GetArrayItem( + GetStructField(colAttr, 0, Some("info_list")), + Literal(0) + ), 0, Some("s1") + ), 0, Some("s1_1") + ) + // Project list of "select num, col.s1.s1_1 as s1_1, col.info_list[0].s1.s1_1 as complex_get" + val projects = Seq( + AttributeReference("num", LongType, true)(), + Alias(childExp, "s1_1")(), + Alias(arrayChildExp, "complex_get")() + ) + val expextResult = + Seq( + StructField("num", LongType, true), + StructField("col", StructType( + Seq( + StructField( + "s1", + StructType(Seq(StructField("s1_1", LongType, true))), + true) + ) + ), true), + StructField("col", StructType(Seq(nested_arr))) + ) + // Call the function generateStructFieldsContainsNesting + val result = FileSourceStrategy.invokePrivate[Seq[StructField]](testFunc(projects, + fullSchema)) + assert(result == expextResult) + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = From d9aa397683afc4b936529d6983f9b48dd4d2ee15 Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Wed, 26 Oct 2016 14:09:19 +0800 Subject: [PATCH 8/9] fix test build --- .../sql/execution/datasources/FileSourceStrategySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 9f67f89bdb0c..24dbde2e639f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem import org.apache.hadoop.mapreduce.Job import org.scalatest.PrivateMethodTester -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{util, InternalRow} import org.apache.spark.sql.catalyst.catalog.BucketSpec From d093c82906919ce1d9405d437e3cd6656983136b Mon Sep 17 00:00:00 2001 From: xuanyuanking Date: Sun, 30 Oct 2016 18:17:54 +0800 Subject: [PATCH 9/9] support named_struct and struct in schema pruning --- .../datasources/FileSourceStrategy.scala | 15 ++- .../datasources/FileSourceStrategySuite.scala | 109 +++++++++++++----- 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 50688f62301f..9ad2108a93d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -100,8 +100,8 @@ object FileSourceStrategy extends Strategy with Logging { .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) val outputSchema = if ( - fsRelation.sqlContext.conf.parquetNestedColumnPruningEnabled && - fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + fsRelation.sqlContext.conf.parquetNestedColumnPruningEnabled && + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ) { val fullSchema = readDataColumns.toStructType val prunedSchema = StructType( @@ -141,7 +141,7 @@ object FileSourceStrategy extends Strategy with Logging { case _ => Nil } - private def generateStructFieldsContainsNesting( + private[sql] def generateStructFieldsContainsNesting( projects: Seq[Expression], fullSchema: StructType) : Seq[StructField] = { // By traverse projects, we can fisrt generate the access path of nested struct, then use the @@ -164,8 +164,13 @@ object FileSourceStrategy extends Strategy with Logging { // Finally reach the leaf node AttributeReference, call getFieldRecursively // and pass the access path of current nested struct Seq(getFieldRecursively(fullSchema, attr.name :: curField)) - case sf: GetStructField => - generateStructField(sf.name.get :: curField, sf.child) + case sf: GetStructField if !sf.child.isInstanceOf[CreateNamedStruct] && + !sf.child.isInstanceOf[CreateStruct] => + val name = sf.name.getOrElse(sf.dataType match { + case StructType(fiedls) => + fiedls(sf.ordinal).name + }) + generateStructField(name :: curField, sf.child) case _ => if (node.children.nonEmpty) { node.children.flatMap(child => generateStructField(curField, child)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 24dbde2e639f..38943febf52f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -24,13 +24,12 @@ import java.util.zip.GZIPOutputStream import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{util, InternalRow} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ExpressionSet, GetArrayItem, GetStructField, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateNamedStruct, Expression, ExpressionSet, GetArrayItem, GetStructField, Literal, PredicateHelper} import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -39,8 +38,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper - with PrivateMethodTester{ +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") @@ -443,8 +441,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } - test("[SPARK-4502] pruning nested schema by projects correctly") { - val testFunc = PrivateMethod[Seq[StructField]]('generateStructFieldsContainsNesting) + test("[SPARK-4502] pruning nested schema by GetStructField projects") { // Construct fullSchema like below: // root // |-- col: struct (nullable = true) @@ -452,11 +449,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi // | | |-- s1_1: long (nullable = true) // | | |-- s1_2: long (nullable = true) // | |-- str: string (nullable = true) - // | |-- info_list: array (nullable = true) - // | | |-- element: struct (containsNull = true) - // | | | |-- s1: struct (nullable = true) - // | | | | |-- s1_1: long (nullable = true) - // | | | | |-- s1_2: long (nullable = true) // |-- num: long (nullable = true) // |-- str: string (nullable = true) val nested_s1 = StructField("s1", @@ -467,21 +459,70 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi ) ), true) val flat_str = StructField("str", StringType, true) - val nested_arr = StructField("info_list", ArrayType(StructType(Seq(nested_s1))), true) val fullSchema = StructType( Seq( - StructField("col", StructType(Seq(nested_s1, flat_str, nested_arr)), true), + StructField("col", StructType(Seq(nested_s1, flat_str)), true), StructField("num", LongType, true), flat_str )) // Attr of struct col val colAttr = AttributeReference("col", StructType( - Seq(nested_s1, flat_str, nested_arr)), true)() + Seq(nested_s1, flat_str)), true)() // Child expression of col.s1.s1_1 val childExp = GetStructField( GetStructField(colAttr, 0, Some("s1")), 0, Some("s1_1")) + + // Project list of "select num, col.s1.s1_1 as s1_1" + val projects = Seq( + AttributeReference("num", LongType, true)(), + Alias(childExp, "s1_1")() + ) + val expextResult = + Seq( + StructField("num", LongType, true), + StructField("col", StructType( + Seq( + StructField( + "s1", + StructType(Seq(StructField("s1_1", LongType, true))), + true) + ) + ), true) + ) + // Call the function generateStructFieldsContainsNesting + val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects, + fullSchema) + assert(result == expextResult) + } + + test("[SPARK-4502] pruning nested schema by GetArrayItem projects") { + // Construct fullSchema like below: + // root + // |-- col: struct (nullable = true) + // | |-- info_list: array (nullable = true) + // | | |-- element: struct (containsNull = true) + // | | | |-- s1: struct (nullable = true) + // | | | | |-- s1_1: long (nullable = true) + // | | | | |-- s1_2: long (nullable = true) + val nested_s1 = StructField("s1", + StructType( + Seq( + StructField("s1_1", LongType, true), + StructField("s1_2", LongType, true) + ) + ), true) + val nested_arr = StructField("info_list", ArrayType(StructType(Seq(nested_s1))), true) + + val fullSchema = StructType( + Seq( + StructField("col", StructType(Seq(nested_arr)), true) + )) + + // Attr of struct col + val colAttr = AttributeReference("col", StructType( + Seq(nested_arr)), true)() // Child expression of col.info_list[0].s1.s1_1 val arrayChildExp = GetStructField( GetStructField( @@ -491,31 +532,41 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi ), 0, Some("s1") ), 0, Some("s1_1") ) - // Project list of "select num, col.s1.s1_1 as s1_1, col.info_list[0].s1.s1_1 as complex_get" + // Project list of "select col.info_list[0].s1.s1_1 as complex_get" val projects = Seq( - AttributeReference("num", LongType, true)(), - Alias(childExp, "s1_1")(), Alias(arrayChildExp, "complex_get")() - ) + ) val expextResult = Seq( - StructField("num", LongType, true), - StructField("col", StructType( - Seq( - StructField( - "s1", - StructType(Seq(StructField("s1_1", LongType, true))), - true) - ) - ), true), StructField("col", StructType(Seq(nested_arr))) ) // Call the function generateStructFieldsContainsNesting - val result = FileSourceStrategy.invokePrivate[Seq[StructField]](testFunc(projects, - fullSchema)) + val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects, + fullSchema) assert(result == expextResult) } + test("[SPARK-4502] pruning nested schema while named_struct in project") { + val schema = new StructType() + .add("f0", IntegerType) + .add("f1", new StructType() + .add("f10", IntegerType)) + + val expr = GetStructField( + CreateNamedStruct(Seq( + Literal("f10"), + AttributeReference("f0", IntegerType)() + )), + 0, + Some("f10") + ) + + val expect = new StructType() + .add("f0", IntegerType) + + assert(FileSourceStrategy.generateStructFieldsContainsNesting(expr :: Nil, schema) == expect) + } + test("spark.files.ignoreCorruptFiles should work in SQL") { val inputFile = File.createTempFile("input-", ".gz") try {