diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 241c761624b76..dea5926e8b2b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -51,6 +51,20 @@ case class ProjectionOverSchema(schema: StructType) { s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" ) } + case ExtractNestedArrayField(child, _, _, field, containsNull, containsNullSeq) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, ExtractNestedArrayType(projSchema @ StructType(_), _, _)) => + ExtractNestedArrayField(projection, + projSchema.fieldIndex(field.name), + projSchema.fields.length, + projSchema(field.name), + containsNull, + containsNullSeq) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for ExtractNestedArrayField: ${projSchema.toString}" + ) + } case MapKeys(child) => getProjection(child).map { projection => MapKeys(projection) } case MapValues(child) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index f2acb75ea6ac4..3e206efef339e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -65,7 +65,8 @@ object SelectedField { /** * Convert an expression into the parts of the schema (the field) it accesses. */ - private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): Option[StructField] = { + private def selectField(expr: Expression, dataTypeOpt: Option[DataType], + nestArray: Boolean = false): Option[StructField] = { expr match { case a: Attribute => dataTypeOpt.map { dt => @@ -81,16 +82,37 @@ object SelectedField { // GetArrayStructFields is the top level extractor. This means its result is // not pruned and we need to use the element type of the array its producing. field.dataType - case Some(ArrayType(dataType, _)) => + case Some(ArrayType(dataType, nullable)) => // GetArrayStructFields is part of a chain of extractors and its result is pruned // by a parent expression. In this case need to use the parent element type. - dataType + if (nestArray) ArrayType(dataType, nullable) else dataType case Some(x) => // This should not happen. throw new AnalysisException(s"DataType '$x' is not supported by GetArrayStructFields.") } val newField = StructField(field.name, newFieldDataType, field.nullable) selectField(child, Option(ArrayType(struct(newField), containsNull))) + case ExtractNestedArrayField(child, _, _, field @ StructField(_, _, _, _), _, _) => + val newFieldDataType = dataTypeOpt match { + case None => + // ExtractNestedArrayField is the top level extractor. This means its result is + // not pruned and we need to use the element type of the array its producing. + field.dataType + case Some(dataType) => + dataType + } + val structType = struct(StructField(field.name, newFieldDataType, field.nullable)) + + val newDataType = child match { + case ExtractNestedArrayField(_, _, _, childField, containsNull, _) => + childField.dataType match { + case _: ArrayType => ArrayType(structType, containsNull) + case _ => structType + } + case GetArrayStructFields(_, _, _, _, nullable) => ArrayType(structType, nullable) + case _ => structType + } + selectField(child, Some(newDataType), nestArray = true) case GetMapValue(child, _, _) => // GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 767650d022200..6bdb1f05d2895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, + CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,12 +37,13 @@ object ExtractValue { * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`, * depend on the type of `child` and `extraction`. * - * `child` | `extraction` | concrete `ExtractValue` - * ---------------------------------------------------------------- - * Struct | Literal String | GetStructField - * Array[Struct] | Literal String | GetArrayStructFields - * Array | Integral type | GetArrayItem - * Map | map key type | GetMapValue + * `child` | `extraction` | concrete `ExtractValue` + * -------------------------------------------------------------------------------- + * Struct | Literal String | GetStructField + * Array[Struct] | Literal String | GetArrayStructFields + * Array[ ...Array[struct] ] | Literal String | ExtractNestedArrayField + * Array | Integral type | GetArrayItem + * Map | map key type | GetMapValue */ def apply( child: Expression, @@ -60,6 +62,13 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull || fields(ordinal).nullable) + case (ExtractNestedArrayType(StructType(fields), containsNull, containsNullSeq), + NonNullLiteral(v, StringType)) if containsNullSeq.nonEmpty => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + ExtractNestedArrayField(child, ordinal, fields.length, + fields(ordinal).copy(name = fieldName), containsNull, containsNullSeq) + case (_: ArrayType, _) => GetArrayItem(child, extraction) case (MapType(kt, _, _), _) => GetMapValue(child, extraction) @@ -218,6 +227,85 @@ case class GetArrayStructFields( } } +/** + * ExtractNestedArrayType is used to match consecutive nested array types. + * + * ReturnType: (DataType: the innermost dataType, Boolean: the outermost array contains null + * , Seq[Boolean]: the second outer layer to the innermost layer contains null) + * + */ +object ExtractNestedArrayType { + type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] + + def unapply(dataType: DataType): ReturnType = { + dataType match { + case ArrayType(dt, containsNull) => + unapply(dt) match { + case Some((d, cn, seq)) => Some((d, containsNull, cn +: seq)) + case None => Some((dt, containsNull, Seq.empty[Boolean])) + } + case _ => None + } + } +} + +/** + * For a child whose data type is a nested array containing struct at the innermost level, extracts + * the `ordinal`-th fields of multi-level nested array, and returns them as a new nested array. + */ +case class ExtractNestedArrayField( + child: Expression, + ordinal: Int, + numFields: Int, + field: StructField, + containsNull: Boolean, + containsNullSeq: Seq[Boolean]) extends UnaryExpression + with ExtractValue with NullIntolerant with CodegenFallback { + + protected override def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData] + new GenericArrayData( + (0 until array.numElements()).map(n => evalArrayItem(n, array, containsNullSeq.size))) + } + + private def evalArrayItem(original: Int, array: ArrayData, num: Int): ArrayData = { + if (array.isNullAt(original)) { + null + } + else { + val innerArray = array.get(original, nestedArrayType(num)).asInstanceOf[ArrayData] + new GenericArrayData((0 until innerArray.numElements()).map(n => { + if (num == 1) { + extractStruct(n, innerArray) + } + else { + evalArrayItem(n, innerArray, num - 1) + } + })) + } + } + + private def extractStruct(n: Int, array: ArrayData): Any = { + if (array.isNullAt(n)) { + null + } else { + val row = array.getStruct(n, numFields) + if (row.isNullAt(ordinal)) { + null + } else { + row.get(ordinal, field.dataType) + } + } + } + + override def dataType: DataType = ArrayType(nestedArrayType(0), containsNull) + + def nestedArrayType(num: Int): DataType = { + (num until containsNullSeq.size).reverse + .foldLeft(field.dataType) { (e, i) => ArrayType(e, containsNullSeq(i))} + } +} + /** * Returns the field at `ordinal` in the Array `child`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index bdcf7230e3211..85440c6dadb9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -117,4 +117,32 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema) checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null))) } + + test("SPARK-32002: Support ExtractValue from nested ArrayStruct") { + val jsonStr1 = """{"a": [{"b": [{"c": [1,2]}]}]}""" + val jsonStr2 = """{"a": [{"b": [{"c": [1]}, {"c": [2]}]}]}""" + val df = spark.read.json(Seq(jsonStr1, jsonStr2).toDS()) + checkAnswer(df.select($"a.b.c"), Row(Seq(Seq(Seq(1, 2)))) + :: Row(Seq(Seq(Seq(1), Seq(2)))) :: Nil) + + def genJson(start: Char, end: Char, vStr: String): String = { + (start to end).map(c => s"""{"$c": [""").mkString + + vStr + (start to end).map(_ => "]}").mkString + } + + def genResult(start: Char, end: Char, r: Seq[Int]): Any = { + (start until end).fold(r) { (z, _) => Seq(z)} + } + + val start: Char = 'a' + for (i <- 2 to 10) { + val end: Char = (start + i).toChar + val json = genJson(start, end, "1,2,3") + val df = spark.read.json(Seq(json).toDS()) + checkAnswer(df.select((start to end).mkString(".")), + Row(genResult(start, end, Seq(1, 2, 3)))) + } + + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala new file mode 100644 index 0000000000000..99c5931b2c31b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType + + +class NestArraySchemaPruningSuite + extends QueryTest + with FileBasedDataSourceTest + with SchemaPruningTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + case class AdRecord(positions: Array[Positions]) + case class Positions(imps: Array[Impression]) + case class Impression(id: String, ad: Advertising, clicks: Array[Clicks]) + case class Advertising(index: Int) + case class Clicks(fraud_type: Int) + + val adRecords = AdRecord(Array(Positions(Array(Impression("1", Advertising(1), + Array(Clicks(0), Clicks(1))))))) :: AdRecord(Array(Positions(Array( + Impression("2", Advertising(2), Array(Clicks(1), Clicks(2))))))) :: Nil + + testSchemaPruning("Nested arrays for pruning schema") { + val queryIndex = sql("select positions.imps.ad.index from adRecords") + checkScan(queryIndex, + "struct>>>>>") + checkAnswer(queryIndex, Row(Seq(Seq(1))) :: Row(Seq(Seq(2))) :: Nil) + + val queryId = sql("select positions.imps.id from adRecords") + checkScan(queryId, + "struct>>>>") + checkAnswer(queryId, Row(Seq(Seq("1"))) :: Row(Seq(Seq("2"))) :: Nil) + + val queryIndexAndFraud = + sql("select positions.imps.ad.index, positions.imps.clicks.fraud_type from adRecords") + checkScan(queryIndexAndFraud, "struct, clicks:array>>>>>>") + checkAnswer(queryIndexAndFraud, Row(Seq(Seq(1)), Seq(Seq(Seq(0, 1)))) + :: Row(Seq(Seq(2)), Seq(Seq(Seq(1, 2)))) :: Nil) + } + + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { + test(s"$testName") { + withSQLConf(vectorizedReaderEnabledKey -> "true") { + withData(testThunk) + } + withSQLConf(vectorizedReaderEnabledKey -> "false") { + withData(testThunk) + } + } + } + + private def withData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(adRecords, new File(path + "/ad_records/a=1")) + + val schema = "`positions` ARRAY, `clicks`: ARRAY>>>>>" + spark.read.format(dataSourceName).schema(schema).load(path + "/ad_records") + .createOrReplaceTempView("adRecords") + + testThunk + } + } + + protected val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + collect(df.queryExecution.executedPlan) { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } + + override protected val dataSourceName: String = "parquet" + override protected val vectorizedReaderEnabledKey: String = + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key +}