Skip to content

Commit a429aaf

Browse files
huaxingaochenzhx
authored andcommitted
[SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet
### What changes were proposed in this pull request? Push down Min/Max/Count to Parquet with the following restrictions: - nested types such as Array, Map or Struct will not be pushed down - Timestamp not pushed down because INT96 sort order is undefined, Parquet doesn't return statistics for INT96 - If the aggregate column is on partition column, only Count will be pushed, Min or Max will not be pushed down because Parquet doesn't return max/min for partition column. - If somehow the file doesn't have stats for the aggregate columns, Spark will throw Exception. - Currently, if filter/GROUP BY is involved, Min/Max/Count will not be pushed down, but the restriction will be lifted if the filter or GROUP BY is on partition column (https://issues.apache.org/jira/browse/SPARK-36646 and https://issues.apache.org/jira/browse/SPARK-36647) ### Why are the changes needed? Since parquet has the statistics information for min, max and count, we want to take advantage of this info and push down Min/Max/Count to parquet layer for better performance. ### Does this PR introduce _any_ user-facing change? Yes, `SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED` was added. If sets to true, we will push down Min/Max/Count to Parquet. ### How was this patch tested? new test suites Closes apache#33639 from huaxingao/parquet_agg. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 8653763 commit a429aaf

9 files changed

Lines changed: 984 additions & 33 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,14 @@ object SQLConf {
851851
.checkValue(threshold => threshold >= 0, "The threshold must not be negative.")
852852
.createWithDefault(10)
853853

854+
val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown")
855+
.doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
856+
" down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" +
857+
" can't be pushed down")
858+
.version("3.3.0")
859+
.booleanConf
860+
.createWithDefault(false)
861+
854862
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
855863
.doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
856864
"values will be written in Apache Parquet's fixed-length byte array format, which other " +
@@ -3679,6 +3687,8 @@ class SQLConf extends Serializable with Logging {
36793687
def parquetFilterPushDownInFilterThreshold: Int =
36803688
getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD)
36813689

3690+
def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED)
3691+
36823692
def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
36833693

36843694
def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
115115
def names: Array[String] = fieldNames
116116

117117
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
118-
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
118+
private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
119119
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
120120

121121
override def equals(that: Any): Boolean = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,28 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.parquet
1818

19+
import java.util
20+
21+
import scala.collection.mutable
22+
import scala.language.existentials
23+
1924
import org.apache.hadoop.fs.{FileStatus, Path}
2025
import org.apache.parquet.hadoop.ParquetFileWriter
26+
import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata}
27+
import org.apache.parquet.io.api.Binary
28+
import org.apache.parquet.schema.{PrimitiveType, Types}
29+
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
2130

31+
import org.apache.spark.SparkException
2232
import org.apache.spark.sql.SparkSession
33+
import org.apache.spark.sql.catalyst.InternalRow
34+
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
35+
import org.apache.spark.sql.execution.RowToColumnConverter
36+
import org.apache.spark.sql.execution.datasources.PartitioningUtils
37+
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
38+
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
2339
import org.apache.spark.sql.types.StructType
40+
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
2441

