From 73aa26baabd79e845dc6d52f5862c5177885f99f Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 13 Oct 2021 00:31:29 -0700 Subject: [PATCH 1/7] Aggregate push down for ORC reader --- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../apache/spark/sql/types/StructType.scala | 2 +- .../datasources/orc/OrcColumnsStatistics.java | 56 ++++++ .../datasources/orc/OrcFooterReader.java | 67 +++++++ .../datasources/AggregatePushDownUtils.scala | 140 +++++++++++++++ .../datasources/orc/OrcDeserializer.scala | 16 ++ .../execution/datasources/orc/OrcUtils.scala | 132 +++++++++++++- .../datasources/parquet/ParquetUtils.scala | 41 ----- .../v2/orc/OrcPartitionReaderFactory.scala | 94 ++++++++-- .../datasources/v2/orc/OrcScan.scala | 46 ++++- .../datasources/v2/orc/OrcScanBuilder.scala | 57 +++++- .../ParquetPartitionReaderFactory.scala | 14 +- .../datasources/v2/parquet/ParquetScan.scala | 12 +- .../v2/parquet/ParquetScanBuilder.scala | 85 +++------ .../org/apache/spark/sql/FileScanSuite.scala | 2 +- ...=> FileSourceAggregatePushDownSuite.scala} | 165 ++++++++++-------- 16 files changed, 710 insertions(+), 228 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/{parquet/ParquetAggregatePushDownSuite.scala => FileSourceAggregatePushDownSuite.scala} (80%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5023b4a616555..b992f5547cd91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -960,6 +960,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown") + .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + + " down to ORC for optimization. MAX/MIN for complex types can't be pushed down") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") .doc("When true, the Orc data source merges schemas collected from all data files, " + "otherwise the schema is picked from a random data file.") @@ -3698,6 +3705,8 @@ class SQLConf extends Serializable with Logging { def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED) + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 205b08f680aee..6707fb2071d00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java new file mode 100644 index 0000000000000..8ce7e013f8228 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, + * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, + * according to data types. This is used for aggregate push down in ORC. + */ +public class OrcColumnsStatistics { + private final ColumnStatistics statistics; + private final List children; + + public OrcColumnsStatistics(ColumnStatistics statistics) { + this.statistics = statistics; + this.children = new ArrayList<>(); + } + + public ColumnStatistics getStatistics() { + return statistics; + } + + public OrcColumnsStatistics get(int ordinal) { + if (ordinal < 0 || ordinal >= children.size()) { + throw new IndexOutOfBoundsException( + String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size())); + } + return children.get(ordinal); + } + + public void add(OrcColumnsStatistics newChild) { + children.add(newChild); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java new file mode 100644 index 0000000000000..14773500f075d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Queue; + +/** + * `OrcFooterReader` is a util class which encapsulates the helper + * methods of reading ORC file footer. + */ +public class OrcFooterReader { + + /** + * Read the columns statistics from ORC file footer. + * + * @param orcReader the reader to read ORC file footer. + * @return Statistics for all columns in the file. + */ + public static OrcColumnsStatistics readStatistics(Reader orcReader) { + TypeDescription orcSchema = orcReader.getSchema(); + ColumnStatistics[] orcStatistics = orcReader.getStatistics(); + StructType dataType = OrcUtils.toCatalystSchema(orcSchema); + return convertStatistics(dataType, new LinkedList<>(Arrays.asList(orcStatistics))); + } + + /** + * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnsStatistics}. + * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order. + */ + private static OrcColumnsStatistics convertStatistics( + DataType dataType, Queue orcStatistics) { + OrcColumnsStatistics statistics = new OrcColumnsStatistics(orcStatistics.remove()); + if (dataType instanceof StructType) { + for (StructField field : ((StructType) dataType).fields()) { + statistics.add(convertStatistics(field.dataType(), orcStatistics)); + } + } else if (dataType instanceof MapType) { + statistics.add(convertStatistics(((MapType) dataType).keyType(), orcStatistics)); + statistics.add(convertStatistics(((MapType) dataType).valueType(), orcStatistics)); + } else if (dataType instanceof ArrayType) { + statistics.add(convertStatistics(((ArrayType) dataType).elementType(), orcStatistics)); + } + return statistics; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala new file mode 100644 index 0000000000000..4d3fcf0d8f0ff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNameSet: Set[String], + dataFilters: Seq[Expression], + isAllowedTypeForMinMaxAggregate: DataType => Boolean, + sparkSession: SparkSession): Option[StructType] = { + + var finalSchema = new StructType() + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.apply(col.fieldNames.head) + } + + def isPartitionCol(col: NamedReference) = { + partitionNameSet.contains(col.fieldNames.head) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + if (isAllowedTypeForMinMaxAggregate(structField.dataType)) { + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + } else { + false + } + } + + if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { + // Parquet/ORC footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // However, if the filter is on partition column, max/min/count can still be pushed down + // Todo: add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + return None + } + + aggregation.groupByColumns.foreach { col => + if (col.fieldNames.length != 1) return None + finalSchema = finalSchema.add(getStructFieldForCol(col)) + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return None + case min: Min => + if (!processMinOrMax(min)) return None + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return None + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return None + } + + Some(finalSchema) + } + + /** + * Check if two Aggregation `a` and `b` is equal or not. + */ + def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + } + + /** + * Convert the aggregates result from `InternalRow` to `ColumnarBatch`. + * This is used for columnar reader. + */ + def convertAggregatesRowToBatch( + aggregatesAsRow: InternalRow, + aggregatesSchema: StructType, + offHeap: Boolean): ColumnarBatch = { + val converter = new RowToColumnConverter(aggregatesSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggregatesSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggregatesSchema) + } + converter.convert(aggregatesAsRow, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 1476083bc3d49..91408332b8624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -68,6 +68,22 @@ class OrcDeserializer( resultRow } + def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): InternalRow = { + var targetColumnIndex = 0 + while (targetColumnIndex < fieldWriters.length) { + if (fieldWriters(targetColumnIndex) != null) { + val value = orcValues(requestedColIds(targetColumnIndex)) + if (value == null) { + resultRow.setNullAt(targetColumnIndex) + } else { + fieldWriters(targetColumnIndex)(value) + } + } + targetColumnIndex += 1 + } + resultRow + } + /** * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 475448a9cc0e2..9ced33898736b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -25,17 +25,21 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.hadoop.hive.serde2.io.{DateWritable, HiveDecimalWritable} +import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, Text, WritableComparable} +import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DecimalColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, StringColumnStatistics, TypeDescription, Writer} -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.SchemaMergeUtils +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SchemaMergeUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -87,7 +91,7 @@ object OrcUtils extends Logging { } } - private def toCatalystSchema(schema: TypeDescription): StructType = { + def toCatalystSchema(schema: TypeDescription): StructType = { import TypeDescription.Category def toCatalystType(orcType: TypeDescription): DataType = { @@ -377,4 +381,124 @@ object OrcUtils extends Logging { case _ => false } } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + isCaseSensitive: Boolean): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + val columnsStatistics = OrcFooterReader.readStatistics(reader) + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take type $dataType" + + "for DoubleColumnStatistics") + } + case s: DecimalColumnStatistics => + new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) + case s: StringColumnStatistics => + new Text(if (isMax) s.getMaximum else s.getMinimum) + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + value + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + value + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields + .map(PartitioningUtils.getColName(_, isCaseSensitive)) + .contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also returns + // number of non-null and null values for its top-level + // ColumnStatistics.getNumberOfValues(). + val nonNullRowsCount = if (isPartitionColumn) { + val topLevelStatistics = columnsStatistics.getStatistics + if (topLevelStatistics.hasNull) { + throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + + s"values: $topLevelStatistics. Aggregate expression: $count") + } + topLevelStatistics.getNumberOfValues + } else { + getColumnStatistics(columnName).getNumberOfValues + } + new LongWritable(nonNullRowsCount) + case (_: CountStar, _) => + // Count(*) includes both null and non-null values. + val topLevelStatistics = columnsStatistics.getStatistics + if (topLevelStatistics.hasNull) { + throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + + s"values: $topLevelStatistics. Aggregate expression: Count(*)") + } + new LongWritable(topLevelStatistics.getNumberOfValues) + case (x, _) => + throw new IllegalArgumentException( + s"createAggInternalRowFromFooter should not take $x as the aggregate expression") + } + + val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray) + orcValuesDeserializer.deserializeFromValues(aggORCValues) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 1093f9c5aa51b..0e4b9283d4866 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -32,12 +32,9 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} -import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.PartitioningUtils -import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -201,44 +198,6 @@ object ParquetUtils { converter.currentRecord } - /** - * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of - * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader - * to read data from Parquet and aggregate at Spark layer. Instead we want - * to get the aggregates (Max/Min/Count) result using the statistics information - * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. - * - * @return Aggregate results in the format of ColumnarBatch - */ - private[sql] def createAggColumnarBatchFromFooter( - footer: ParquetMetadata, - filePath: String, - dataSchema: StructType, - partitionSchema: StructType, - aggregation: Aggregation, - aggSchema: StructType, - offHeap: Boolean, - datetimeRebaseMode: LegacyBehaviorPolicy.Value, - isCaseSensitive: Boolean): ColumnarBatch = { - val row = createAggInternalRowFromFooter( - footer, - filePath, - dataSchema, - partitionSchema, - aggregation, - aggSchema, - datetimeRebaseMode, - isCaseSensitive) - val converter = new RowToColumnConverter(aggSchema) - val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(1, aggSchema) - } else { - OnHeapColumnVector.allocateColumns(1, aggSchema) - } - converter.convert(row, columnVectors.toArray) - new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) - } - /** * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics * information from Parquet footer file. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index c5020cb79524c..be1a3045e3a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -23,15 +23,16 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription} import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -55,7 +56,8 @@ case class OrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + filters: Array[Filter], + aggregation: Option[Aggregation]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -81,17 +83,14 @@ case class OrcPartitionReaderFactory( override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -128,17 +127,14 @@ case class OrcPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildColumnarReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -173,4 +169,68 @@ case class OrcPartitionReaderFactory( } } + private def createORCReader(filePath: Path, conf: Configuration): Reader = { + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) + + pushDownPredicates(filePath, conf) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + OrcFile.createReader(filePath, readerOptions) + } + + /** + * Build reader with aggregate push down. + */ + private def buildReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[InternalRow] = { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema, isCaseSensitive) + } + } + + override def next(): Boolean = hasNext + + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } + } + + /** + * Build columnar reader with aggregate push down. + */ + private def buildColumnarReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[ColumnarBatch] = { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private lazy val batch: ColumnarBatch = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + val row = OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema, isCaseSensitive) + AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) + } + } + + override def next(): Boolean = hasNext + + override def get(): ColumnarBatch = { + hasNext = false + batch + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 7619e3c503139..ac57c9dc6cdfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path - import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -37,10 +37,25 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true + override def isSplitable(path: Path): Boolean = { + // If aggregate is pushed down, only the file footer will be read once, + // so file should be not split across multiple tasks. + pushedAggregate.isEmpty + } + + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `OrcScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) { + readDataSchema + } else { + super.readSchema() + } + } override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( @@ -48,24 +63,39 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate) } override def equals(obj: Any): Boolean = obj match { case o: OrcScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && o.pushedAggregate.nonEmpty) { + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, o.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && o.pushedAggregate.isEmpty + } super.equals(o) && dataSchema == o.dataSchema && options == o.options && - equivalentFilters(pushedFilters, o.pushedFilters) - + equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index cfa396f5482f4..0636edd9b4d12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcScanBuilder( @@ -35,18 +36,31 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates { + lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), options, pushedAggregations, pushedDataFilters, partitionFilters, + dataFilters) } override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { @@ -58,4 +72,35 @@ case class OrcScanBuilder( Array.empty[Filter] } } + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.orcAggregatePushDown) { + return false + } + + def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { + dataType match { + // Not push down complex and Timestamp type. + // Not push down Binary type as ORC does not write min/max statistics for it. + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType => + false + case _ => true + } + } + + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters, + isAllowedTypeForMinMaxAggregate, + sparkSession) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 111018b579ed2..afc4e46dfb115 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -175,24 +175,26 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[ColumnarBatch] { private var hasNext = true - private val row: ColumnarBatch = { + private val batch: ColumnarBatch = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, - partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, + val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, + dataSchema, partitionSchema, aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + AggregatePushDownUtils.convertAggregatesRowToBatch( + row, readDataSchema, enableOffHeapColumnVector) } else { null } } override def next(): Boolean = { - hasNext && row != null + hasNext && batch != null } override def get(): ColumnarBatch = { hasNext = false - row + batch } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 42dc287f73129..084aae8c6afca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetInputFormat - import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf @@ -101,7 +99,7 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { - equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) } else { pushedAggregate.isEmpty && p.pushedAggregate.isEmpty } @@ -130,10 +128,4 @@ case class ParquetScan( Map("PushedAggregation" -> pushedAggregationsStr) ++ Map("PushedGroupBy" -> pushedGroupByStr) } - - private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { - a.aggregateExpressions.sortBy(_.hashCode()) - .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && - a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index da4938134785d..939146e57278c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -87,30 +86,12 @@ case class ParquetScanBuilder( override def pushedFilters(): Array[Filter] = pushedParquetFilters override def pushAggregation(aggregation: Aggregation): Boolean = { - - def getStructFieldForCol(col: NamedReference): StructField = { - schema.nameToField(col.fieldNames.head) - } - - def isPartitionCol(col: NamedReference) = { - partitionNameSet.contains(col.fieldNames.head) + if (!sparkSession.sessionState.conf.parquetAggregatePushDown) { + return false } - def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") - } - - if (isPartitionCol(column)) { - // don't push down partition column, footer doesn't have max/min for partition column - return false - } - val structField = getStructFieldForCol(column) - - structField.dataType match { + def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { + dataType match { // not push down complex type // not push down Timestamp because INT96 sort order is undefined, // Parquet doesn't return statistics for INT96 @@ -119,55 +100,35 @@ case class ParquetScanBuilder( // could be Spark StringType, BinaryType or DecimalType case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType => - finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) true case _ => false } } - if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { - // Parquet footer has max/min/count for columns - // e.g. SELECT COUNT(col1) FROM t - // but footer doesn't have max/min/count for a column if max/min/count - // are combined with filter or group by - // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 - // SELECT COUNT(col1) FROM t GROUP BY col2 - // However, if the filter is on partition column, max/min/count can still be pushed down - // Todo: add support if groupby column is partition col - // (https://issues.apache.org/jira/browse/SPARK-36646) - return false - } - - aggregation.groupByColumns.foreach { col => - if (col.fieldNames.length != 1) return false - finalSchema = finalSchema.add(getStructFieldForCol(col)) + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters, + isAllowedTypeForMinMaxAggregate, + sparkSession) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false } - - aggregation.aggregateExpressions.foreach { - case max: Max => - if (!processMinOrMax(max)) return false - case min: Min => - if (!processMinOrMax(min)) return false - case count: Count => - if (count.column.fieldNames.length != 1 || count.isDistinct) return false - finalSchema = - finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) - case _: CountStar => - finalSchema = finalSchema.add(StructField("count(*)", LongType)) - case _ => - return false - } - this.pushedAggregations = Some(aggregation) - true } override def build(): Scan = { // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in readDataSchema() (in regular column pruning). These // two are mutual exclusive. - if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, partitionFilters, dataFilters) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 604a8927aa7af..14b59ba23d09f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -358,7 +358,7 @@ class FileScanSuite extends FileScanSuiteBase { Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => - OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df), Seq.empty), ("CSVScan", (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 77ecd288d400d..f9e1925989f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -15,33 +15,39 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.datasources import java.sql.{Date, Timestamp} import org.apache.spark.SparkConf -import org.apache.spark.sql._ +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.functions.min import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} /** - * A test suite that tests Max/Min/Count push down. + * A test suite that tests aggregate push down for Parquet and ORC. */ -abstract class ParquetAggregatePushDownSuite +trait FileSourceAggregatePushDownSuite extends QueryTest - with ParquetTest + with FileBasedDataSourceTest with SharedSparkSession with ExplainSuiteHelper { + import testImplicits._ - test("aggregate push down - nested column: Max(top level column) not push down") { + protected def format: String + // The SQL config key for enabling aggregate push down. + protected val aggPushDownEnabledKey: String + + test("nested column: Max(top level column) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -53,11 +59,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(top level column) push down") { + test("nested column: Count(top level column) push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -70,11 +75,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Max(nested column) not push down") { + test("nested column: Max(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey-> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1._2[0]) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -86,11 +90,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(nested column) not push down") { + test("nested column: Count(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1._2[0]) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -103,13 +106,13 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Max(partition Col): not push dow") { + test("Max(partition column): not push down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + withSQLConf(aggPushDownEnabledKey -> "true") { val max = sql("SELECT Max(p) FROM tmp") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -146,12 +149,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Filter alias over aggregate") { + test("filter alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -164,12 +166,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - alias over aggregate") { + test("alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -182,12 +183,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate over alias not push down") { + test("aggregate over alias not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val df = spark.table("t") val query = df.select($"_1".as("col1")).agg(min($"col1")) query.queryExecution.optimizedPlan.collect { @@ -201,12 +201,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - query with group by not push down") { + test("query with group by not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is group by val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") selectAgg.queryExecution.optimizedPlan.collect { @@ -220,12 +219,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate with data filter cannot be pushed down") { + test("aggregate with data filter cannot be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is filter val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") selectAgg.queryExecution.optimizedPlan.collect { @@ -239,14 +237,14 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate with partition filter can be pushed down") { + test("aggregate with partition filter can be pushed down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p = 0") max.queryExecution.optimizedPlan.collect { @@ -262,12 +260,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - push down only if all the aggregates can be pushed down") { + test("push down only if all the aggregates can be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // not push down since sum can't be pushed down val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -284,9 +281,8 @@ abstract class ParquetAggregatePushDownSuite test("aggregate push down - MIN/MAX/COUNT") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + " count(*), count(_1), count(_2), count(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -375,11 +371,11 @@ abstract class ParquetAggregatePushDownSuite val rdd = sparkContext.parallelize(rows) withTempPath { file => - spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath) withTempView("test") { - spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test") Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + @@ -389,7 +385,8 @@ abstract class ParquetAggregatePushDownSuite // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down // In addition, Parquet Binary min/max could be truncated, so we disable aggregate - // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType) + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. testMinWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -430,7 +427,8 @@ abstract class ParquetAggregatePushDownSuite // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down // In addition, Parquet Binary min/max could be truncated, so we disable aggregate - // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType) + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. testMaxWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -494,15 +492,15 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - column name case sensitivity") { + test("column name case sensitivity") { Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> enableVectorizedReader) { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -518,18 +516,41 @@ abstract class ParquetAggregatePushDownSuite } } +abstract class ParquetAggregatePushDownSuite + extends FileSourceAggregatePushDownSuite with ParquetTest { + + override def format: String = "parquet" + override protected val aggPushDownEnabledKey: String = + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key +} + class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") } class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") +} + +abstract class OrcAggregatePushDownSuite extends OrcTest with FileSourceAggregatePushDownSuite { + + override def format: String = "orc" + override protected val aggPushDownEnabledKey: String = + SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key +} + +class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc") +} + +class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") } From cca5a30bf539c4d3e1321201f62d28156635e613 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 15 Oct 2021 20:22:01 -0700 Subject: [PATCH 2/7] Fix style --- .../apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala | 1 + .../sql/execution/datasources/v2/parquet/ParquetScan.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index ac57c9dc6cdfd..6b9d181a7f4c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 084aae8c6afca..b92ed82190ae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetInputFormat + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation From 829cafbd618d5df72a3733bbc9e37b6bdb866fcb Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Sat, 16 Oct 2021 16:39:58 -0700 Subject: [PATCH 3/7] Address comment in AggregatePushDownUtils to remove unnecessary statements --- .../datasources/AggregatePushDownUtils.scala | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 4d3fcf0d8f0ff..8a48deed87a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -88,11 +88,6 @@ object AggregatePushDownUtils { return None } - aggregation.groupByColumns.foreach { col => - if (col.fieldNames.length != 1) return None - finalSchema = finalSchema.add(getStructFieldForCol(col)) - } - aggregation.aggregateExpressions.foreach { case max: Max => if (!processMinOrMax(max)) return None @@ -125,9 +120,9 @@ object AggregatePushDownUtils { * This is used for columnar reader. */ def convertAggregatesRowToBatch( - aggregatesAsRow: InternalRow, - aggregatesSchema: StructType, - offHeap: Boolean): ColumnarBatch = { + aggregatesAsRow: InternalRow, + aggregatesSchema: StructType, + offHeap: Boolean): ColumnarBatch = { val converter = new RowToColumnConverter(aggregatesSchema) val columnVectors = if (offHeap) { OffHeapColumnVector.allocateColumns(1, aggregatesSchema) From c3e1a127d7c2914fa48e0f1ad9b74559dad2301d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 22 Oct 2021 12:38:47 -0700 Subject: [PATCH 4/7] Addressed all comments --- ...atistics.java => OrcColumnStatistics.java} | 13 ++++---- .../datasources/orc/OrcFooterReader.java | 28 ++++++++--------- .../datasources/AggregatePushDownUtils.scala | 8 ++--- .../execution/datasources/orc/OrcUtils.scala | 30 +++++-------------- .../v2/orc/OrcPartitionReaderFactory.scala | 6 ++-- .../datasources/v2/orc/OrcScanBuilder.scala | 9 +++--- .../v2/parquet/ParquetScanBuilder.scala | 3 +- .../FileSourceAggregatePushDownSuite.scala | 15 +++++----- 8 files changed, 48 insertions(+), 64 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/{OrcColumnsStatistics.java => OrcColumnStatistics.java} (80%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java similarity index 80% rename from sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java index 8ce7e013f8228..db6dfc0ee29f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -27,13 +27,14 @@ * * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, - * according to data types. This is used for aggregate push down in ORC. + * according to data types. The flatten array stores all data types (including nested types) in + * tree pre-ordering. This is used for aggregate push down in ORC. */ -public class OrcColumnsStatistics { +public class OrcColumnStatistics { private final ColumnStatistics statistics; - private final List children; + private final List children; - public OrcColumnsStatistics(ColumnStatistics statistics) { + public OrcColumnStatistics(ColumnStatistics statistics) { this.statistics = statistics; this.children = new ArrayList<>(); } @@ -42,7 +43,7 @@ public ColumnStatistics getStatistics() { return statistics; } - public OrcColumnsStatistics get(int ordinal) { + public OrcColumnStatistics get(int ordinal) { if (ordinal < 0 || ordinal >= children.size()) { throw new IndexOutOfBoundsException( String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size())); @@ -50,7 +51,7 @@ public OrcColumnsStatistics get(int ordinal) { return children.get(ordinal); } - public void add(OrcColumnsStatistics newChild) { + public void add(OrcColumnStatistics newChild) { children.add(newChild); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java index 14773500f075d..74091e9bd8074 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -38,29 +38,29 @@ public class OrcFooterReader { * @param orcReader the reader to read ORC file footer. * @return Statistics for all columns in the file. */ - public static OrcColumnsStatistics readStatistics(Reader orcReader) { + public static OrcColumnStatistics readStatistics(Reader orcReader) { TypeDescription orcSchema = orcReader.getSchema(); ColumnStatistics[] orcStatistics = orcReader.getStatistics(); - StructType dataType = OrcUtils.toCatalystSchema(orcSchema); - return convertStatistics(dataType, new LinkedList<>(Arrays.asList(orcStatistics))); + StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema); + return convertStatistics(sparkSchema, new LinkedList<>(Arrays.asList(orcStatistics))); } /** - * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnsStatistics}. + * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnStatistics}. * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order. */ - private static OrcColumnsStatistics convertStatistics( - DataType dataType, Queue orcStatistics) { - OrcColumnsStatistics statistics = new OrcColumnsStatistics(orcStatistics.remove()); - if (dataType instanceof StructType) { - for (StructField field : ((StructType) dataType).fields()) { + private static OrcColumnStatistics convertStatistics( + DataType sparkSchema, Queue orcStatistics) { + OrcColumnStatistics statistics = new OrcColumnStatistics(orcStatistics.remove()); + if (sparkSchema instanceof StructType) { + for (StructField field : ((StructType) sparkSchema).fields()) { statistics.add(convertStatistics(field.dataType(), orcStatistics)); } - } else if (dataType instanceof MapType) { - statistics.add(convertStatistics(((MapType) dataType).keyType(), orcStatistics)); - statistics.add(convertStatistics(((MapType) dataType).valueType(), orcStatistics)); - } else if (dataType instanceof ArrayType) { - statistics.add(convertStatistics(((ArrayType) dataType).elementType(), orcStatistics)); + } else if (sparkSchema instanceof MapType) { + statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), orcStatistics)); + statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), orcStatistics)); + } else if (sparkSchema instanceof ArrayType) { + statistics.add(convertStatistics(((ArrayType) sparkSchema).elementType(), orcStatistics)); } return statistics; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 8a48deed87a2c..6e447c55939a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.NamedReference @@ -38,10 +37,9 @@ object AggregatePushDownUtils { def getSchemaForPushedAggregation( aggregation: Aggregation, schema: StructType, - partitionNameSet: Set[String], + partitionNames: Set[String], dataFilters: Seq[Expression], - isAllowedTypeForMinMaxAggregate: DataType => Boolean, - sparkSession: SparkSession): Option[StructType] = { + isAllowedTypeForMinMaxAggregate: DataType => Boolean): Option[StructType] = { var finalSchema = new StructType() @@ -50,7 +48,7 @@ object AggregatePushDownUtils { } def isPartitionCol(col: NamedReference) = { - partitionNameSet.contains(col.fieldNames.head) + partitionNames.contains(col.fieldNames.head) } def processMinOrMax(agg: AggregateFunc): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 9ced33898736b..1ffddc11247ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, Text, WritableComparable} import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DecimalColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, StringColumnStatistics, TypeDescription, Writer} -import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} +import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} @@ -392,7 +392,6 @@ object OrcUtils extends Logging { */ def createAggInternalRowFromFooter( reader: Reader, - filePath: String, dataSchema: StructType, partitionSchema: StructType, aggregation: Aggregation, @@ -425,7 +424,7 @@ object OrcUtils extends Logging { case IntegerType => new IntWritable(value.toInt) case LongType => new LongWritable(value) case _ => throw new IllegalArgumentException( - s"getMaxFromColumnStatistics should not take type $dataType" + + s"getMaxFromColumnStatistics should not take type $dataType " + "for IntegerColumnStatistics") } case s: DoubleColumnStatistics => @@ -456,14 +455,12 @@ object OrcUtils extends Logging { val columnName = max.column.fieldNames.head val statistics = getColumnStatistics(columnName) val dataType = aggSchema(index).dataType - val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) - value + getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) case (min: Min, index) => val columnName = min.column.fieldNames.head val statistics = getColumnStatistics(columnName) val dataType = aggSchema.apply(index).dataType - val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) - value + getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) case (count: Count, _) => val columnName = count.column.fieldNames.head val isPartitionColumn = partitionSchema.fields @@ -471,28 +468,17 @@ object OrcUtils extends Logging { .contains(columnName) // NOTE: Count(columnName) doesn't include null values. // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values - // for ColumnStatistics of individual column. In addition to this, ORC also returns - // number of non-null and null values for its top-level - // ColumnStatistics.getNumberOfValues(). + // for ColumnStatistics of individual column. In addition to this, ORC also stores number + // of all values (null and non-null) separately. val nonNullRowsCount = if (isPartitionColumn) { - val topLevelStatistics = columnsStatistics.getStatistics - if (topLevelStatistics.hasNull) { - throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + - s"values: $topLevelStatistics. Aggregate expression: $count") - } - topLevelStatistics.getNumberOfValues + columnsStatistics.getStatistics.getNumberOfValues } else { getColumnStatistics(columnName).getNumberOfValues } new LongWritable(nonNullRowsCount) case (_: CountStar, _) => // Count(*) includes both null and non-null values. - val topLevelStatistics = columnsStatistics.getStatistics - if (topLevelStatistics.hasNull) { - throw new SparkException(s"Illegal ORC top-level column statistics with NULL " + - s"values: $topLevelStatistics. Aggregate expression: Count(*)") - } - new LongWritable(topLevelStatistics.getNumberOfValues) + new LongWritable(columnsStatistics.getStatistics.getNumberOfValues) case (x, _) => throw new IllegalArgumentException( s"createAggInternalRowFromFooter should not take $x as the aggregate expression") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index be1a3045e3a3f..55ae71c8d66ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -190,8 +190,7 @@ case class OrcPartitionReaderFactory( private lazy val row: InternalRow = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.createAggInternalRowFromFooter( - reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, - readDataSchema, isCaseSensitive) + reader, dataSchema, partitionSchema, aggregation.get, readDataSchema, isCaseSensitive) } } @@ -217,8 +216,7 @@ case class OrcPartitionReaderFactory( private lazy val batch: ColumnarBatch = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => val row = OrcUtils.createAggInternalRowFromFooter( - reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, - readDataSchema, isCaseSensitive) + reader, dataSchema, partitionSchema, aggregation.get, readDataSchema, isCaseSensitive) AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 0636edd9b4d12..5ba116142cdbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StringType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcScanBuilder( @@ -82,7 +82,9 @@ case class OrcScanBuilder( dataType match { // Not push down complex and Timestamp type. // Not push down Binary type as ORC does not write min/max statistics for it. - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType => + // Not push down String type as ORC truncates min/max statistics for it. + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType | + StringType => false case _ => true } @@ -93,8 +95,7 @@ case class OrcScanBuilder( schema, partitionNameSet, dataFilters, - isAllowedTypeForMinMaxAggregate, - sparkSession) match { + isAllowedTypeForMinMaxAggregate) match { case Some(schema) => finalSchema = schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 939146e57278c..5d04549e37628 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -111,8 +111,7 @@ case class ParquetScanBuilder( schema, partitionNameSet, dataFilters, - isAllowedTypeForMinMaxAggregate, - sparkSession) match { + isAllowedTypeForMinMaxAggregate) match { case Some(schema) => finalSchema = schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index f9e1925989f4a..a053b0c9a9a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -111,7 +111,7 @@ trait FileSourceAggregatePushDownSuite spark.range(10).selectExpr("id", "id % 3 as p") .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") withSQLConf(aggPushDownEnabledKey -> "true") { val max = sql("SELECT Max(p) FROM tmp") max.queryExecution.optimizedPlan.collect { @@ -126,15 +126,16 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate push down - Count(partition Col): push down") { + test("Count(partition column): push down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); - Seq("false", "true").foreach { enableVectorizedReader => - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", - vectorizedReaderEnabledKey -> enableVectorizedReader) { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { val count = sql("SELECT COUNT(p) FROM tmp") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => From 9b0a39977f56a945109f7d9656e1b245c7594302 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 22 Oct 2021 13:30:58 -0700 Subject: [PATCH 5/7] Rebase to latest master --- .../datasources/AggregatePushDownUtils.scala | 24 ++++++++++++------- .../execution/datasources/orc/OrcUtils.scala | 10 +++----- .../datasources/v2/orc/OrcScanBuilder.scala | 17 ++----------- .../v2/parquet/ParquetScanBuilder.scala | 21 ++-------------- 4 files changed, 23 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 6e447c55939a1..6340d97af1a04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} -import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** @@ -38,8 +38,7 @@ object AggregatePushDownUtils { aggregation: Aggregation, schema: StructType, partitionNames: Set[String], - dataFilters: Seq[Expression], - isAllowedTypeForMinMaxAggregate: DataType => Boolean): Option[StructType] = { + dataFilters: Seq[Expression]): Option[StructType] = { var finalSchema = new StructType() @@ -65,11 +64,20 @@ object AggregatePushDownUtils { } val structField = getStructFieldForCol(column) - if (isAllowedTypeForMinMaxAggregate(structField.dataType)) { - finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) - true - } else { - false + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + // not push down Parquet Binary because min/max could be truncated + // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary + // could be Spark StringType, BinaryType or DecimalType. + // not push down for ORC with same reason. + case BooleanType | ByteType | ShortType | IntegerType + | LongType | FloatType | DoubleType | DateType => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + case _ => + false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 1ffddc11247ee..e1fc97e91bb8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -25,9 +25,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.hive.serde2.io.{DateWritable, HiveDecimalWritable} -import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, Text, WritableComparable} -import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DecimalColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, StringColumnStatistics, TypeDescription, Writer} +import org.apache.hadoop.hive.serde2.io.DateWritable +import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} +import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.deploy.SparkHadoopUtil @@ -436,10 +436,6 @@ object OrcUtils extends Logging { s"getMaxFromColumnStatistics should not take type $dataType" + "for DoubleColumnStatistics") } - case s: DecimalColumnStatistics => - new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum) - case s: StringColumnStatistics => - new Text(if (isMax) s.getMaximum else s.getMinimum) case s: DateColumnStatistics => new DateWritable( if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 5ba116142cdbc..d2c17fda4a382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcScanBuilder( @@ -78,24 +78,11 @@ case class OrcScanBuilder( return false } - def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { - dataType match { - // Not push down complex and Timestamp type. - // Not push down Binary type as ORC does not write min/max statistics for it. - // Not push down String type as ORC truncates min/max statistics for it. - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType | - StringType => - false - case _ => true - } - } - AggregatePushDownUtils.getSchemaForPushedAggregation( aggregation, schema, partitionNameSet, - dataFilters, - isAllowedTypeForMinMaxAggregate) match { + dataFilters) match { case Some(schema) => finalSchema = schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 5d04549e37628..74d11b62b4c26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, Spark import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -90,28 +90,11 @@ case class ParquetScanBuilder( return false } - def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = { - dataType match { - // not push down complex type - // not push down Timestamp because INT96 sort order is undefined, - // Parquet doesn't return statistics for INT96 - // not push down Parquet Binary because min/max could be truncated - // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary - // could be Spark StringType, BinaryType or DecimalType - case BooleanType | ByteType | ShortType | IntegerType - | LongType | FloatType | DoubleType | DateType => - true - case _ => - false - } - } - AggregatePushDownUtils.getSchemaForPushedAggregation( aggregation, schema, partitionNameSet, - dataFilters, - isAllowedTypeForMinMaxAggregate) match { + dataFilters) match { case Some(schema) => finalSchema = schema From 8c7c6178ae145190a6fae6fd2024946578362312 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 26 Oct 2021 16:36:54 -0700 Subject: [PATCH 6/7] Address all comments --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../datasources/orc/OrcColumnStatistics.java | 23 +++ .../execution/datasources/orc/OrcUtils.scala | 32 ++-- .../v2/orc/OrcPartitionReaderFactory.scala | 5 +- .../ParquetPartitionReaderFactory.scala | 2 +- .../FileSourceAggregatePushDownSuite.scala | 146 +++++++++++------- 6 files changed, 140 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b992f5547cd91..be5f9d42a8631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -961,8 +961,9 @@ object SQLConf { .createWithDefault(true) val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown") - .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + - " down to ORC for optimization. MAX/MIN for complex types can't be pushed down") + .doc("If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and " + + "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date " + + "type. For COUNT, support all data types.") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java index db6dfc0ee29f1..77d82fd3d7f2d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -29,6 +29,29 @@ * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, * according to data types. The flatten array stores all data types (including nested types) in * tree pre-ordering. This is used for aggregate push down in ORC. + * + * For nested data types (array, map and struct), the sub-field statistics are stored recursively + * inside parent column's `children` field. Here is an example of `OrcColumnStatistics`: + * + * Data schema: + * c1: int + * c2: struct + * c3: map + * c4: array + * + * OrcColumnStatistics + * | (children) + * --------------------------------------------- + * / | \ \ + * c1 c2 c3 c4 + * (integer) (struct) (map) (array) +* (min:1, | (children) | (children) | (children) + * max:10) ----- ----- element + * / \ / \ (integer) + * c2.f1 c2.f2 key value + * (integer) (float) (integer) (string) + * (min:0.1, (min:"a", + * max:100.5) max:"zzz") */ public class OrcColumnStatistics { private final ColumnStatistics statistics; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index e1fc97e91bb8c..56c63cb7e1977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SchemaMergeUtils} +import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -392,14 +392,21 @@ object OrcUtils extends Logging { */ def createAggInternalRowFromFooter( reader: Reader, + filePath: String, dataSchema: StructType, partitionSchema: StructType, aggregation: Aggregation, - aggSchema: StructType, - isCaseSensitive: Boolean): InternalRow = { + aggSchema: StructType): InternalRow = { require(aggregation.groupByColumns.length == 0, s"aggregate $aggregation with group-by column shouldn't be pushed down") - val columnsStatistics = OrcFooterReader.readStatistics(reader) + var columnsStatistics: OrcColumnStatistics = null + try { + columnsStatistics = OrcFooterReader.readStatistics(reader) + } catch { case e: RuntimeException => + throw new SparkException( + s"Cannot read columns statistics in file: $filePath. Please consider disabling " + + s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e) + } // Get column statistics with column name. def getColumnStatistics(columnName: String): ColumnStatistics = { @@ -408,10 +415,15 @@ object OrcUtils extends Logging { } // Get Min/Max statistics and store as ORC `WritableComparable` format. + // Return null if number of non-null values is zero. def getMinMaxFromColumnStatistics( statistics: ColumnStatistics, dataType: DataType, isMax: Boolean): WritableComparable[_] = { + if (statistics.getNumberOfValues == 0) { + return null + } + statistics match { case s: BooleanColumnStatistics => val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) @@ -424,7 +436,7 @@ object OrcUtils extends Logging { case IntegerType => new IntWritable(value.toInt) case LongType => new LongWritable(value) case _ => throw new IllegalArgumentException( - s"getMaxFromColumnStatistics should not take type $dataType " + + s"getMinMaxFromColumnStatistics should not take type $dataType " + "for IntegerColumnStatistics") } case s: DoubleColumnStatistics => @@ -433,14 +445,14 @@ object OrcUtils extends Logging { case FloatType => new FloatWritable(value.toFloat) case DoubleType => new DoubleWritable(value) case _ => throw new IllegalArgumentException( - s"getMaxFromColumnStatistics should not take type $dataType" + + s"getMinMaxFromColumnStatistics should not take type $dataType " + "for DoubleColumnStatistics") } case s: DateColumnStatistics => new DateWritable( if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) case _ => throw new IllegalArgumentException( - s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"getMinMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + s"$statistics as the ORC column statistics") } } @@ -459,9 +471,7 @@ object OrcUtils extends Logging { getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) case (count: Count, _) => val columnName = count.column.fieldNames.head - val isPartitionColumn = partitionSchema.fields - .map(PartitioningUtils.getColName(_, isCaseSensitive)) - .contains(columnName) + val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) // NOTE: Count(columnName) doesn't include null values. // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values // for ColumnStatistics of individual column. In addition to this, ORC also stores number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 55ae71c8d66ff..246f160cfb79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -190,7 +190,7 @@ case class OrcPartitionReaderFactory( private lazy val row: InternalRow = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.createAggInternalRowFromFooter( - reader, dataSchema, partitionSchema, aggregation.get, readDataSchema, isCaseSensitive) + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) } } @@ -216,7 +216,8 @@ case class OrcPartitionReaderFactory( private lazy val batch: ColumnarBatch = { Utils.tryWithResource(createORCReader(filePath, conf)) { reader => val row = OrcUtils.createAggInternalRowFromFooter( - reader, dataSchema, partitionSchema, aggregation.get, readDataSchema, isCaseSensitive) + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema) AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index afc4e46dfb115..6f021ff2e97f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -182,7 +182,7 @@ case class ParquetPartitionReaderFactory( dataSchema, partitionSchema, aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) AggregatePushDownUtils.convertAggregatesRowToBatch( - row, readDataSchema, enableOffHeapColumnVector) + row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined) } else { null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index a053b0c9a9a62..a3d01e483209a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -128,7 +128,7 @@ trait FileSourceAggregatePushDownSuite test("Count(partition column): push down") { withTempPath { dir => - spark.range(10).selectExpr("id", "id % 3 as p") + spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as p") .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") @@ -305,7 +305,13 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate push down - different data types") { + private def testPushDownForAllDataTypes( + inputRows: Seq[Row], + expectedMinWithAllTypes: Seq[Row], + expectedMinWithOutTSAndBinary: Seq[Row], + expectedMaxWithAllTypes: Seq[Row], + expectedMaxWithOutTSAndBinary: Seq[Row], + expectedCount: Seq[Row]): Unit = { implicit class StringToDate(s: String) { def date: Date = Date.valueOf(s) } @@ -314,49 +320,6 @@ trait FileSourceAggregatePushDownSuite def ts: Timestamp = Timestamp.valueOf(s) } - val rows = - Seq( - Row( - "a string", - true, - 10.toByte, - "Spark SQL".getBytes, - 12.toShort, - 3, - Long.MaxValue, - 0.15.toFloat, - 0.75D, - Decimal("12.345678"), - ("2021-01-01").date, - ("2015-01-01 23:50:59.123").ts), - Row( - "test string", - false, - 1.toByte, - "Parquet".getBytes, - 2.toShort, - null, - Long.MinValue, - 0.25.toFloat, - 0.85D, - Decimal("1.2345678"), - ("2015-01-01").date, - ("2021-01-01 23:50:59.123").ts), - Row( - null, - true, - 10000.toByte, - "Spark ML".getBytes, - 222.toShort, - 113, - 11111111L, - 0.25.toFloat, - 0.75D, - Decimal("12345.678"), - ("2004-06-19").date, - ("1999-08-26 10:43:59.123").ts) - ) - val schema = StructType(List(StructField("StringCol", StringType, true), StructField("BooleanCol", BooleanType, false), StructField("ByteCol", ByteType, false), @@ -370,7 +333,7 @@ trait FileSourceAggregatePushDownSuite StructField("DateCol", DateType, false), StructField("TimestampCol", TimestampType, false)).toArray) - val rdd = sparkContext.parallelize(rows) + val rdd = sparkContext.parallelize(inputRows) withTempPath { file => spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath) withTempView("test") { @@ -395,9 +358,7 @@ trait FileSourceAggregatePushDownSuite checkKeywordsExistsInExplain(testMinWithAllTypes, expected_plan_fragment) } - checkAnswer(testMinWithAllTypes, Seq(Row("a string", false, 1.toByte, - "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, - 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes) val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), min(ByteCol), " + "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + @@ -417,8 +378,7 @@ trait FileSourceAggregatePushDownSuite checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMinWithOutTSAndBinary, Seq(Row(false, 1.toByte, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date))) + checkAnswer(testMinWithOutTSAndBinary, expectedMinWithOutTSAndBinary) val testMaxWithAllTypes = sql("SELECT max(StringCol), max(BooleanCol), " + "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), " + @@ -437,9 +397,7 @@ trait FileSourceAggregatePushDownSuite checkKeywordsExistsInExplain(testMaxWithAllTypes, expected_plan_fragment) } - checkAnswer(testMaxWithAllTypes, Seq(Row("test string", true, 16.toByte, - "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes) val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), max(ByteCol), " + "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + @@ -459,8 +417,7 @@ trait FileSourceAggregatePushDownSuite checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMaxWithoutTSAndBinary, Seq(Row(true, 16.toByte, - 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date))) + checkAnswer(testMaxWithoutTSAndBinary, expectedMaxWithOutTSAndBinary) val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + @@ -486,13 +443,88 @@ trait FileSourceAggregatePushDownSuite checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } - checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + checkAnswer(testCount, expectedCount) } } } } } + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + testPushDownForAllDataTypes( + rows, + Seq(Row("a string", false, 1.toByte, + "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)), + Seq(Row(false, 1.toByte, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date)), + Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)), + Seq(Row(true, 16.toByte, + 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date)), + Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)) + ) + + // Test for 0 row (empty file) + val nullRow = Row.fromSeq((1 to 12).map(_ => null)) + val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null)) + val zeroCount = Row.fromSeq((1 to 12).map(_ => 0)) + testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), Seq(nullRowWithOutTSAndBinary), + Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount)) + } + test("column name case sensitivity") { Seq("false", "true").foreach { enableVectorizedReader => withSQLConf(aggPushDownEnabledKey -> "true", From d85d4ba720a0936ee636320be8cfefe9a72565f9 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 28 Oct 2021 11:39:33 -0700 Subject: [PATCH 7/7] Addressed all comments --- .../sql/execution/datasources/orc/OrcColumnStatistics.java | 2 +- .../spark/sql/execution/datasources/orc/OrcFooterReader.java | 2 +- .../apache/spark/sql/execution/datasources/orc/OrcUtils.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java index 77d82fd3d7f2d..8adb9e8ca20be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -31,7 +31,7 @@ * tree pre-ordering. This is used for aggregate push down in ORC. * * For nested data types (array, map and struct), the sub-field statistics are stored recursively - * inside parent column's `children` field. Here is an example of `OrcColumnStatistics`: + * inside parent column's children field. Here is an example of {@link OrcColumnStatistics}: * * Data schema: * c1: int diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java index 74091e9bd8074..546b048648844 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -27,7 +27,7 @@ import java.util.Queue; /** - * `OrcFooterReader` is a util class which encapsulates the helper + * {@link OrcFooterReader} is a util class which encapsulates the helper * methods of reading ORC file footer. */ public class OrcFooterReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 56c63cb7e1977..b2624150a9151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -402,7 +402,7 @@ object OrcUtils extends Logging { var columnsStatistics: OrcColumnStatistics = null try { columnsStatistics = OrcFooterReader.readStatistics(reader) - } catch { case e: RuntimeException => + } catch { case e: Exception => throw new SparkException( s"Cannot read columns statistics in file: $filePath. Please consider disabling " + s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e)