diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f04502d113ac..ebd96033c19f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -310,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) hadoopConf.setBoolean( SQLConf.CASE_SENSITIVE.key, sparkSession.sessionState.conf.caseSensitiveAnalysis) @@ -424,11 +427,12 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow + val readSupport = new ParquetReadSupport(convertTz, usingVectorizedReader = false) val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter) + new ParquetRecordReader[UnsafeRow](readSupport, parquetFilter) } else { - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) + new ParquetRecordReader[UnsafeRow](readSupport) } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 3319e73f2b31..a5ff94298180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.types._ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. */ -private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) +private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone], + usingVectorizedReader: Boolean) extends ReadSupport[UnsafeRow] with Logging { private var catalystRequestedSchema: StructType = _ @@ -57,7 +58,7 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only // used in the vectorized reader, where we get the convertTz value directly, and the value here // is ignored. - this(None) + this(None, usingVectorizedReader = true) } /** @@ -65,18 +66,65 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) * readers. Responsible for figuring out Parquet requested schema used for column pruning. */ override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration catalystRequestedSchema = { - val conf = context.getConfiguration val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) } - val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) + val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) - val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( - context.getFileSchema, catalystRequestedSchema, caseSensitive) - + val parquetFileSchema = context.getFileSchema + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, + catalystRequestedSchema, caseSensitive) + + // As part of schema clipping, we add fields in catalystRequestedSchema which are missing + // from parquetFileSchema to parquetClippedSchema. However, nested schema pruning requires + // we ignore unrequested field data when reading from a Parquet file. Therefore we pass two + // schema to ParquetRecordMaterializer: the schema of the file data we want to read + // (parquetRequestedSchema), and the schema of the rows we want to return + // (catalystRequestedSchema). The reader is responsible for reconciling the differences between + // the two. + // + // Aside from checking whether schema pruning is enabled (schemaPruningEnabled), there + // is an additional complication to constructing parquetRequestedSchema. The manner in which + // Spark's two Parquet readers reconcile the differences between parquetRequestedSchema and + // catalystRequestedSchema differ. Spark's vectorized reader does not (currently) support + // reading Parquet files with complex types in their schema. Further, it assumes that + // parquetRequestedSchema includes all fields requested in catalystRequestedSchema. It includes + // logic in its read path to skip fields in parquetRequestedSchema which are not present in the + // file. + // + // Spark's parquet-mr based reader supports reading Parquet files of any kind of complex + // schema, and it supports nested schema pruning as well. Unlike the vectorized reader, the + // parquet-mr reader requires that parquetRequestedSchema include only those fields present in + // the underlying parquetFileSchema. Therefore, in the case where we use the parquet-mr reader + // we intersect the parquetClippedSchema with the parquetFileSchema to construct the + // parquetRequestedSchema set in the ReadContext. + val parquetRequestedSchema = + if (schemaPruningEnabled && !usingVectorizedReader) { + ParquetReadSupport.intersectParquetGroups(parquetClippedSchema, parquetFileSchema) + .map(intersectionGroup => + new MessageType(intersectionGroup.getName, intersectionGroup.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + parquetClippedSchema + } + log.debug { + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$parquetFileSchema + |Parquet clipped schema: + |$parquetClippedSchema + |Parquet requested schema: + |$parquetRequestedSchema + |Catalyst requested schema: + |${catalystRequestedSchema.treeString} + """.stripMargin + } new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -90,16 +138,15 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) keyValueMetaData: JMap[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[UnsafeRow] = { - log.debug(s"Preparing for read Parquet file with message type: $fileSchema") val parquetRequestedSchema = readContext.getRequestedSchema - - logInfo { - s"""Going to read the following fields from the Parquet file: - | - |Parquet form: + log.debug { + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$fileSchema + |Parquet read schema: |$parquetRequestedSchema - |Catalyst form: - |$catalystRequestedSchema + |Catalyst read schema: + |${catalystRequestedSchema.treeString} """.stripMargin } @@ -322,6 +369,27 @@ private[parquet] object ParquetReadSupport { } } + /** + * Computes the structural intersection between two Parquet group types. + */ + private def intersectParquetGroups( + groupType1: GroupType, groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + intersectParquetGroups(field1, groupType2.getType(field1.getName).asGroupType) + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + def expandUDT(schema: StructType): StructType = { def expand(dataType: DataType): DataType = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 119972594184..2405dd0aba1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -130,8 +130,8 @@ private[parquet] class ParquetRowConverter( extends ParquetGroupConverter(updater) with Logging { assert( - parquetType.getFieldCount == catalystType.length, - s"""Field counts of the Parquet schema and the Catalyst schema don't match: + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: | |Parquet schema: |$parquetType @@ -182,10 +182,11 @@ private[parquet] class ParquetRowConverter( // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { - parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { - case ((parquetFieldType, catalystField), ordinal) => - // Converted field value should be set to the `ordinal`-th cell of `currentRow` - newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) + parquetType.getFields.asScala.map { parquetField => + val fieldIndex = catalystType.fieldIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) }.toArray } @@ -193,7 +194,7 @@ private[parquet] class ParquetRowConverter( override def end(): Unit = { var i = 0 - while (i < currentRow.numFields) { + while (i < fieldConverters.length) { fieldConverters(i).updater.end() i += 1 } @@ -202,8 +203,12 @@ private[parquet] class ParquetRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.numFields) { + while (i < fieldConverters.length) { fieldConverters(i).updater.start() + i += 1 + } + i = 0 + while (i < currentRow.numFields) { currentRow.setNullAt(i) i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index 434c4414edeb..12f8f02f65af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -130,7 +130,7 @@ class ParquetSchemaPruningSuite Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) } - ignore("partial schema intersection - select missing subfield") { + testSchemaPruning("partial schema intersection - select missing subfield") { val query = sql("select name.middle, address from contacts where p=2") checkScan(query, "struct,address:string>") checkAnswer(query.orderBy("id"),