2542
object ParquetUtils {
2643
def inferSchema(
@@ -127,4 +144,214 @@ object ParquetUtils {
127144
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
128145
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
129146
}
147+
148+
/**
149+
* When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to
150+
* createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want
151+
* to get the partial aggregates (Max/Min/Count) result using the statistics information
152+
* from Parquet footer file, and then construct an InternalRow from these aggregate results.
153+
*
154+
* @return Aggregate results in the format of InternalRow
155+
*/
156+
private[sql] def createAggInternalRowFromFooter(
157+
footer: ParquetMetadata,
158+
filePath: String,
159+
dataSchema: StructType,
160+
partitionSchema: StructType,
161+
aggregation: Aggregation,
162+
aggSchema: StructType,
163+
datetimeRebaseMode: LegacyBehaviorPolicy.Value,
164+
isCaseSensitive: Boolean): InternalRow = {
165+
val (primitiveTypes, values) = getPushedDownAggResult(
166+
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
167+
168+
val builder = Types.buildMessage
169+
primitiveTypes.foreach(t => builder.addField(t))
170+
val parquetSchema = builder.named("root")
171+
172+
val schemaConverter = new ParquetToSparkSchemaConverter
173+
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
174+
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
175+
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
176+
primitiveTypeNames.zipWithIndex.foreach {
177+
case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
178+
val v = values(i).asInstanceOf[Boolean]
179+
converter.getConverter(i).asPrimitiveConverter.addBoolean(v)
180+
case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
181+
val v = values(i).asInstanceOf[Integer]
182+
converter.getConverter(i).asPrimitiveConverter.addInt(v)
183+
case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
184+
val v = values(i).asInstanceOf[Long]
185+
converter.getConverter(i).asPrimitiveConverter.addLong(v)
186+
case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
187+
val v = values(i).asInstanceOf[Float]
188+
converter.getConverter(i).asPrimitiveConverter.addFloat(v)
189+
case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
190+
val v = values(i).asInstanceOf[Double]
191+
converter.getConverter(i).asPrimitiveConverter.addDouble(v)
192+
case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
193+
val v = values(i).asInstanceOf[Binary]
194+
converter.getConverter(i).asPrimitiveConverter.addBinary(v)
195+
case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
196+
val v = values(i).asInstanceOf[Binary]
197+
converter.getConverter(i).asPrimitiveConverter.addBinary(v)
198+
case (_, i) =>
199+
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
200+
}
201+
converter.currentRecord
202+
}
203+
204+
/**
205+
* When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of
206+
* PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader
207+
* to read data from Parquet and aggregate at Spark layer. Instead we want
208+
* to get the aggregates (Max/Min/Count) result using the statistics information
209+
* from Parquet footer file, and then construct a ColumnarBatch from these aggregate results.
210+
*
211+
* @return Aggregate results in the format of ColumnarBatch
212+
*/
213+
private[sql] def createAggColumnarBatchFromFooter(
214+
footer: ParquetMetadata,
215+
filePath: String,
216+
dataSchema: StructType,
217+
partitionSchema: StructType,
218+
aggregation: Aggregation,
219+
aggSchema: StructType,
220+
offHeap: Boolean,
221+
datetimeRebaseMode: LegacyBehaviorPolicy.Value,
222+
isCaseSensitive: Boolean): ColumnarBatch = {
223+
val row = createAggInternalRowFromFooter(
224+
footer,
225+
filePath,
226+
dataSchema,
227+
partitionSchema,
228+
aggregation,
229+
aggSchema,
230+
datetimeRebaseMode,
231+
isCaseSensitive)
232+
val converter = new RowToColumnConverter(aggSchema)
233+
val columnVectors = if (offHeap) {
234+
OffHeapColumnVector.allocateColumns(1, aggSchema)
235+
} else {
236+
OnHeapColumnVector.allocateColumns(1, aggSchema)
237+
}
238+
converter.convert(row, columnVectors.toArray)
239+
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
240+
}
241+
242+
/**
243+
* Calculate the pushed down aggregates (Max/Min/Count) result using the statistics
244+
* information from Parquet footer file.
245+
*
246+
* @return A tuple of `Array[PrimitiveType]` and Array[Any].
247+
* The first element is the Parquet PrimitiveType of the aggregate column,
248+
* and the second element is the aggregated value.
249+
*/
250+
private[sql] def getPushedDownAggResult(
251+
footer: ParquetMetadata,
252+
filePath: String,
253+
dataSchema: StructType,
254+
partitionSchema: StructType,
255+
aggregation: Aggregation,
256+
isCaseSensitive: Boolean)
257+
: (Array[PrimitiveType], Array[Any]) = {
258+
val footerFileMetaData = footer.getFileMetaData
259+
val fields = footerFileMetaData.getSchema.getFields
260+
val blocks = footer.getBlocks
261+
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
262+
val valuesBuilder = mutable.ArrayBuilder.make[Any]
263+
264+
assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
265+
aggregation.aggregateExpressions.foreach { agg =>
266+
var value: Any = None
267+
var rowCount = 0L
268+
var isCount = false
269+
var index = 0
270+
var schemaName = ""
271+
blocks.forEach { block =>
272+
val blockMetaData = block.getColumns
273+
agg match {
274+
case max: Max =>
275+
val colName = max.column.fieldNames.head
276+
index = dataSchema.fieldNames.toList.indexOf(colName)
277+
schemaName = "max(" + colName + ")"
278+
val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true)
279+
if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) {
280+
value = currentMax
281+
}
282+
case min: Min =>
283+
val colName = min.column.fieldNames.head
284+
index = dataSchema.fieldNames.toList.indexOf(colName)
285+
schemaName = "min(" + colName + ")"
286+
val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false)
287+
if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) {
288+
value = currentMin
289+
}
290+
case count: Count =>
291+
schemaName = "count(" + count.column.fieldNames.head + ")"
292+
rowCount += block.getRowCount
293+
var isPartitionCol = false
294+
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
295+
.toSet.contains(count.column.fieldNames.head)) {
296+
isPartitionCol = true
297+
}
298+
isCount = true
299+
if (!isPartitionCol) {
300+
index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head)
301+
// Count(*) includes the null values, but Count(colName) doesn't.
302+
rowCount -= getNumNulls(filePath, blockMetaData, index)
303+
}
304+
case _: CountStar =>
305+
schemaName = "count(*)"
306+
rowCount += block.getRowCount
307+
isCount = true
308+
case _ =>
309+
}
310+
}
311+
if (isCount) {
312+
valuesBuilder += rowCount
313+
primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName);
314+
} else {
315+
valuesBuilder += value
316+
val field = fields.get(index)
317+
primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName)
318+
.as(field.getLogicalTypeAnnotation)
319+
.length(field.asPrimitiveType.getTypeLength)
320+
.named(schemaName)
321+
}
322+
}
323+
(primitiveTypeBuilder.result, valuesBuilder.result)
324+
}
325+
326+
/**
327+
* Get the Max or Min value for ith column in the current block
328+
*
329+
* @return the Max or Min value
330+
*/
331+
private def getCurrentBlockMaxOrMin(
332+
filePath: String,
333+
columnChunkMetaData: util.List[ColumnChunkMetaData],
334+
i: Int,
335+
isMax: Boolean): Any = {
336+
val statistics = columnChunkMetaData.get(i).getStatistics
337+
if (!statistics.hasNonNullValue) {
338+
throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " +
339+
s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again")
340+
} else {
341+
if (isMax) statistics.genericGetMax else statistics.genericGetMin
342+
}
343+
}
344+
345+
private def getNumNulls(
346+
filePath: String,
347+
columnChunkMetaData: util.List[ColumnChunkMetaData],
348+
i: Int): Long = {
349+
val statistics = columnChunkMetaData.get(i).getStatistics
350+
if (!statistics.isNumNullsSet) {
351+
throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" +
352+
s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" +
353+
s" again")
354+
}
355+
statistics.getNumNulls;
356+
}
130357
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ abstract class FileScanBuilder(
9696
private def createRequiredNameSet(): Set[String] =
9797
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
9898

99-
private val partitionNameSet: Set[String] =
99+
val partitionNameSet: Set[String] =
100100
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
101101
}

0 commit comments

Comments
 (0)