diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index 10f54f856a19..ca1074fcf6fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -127,6 +127,16 @@ object StaticSQLConf {
.toSequence
.createOptional
+ val SPARK_CACHE_SERIALIZER = buildStaticConf("spark.sql.cache.serializer")
+ .doc("The name of a class that implements " +
+ "org.apache.spark.sql.columnar.CachedBatchSerializer. It will be used to " +
+ "translate SQL data into a format that can more efficiently be cached. The underlying " +
+ "API is subject to change so use with caution. Multiple classes cannot be specified. " +
+ "The class must have a no-arg constructor.")
+ .version("3.1.0")
+ .stringConf
+ .createWithDefault("org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer")
+
val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners")
.doc("List of class names implementing QueryExecutionListener that will be automatically " +
"added to newly created sessions. The classes should have either a no-arg constructor, " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala
new file mode 100644
index 000000000000..1113e63cab33
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.annotation.{DeveloperApi, Since}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BindReferences, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, Length, LessThan, LessThanOrEqual, Literal, Or, Predicate, StartsWith}
+import org.apache.spark.sql.execution.columnar.{ColumnStatisticsSchema, PartitionStatistics}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{AtomicType, BinaryType, StructType}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Basic interface that all cached batches of data must support. This is primarily to allow
+ * for metrics to be handled outside of the encoding and decoding steps in a standard way.
+ */
+@DeveloperApi
+@Since("3.1.0")
+trait CachedBatch {
+ def numRows: Int
+ def sizeInBytes: Long
+}
+
+/**
+ * Provides APIs that handle transformations of SQL data associated with the cache/persist APIs.
+ */
+@DeveloperApi
+@Since("3.1.0")
+trait CachedBatchSerializer extends Serializable {
+ /**
+ * Can `convertColumnarBatchToCachedBatch()` be called instead of
+ * `convertInternalRowToCachedBatch()` for this given schema? True if it can and false if it
+ * cannot. Columnar input is only supported if the plan could produce columnar output. Currently
+ * this is mostly supported by input formats like parquet and orc, but more operations are likely
+ * to be supported soon.
+ * @param schema the schema of the data being stored.
+ * @return True if columnar input can be supported, else false.
+ */
+ def supportsColumnarInput(schema: Seq[Attribute]): Boolean
+
+ /**
+ * Convert an `RDD[InternalRow]` into an `RDD[CachedBatch]` in preparation for caching the data.
+ * @param input the input `RDD` to be converted.
+ * @param schema the schema of the data being stored.
+ * @param storageLevel where the data will be stored.
+ * @param conf the config for the query.
+ * @return The data converted into a format more suitable for caching.
+ */
+ def convertInternalRowToCachedBatch(
+ input: RDD[InternalRow],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch]
+
+ /**
+ * Convert an `RDD[ColumnarBatch]` into an `RDD[CachedBatch]` in preparation for caching the data.
+ * This will only be called if `supportsColumnarInput()` returned true for the given schema and
+ * the plan up to this point would could produce columnar output without modifying it.
+ * @param input the input `RDD` to be converted.
+ * @param schema the schema of the data being stored.
+ * @param storageLevel where the data will be stored.
+ * @param conf the config for the query.
+ * @return The data converted into a format more suitable for caching.
+ */
+ def convertColumnarBatchToCachedBatch(
+ input: RDD[ColumnarBatch],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch]
+
+ /**
+ * Builds a function that can be used to filter batches prior to being decompressed.
+ * In most cases extending [[SimpleMetricsCachedBatchSerializer]] will provide the filter logic
+ * necessary. You will need to provide metrics for this to work. [[SimpleMetricsCachedBatch]]
+ * provides the APIs to hold those metrics and explains the metrics used, really just min and max.
+ * Note that this is intended to skip batches that are not needed, and the actual filtering of
+ * individual rows is handled later.
+ * @param predicates the set of expressions to use for filtering.
+ * @param cachedAttributes the schema/attributes of the data that is cached. This can be helpful
+ * if you don't store it with the data.
+ * @return a function that takes the partition id and the iterator of batches in the partition.
+ * It returns an iterator of batches that should be decompressed.
+ */
+ def buildFilter(
+ predicates: Seq[Expression],
+ cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch]
+
+ /**
+ * Can `convertCachedBatchToColumnarBatch()` be called instead of
+ * `convertCachedBatchToInternalRow()` for this given schema? True if it can and false if it
+ * cannot. Columnar output is typically preferred because it is more efficient. Note that
+ * `convertCachedBatchToInternalRow()` must always be supported as there are other checks that
+ * can force row based output.
+ * @param schema the schema of the data being checked.
+ * @return true if columnar output should be used for this schema, else false.
+ */
+ def supportsColumnarOutput(schema: StructType): Boolean
+
+ /**
+ * The exact java types of the columns that are output in columnar processing mode. This
+ * is a performance optimization for code generation and is optional.
+ * @param attributes the attributes to be output.
+ * @param conf the config for the query that will read the data.
+ */
+ def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None
+
+ /**
+ * Convert the cached data into a ColumnarBatch. This currently is only used if
+ * `supportsColumnarOutput()` returns true for the associated schema, but there are other checks
+ * that can force row based output. One of the main advantages of doing columnar output over row
+ * based output is that the code generation is more standard and can be combined with code
+ * generation for downstream operations.
+ * @param input the cached batches that should be converted.
+ * @param cacheAttributes the attributes of the data in the batch.
+ * @param selectedAttributes the fields that should be loaded from the data and the order they
+ * should appear in the output batch.
+ * @param conf the configuration for the job.
+ * @return an RDD of the input cached batches transformed into the ColumnarBatch format.
+ */
+ def convertCachedBatchToColumnarBatch(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[ColumnarBatch]
+
+ /**
+ * Convert the cached batch into `InternalRow`s. If you want this to be performant, code
+ * generation is advised.
+ * @param input the cached batches that should be converted.
+ * @param cacheAttributes the attributes of the data in the batch.
+ * @param selectedAttributes the field that should be loaded from the data and the order they
+ * should appear in the output rows.
+ * @param conf the configuration for the job.
+ * @return RDD of the rows that were stored in the cached batches.
+ */
+ def convertCachedBatchToInternalRow(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[InternalRow]
+}
+
+/**
+ * A [[CachedBatch]] that stores some simple metrics that can be used for filtering of batches with
+ * the [[SimpleMetricsCachedBatchSerializer]].
+ * The metrics are returned by the stats value. For each column in the batch 5 columns of metadata
+ * are needed in the row.
+ */
+@DeveloperApi
+@Since("3.1.0")
+trait SimpleMetricsCachedBatch extends CachedBatch {
+ /**
+ * Holds stats for each cached column. The optional `upperBound` and `lowerBound` should be
+ * of the same type as the original column. If they are null, then it is assumed that they
+ * are not provided, and will not be used for filtering.
+ *
+ * - `upperBound` (optional)
+ * - `lowerBound` (Optional)
+ * - `nullCount`: `Int`
+ * - `rowCount`: `Int`
+ * - `sizeInBytes`: `Long`
+ *
+ * These are repeated for each column in the original cached data.
+ */
+ val stats: InternalRow
+ override def sizeInBytes: Long =
+ Range.apply(4, stats.numFields, 5).map(stats.getLong).sum
+}
+
+// Currently, uses statistics for all atomic types that are not `BinaryType`.
+private object ExtractableLiteral {
+ def unapply(expr: Expression): Option[Literal] = expr match {
+ case lit: Literal => lit.dataType match {
+ case BinaryType => None
+ case _: AtomicType => Some(lit)
+ case _ => None
+ }
+ case _ => None
+ }
+}
+
+/**
+ * Provides basic filtering for [[CachedBatchSerializer]] implementations.
+ * The requirement to extend this is that all of the batches produced by your serializer are
+ * instances of [[SimpleMetricsCachedBatch]].
+ * This does not calculate the metrics needed to be stored in the batches. That is up to each
+ * implementation. The metrics required are really just min and max values and those are optional
+ * especially for complex types. Because those metrics are simple and it is likely that compression
+ * will also be done on the data we thought it best to let each implementation decide on the most
+ * efficient way to calculate the metrics, possibly combining them with compression passes that
+ * might also be done across the data.
+ */
+@DeveloperApi
+@Since("3.1.0")
+abstract class SimpleMetricsCachedBatchSerializer extends CachedBatchSerializer with Logging {
+ override def buildFilter(
+ predicates: Seq[Expression],
+ cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = {
+ val stats = new PartitionStatistics(cachedAttributes)
+ val statsSchema = stats.schema
+
+ def statsFor(a: Attribute): ColumnStatisticsSchema = {
+ stats.forAttribute(a)
+ }
+
+ // Returned filter predicate should return false iff it is impossible for the input expression
+ // to evaluate to `true` based on statistics collected about this partition batch.
+ @transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
+ case And(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
+ (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
+
+ case Or(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) || buildFilter(rhs)
+
+ case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+ case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
+ case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+ case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
+ case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
+ case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound
+
+ case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+ statsFor(a).lowerBound <= l
+ case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+ l <= statsFor(a).upperBound
+
+ case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
+ case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l
+
+ case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
+ l <= statsFor(a).upperBound
+ case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
+ statsFor(a).lowerBound <= l
+
+ case IsNull(a: Attribute) => statsFor(a).nullCount > 0
+ case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
+
+ case In(a: AttributeReference, list: Seq[Expression])
+ if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
+ list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
+ l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
+ // This is an example to explain how it works, imagine that the id column stored as follows:
+ // __________________________________________
+ // | Partition ID | lowerBound | upperBound |
+ // |--------------|------------|------------|
+ // | p1 | '1' | '9' |
+ // | p2 | '10' | '19' |
+ // | p3 | '20' | '29' |
+ // | p4 | '30' | '39' |
+ // | p5 | '40' | '49' |
+ // |______________|____________|____________|
+ //
+ // A filter: df.filter($"id".startsWith("2")).
+ // In this case it substr lowerBound and upperBound:
+ // ________________________________________________________________________________________
+ // | Partition ID | lowerBound.substr(0, Length("2")) | upperBound.substr(0, Length("2")) |
+ // |--------------|-----------------------------------|-----------------------------------|
+ // | p1 | '1' | '9' |
+ // | p2 | '1' | '1' |
+ // | p3 | '2' | '2' |
+ // | p4 | '3' | '3' |
+ // | p5 | '4' | '4' |
+ // |______________|___________________________________|___________________________________|
+ //
+ // We can see that we only need to read p1 and p3.
+ case StartsWith(a: AttributeReference, ExtractableLiteral(l)) =>
+ statsFor(a).lowerBound.substr(0, Length(l)) <= l &&
+ l <= statsFor(a).upperBound.substr(0, Length(l))
+ }
+
+ // When we bind the filters we need to do it against the stats schema
+ val partitionFilters: Seq[Expression] = {
+ predicates.flatMap { p =>
+ val filter = buildFilter.lift(p)
+ val boundFilter =
+ filter.map(
+ BindReferences.bindReference(
+ _,
+ statsSchema,
+ allowFailures = true))
+
+ boundFilter.foreach(_ =>
+ filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
+
+ // If the filter can't be resolved then we are missing required statistics.
+ boundFilter.filter(_.resolved)
+ }
+ }
+
+ def ret(index: Int, cachedBatchIterator: Iterator[CachedBatch]): Iterator[CachedBatch] = {
+ val partitionFilter = Predicate.create(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ cachedAttributes)
+
+ partitionFilter.initialize(index)
+ val schemaIndex = cachedAttributes.zipWithIndex
+
+ cachedBatchIterator.filter { cb =>
+ val cachedBatch = cb.asInstanceOf[SimpleMetricsCachedBatch]
+ if (!partitionFilter.eval(cachedBatch.stats)) {
+ logDebug {
+ val statsString = schemaIndex.map { case (a, i) =>
+ val value = cachedBatch.stats.get(i, a.dataType)
+ s"${a.name}: $value"
+ }.mkString(", ")
+ s"Skipping partition based on stats $statsString"
+ }
+ false
+ } else {
+ true
+ }
+ }
+ }
+ ret
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 7d86c4801540..7201026b11b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -27,11 +27,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.columnar.{DefaultCachedBatchSerializer, InMemoryRelation}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable}
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -85,11 +84,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
val inMemoryRelation = sessionWithAqeOff.withActive {
val qe = sessionWithAqeOff.sessionState.executePlan(planToCache)
InMemoryRelation(
- sessionWithAqeOff.sessionState.conf.useCompression,
- sessionWithAqeOff.sessionState.conf.columnBatchSize, storageLevel,
- qe.executedPlan,
- tableName,
- optimizedPlan = qe.optimizedPlan)
+ storageLevel,
+ qe,
+ tableName)
}
this.synchronized {
@@ -195,9 +192,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
val sessionWithAqeOff = getOrCloneSessionWithAqeOff(spark)
val newCache = sessionWithAqeOff.withActive {
val qe = sessionWithAqeOff.sessionState.executePlan(cd.plan)
- InMemoryRelation(
- cacheBuilder = cd.cachedRepresentation.cacheBuilder.copy(cachedPlan = qe.executedPlan),
- optimizedPlan = qe.optimizedPlan)
+ InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
}
val recomputedPlan = cd.copy(cachedRepresentation = newCache)
this.synchronized {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index e01cd8598db0..13ea609f7bfa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -49,6 +49,13 @@ class ColumnarRule {
def postColumnarTransitions: Rule[SparkPlan] = plan => plan
}
+/**
+ * A trait that is used as a tag to indicate a transition from columns to rows. This allows plugins
+ * to replace the current [[ColumnarToRowExec]] with an optimized version and still have operations
+ * that walk a spark plan looking for this type of transition properly match it.
+ */
+trait ColumnarToRowTransition extends UnaryExecNode
+
/**
* Provides a common executor to translate an [[RDD]] of [[ColumnarBatch]] into an [[RDD]] of
* [[InternalRow]]. This is inserted whenever such a transition is determined to be needed.
@@ -57,7 +64,7 @@ class ColumnarRule {
* [[org.apache.spark.sql.execution.python.ArrowEvalPythonExec]] and
* [[MapPartitionsInRWithArrowExec]]. Eventually this should replace those implementations.
*/
-case class ColumnarToRowExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+case class ColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition with CodegenSupport {
assert(child.supportsColumnar)
override def output: Seq[Attribute] = child.output
@@ -479,7 +486,9 @@ case class RowToColumnarExec(child: SparkPlan) extends UnaryExecNode {
* Apply any user defined [[ColumnarRule]]s and find the correct place to insert transitions
* to/from columnar formatted data.
*/
-case class ApplyColumnarRulesAndInsertTransitions(conf: SQLConf, columnarRules: Seq[ColumnarRule])
+case class ApplyColumnarRulesAndInsertTransitions(
+ conf: SQLConf,
+ columnarRules: Seq[ColumnarRule])
extends Rule[SparkPlan] {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 20ecc57c49e7..45557bfbada6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
+class ColumnStatisticsSchema(a: Attribute) extends Serializable {
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)()
val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
@@ -32,7 +32,7 @@ private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializabl
val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes)
}
-private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
+class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
val (forAttribute: AttributeMap[ColumnStatisticsSchema], schema: Seq[AttributeReference]) = {
val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a))
(AttributeMap(allStats), allStats.flatMap(_._2.schema))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index bd2d06665a91..eb0663830dd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
* An Iterator to walk through the InternalRows from a CachedBatch
*/
abstract class ColumnarIterator extends Iterator[InternalRow] {
- def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType],
+ def initialize(input: Iterator[DefaultCachedBatch], columnTypes: Array[DataType],
columnIndexes: Array[Int]): Unit
}
@@ -203,7 +203,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
return false;
}
- ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next();
+ ${classOf[DefaultCachedBatch].getName} batch =
+ (${classOf[DefaultCachedBatch].getName}) input.next();
currentRow = 0;
numRowsInBatch = batch.numRows();
for (int i = 0; i < columnIndexes.length; i ++) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index cf9f3ddeb42a..be3dc5934e84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -19,33 +19,189 @@ package org.apache.spark.sql.execution.columnar
import org.apache.commons.lang3.StringUtils
+import org.apache.spark.TaskContext
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer}
+import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.LongAccumulator
-
+import org.apache.spark.util.{LongAccumulator, Utils}
/**
- * CachedBatch is a cached batch of rows.
+ * The default implementation of CachedBatch.
*
* @param numRows The total number of rows in this batch
* @param buffers The buffers for serialized columns
* @param stats The stat of columns
*/
-private[columnar]
-case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
+case class DefaultCachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
+ extends SimpleMetricsCachedBatch
+
+/**
+ * The default implementation of CachedBatchSerializer.
+ */
+class DefaultCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer {
+ override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = false
+
+ override def convertColumnarBatchToCachedBatch(
+ input: RDD[ColumnarBatch],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch] =
+ throw new IllegalStateException("Columnar input is not supported")
+
+ override def convertInternalRowToCachedBatch(
+ input: RDD[InternalRow],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch] = {
+ val batchSize = conf.columnBatchSize
+ val useCompression = conf.useCompression
+ convertForCacheInternal(input, schema, batchSize, useCompression)
+ }
+
+ def convertForCacheInternal(
+ input: RDD[InternalRow],
+ output: Seq[Attribute],
+ batchSize: Int,
+ useCompression: Boolean): RDD[CachedBatch] = {
+ input.mapPartitionsInternal { rowIterator =>
+ new Iterator[DefaultCachedBatch] {
+ def next(): DefaultCachedBatch = {
+ val columnBuilders = output.map { attribute =>
+ ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression)
+ }.toArray
+
+ var rowCount = 0
+ var totalSize = 0L
+ while (rowIterator.hasNext && rowCount < batchSize
+ && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) {
+ val row = rowIterator.next()
+
+ // Added for SPARK-6082. This assertion can be useful for scenarios when something
+ // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM
+ // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat
+ // hard to decipher.
+ assert(
+ row.numFields == columnBuilders.length,
+ s"Row column number mismatch, expected ${output.size} columns, " +
+ s"but got ${row.numFields}." +
+ s"\nRow content: $row")
+
+ var i = 0
+ totalSize = 0
+ while (i < row.numFields) {
+ columnBuilders(i).appendFrom(row, i)
+ totalSize += columnBuilders(i).columnStats.sizeInBytes
+ i += 1
+ }
+ rowCount += 1
+ }
+
+ val stats = InternalRow.fromSeq(
+ columnBuilders.flatMap(_.columnStats.collectedStatistics).toSeq)
+ DefaultCachedBatch(rowCount, columnBuilders.map { builder =>
+ JavaUtils.bufferToArray(builder.build())
+ }, stats)
+ }
+
+ def hasNext: Boolean = rowIterator.hasNext
+ }
+ }
+ }
+
+ override def supportsColumnarOutput(schema: StructType): Boolean = schema.fields.forall(f =>
+ f.dataType match {
+ // More types can be supported, but this is to match the original implementation that
+ // only supported primitive types "for ease of review"
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType => true
+ case _ => false
+ })
+
+ override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] =
+ Option(Seq.fill(attributes.length)(
+ if (!conf.offHeapColumnVectorEnabled) {
+ classOf[OnHeapColumnVector].getName
+ } else {
+ classOf[OffHeapColumnVector].getName
+ }
+ ))
+
+ override def convertCachedBatchToColumnarBatch(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[ColumnarBatch] = {
+ val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled
+ val outputSchema = StructType.fromAttributes(selectedAttributes)
+ val columnIndices =
+ selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray
+
+ def createAndDecompressColumn(cb: CachedBatch): ColumnarBatch = {
+ val cachedColumnarBatch = cb.asInstanceOf[DefaultCachedBatch]
+ val rowCount = cachedColumnarBatch.numRows
+ val taskContext = Option(TaskContext.get())
+ val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) {
+ OnHeapColumnVector.allocateColumns(rowCount, outputSchema)
+ } else {
+ OffHeapColumnVector.allocateColumns(rowCount, outputSchema)
+ }
+ val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]])
+ columnarBatch.setNumRows(rowCount)
+
+ for (i <- selectedAttributes.indices) {
+ ColumnAccessor.decompress(
+ cachedColumnarBatch.buffers(columnIndices(i)),
+ columnarBatch.column(i).asInstanceOf[WritableColumnVector],
+ outputSchema.fields(i).dataType, rowCount)
+ }
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close()))
+ columnarBatch
+ }
+
+ input.map(createAndDecompressColumn)
+ }
+
+ override def convertCachedBatchToInternalRow(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[InternalRow] = {
+ // Find the ordinals and data types of the requested columns.
+ val (requestedColumnIndices, requestedColumnDataTypes) =
+ selectedAttributes.map { a =>
+ cacheAttributes.map(_.exprId).indexOf(a.exprId) -> a.dataType
+ }.unzip
+
+ val columnTypes = requestedColumnDataTypes.map {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }.toArray
+
+ input.mapPartitionsInternal { cachedBatchIterator =>
+ val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+ columnarIterator.initialize(cachedBatchIterator.asInstanceOf[Iterator[DefaultCachedBatch]],
+ columnTypes,
+ requestedColumnIndices.toArray)
+ columnarIterator
+ }
+ }
+}
+private[sql]
case class CachedRDDBuilder(
- useCompression: Boolean,
- batchSize: Int,
+ serializer: CachedBatchSerializer,
storageLevel: StorageLevel,
@transient cachedPlan: SparkPlan,
tableName: Option[String]) {
@@ -85,54 +241,24 @@ case class CachedRDDBuilder(
}
private def buildBuffers(): RDD[CachedBatch] = {
- val output = cachedPlan.output
- val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
- new Iterator[CachedBatch] {
- def next(): CachedBatch = {
- val columnBuilders = output.map { attribute =>
- ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression)
- }.toArray
-
- var rowCount = 0
- var totalSize = 0L
- while (rowIterator.hasNext && rowCount < batchSize
- && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) {
- val row = rowIterator.next()
-
- // Added for SPARK-6082. This assertion can be useful for scenarios when something
- // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM
- // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat
- // hard to decipher.
- assert(
- row.numFields == columnBuilders.length,
- s"Row column number mismatch, expected ${output.size} columns, " +
- s"but got ${row.numFields}." +
- s"\nRow content: $row")
-
- var i = 0
- totalSize = 0
- while (i < row.numFields) {
- columnBuilders(i).appendFrom(row, i)
- totalSize += columnBuilders(i).columnStats.sizeInBytes
- i += 1
- }
- rowCount += 1
- }
-
- sizeInBytesStats.add(totalSize)
- rowCountStats.add(rowCount)
-
- val stats = InternalRow.fromSeq(
- columnBuilders.flatMap(_.columnStats.collectedStatistics).toSeq)
- CachedBatch(rowCount, columnBuilders.map { builder =>
- JavaUtils.bufferToArray(builder.build())
- }, stats)
- }
-
- def hasNext: Boolean = rowIterator.hasNext
- }
+ val cb = if (cachedPlan.supportsColumnar) {
+ serializer.convertColumnarBatchToCachedBatch(
+ cachedPlan.executeColumnar(),
+ cachedPlan.output,
+ storageLevel,
+ cachedPlan.conf)
+ } else {
+ serializer.convertInternalRowToCachedBatch(
+ cachedPlan.execute(),
+ cachedPlan.output,
+ storageLevel,
+ cachedPlan.conf)
+ }
+ val cached = cb.map { batch =>
+ sizeInBytesStats.add(batch.sizeInBytes)
+ rowCountStats.add(batch.numRows)
+ batch
}.persist(storageLevel)
-
cached.setName(cachedName)
cached
}
@@ -140,22 +266,71 @@ case class CachedRDDBuilder(
object InMemoryRelation {
+ private[this] var ser: Option[CachedBatchSerializer] = None
+ private[this] def getSerializer(sqlConf: SQLConf): CachedBatchSerializer = synchronized {
+ if (ser.isEmpty) {
+ val serName = sqlConf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER)
+ val serClass = Utils.classForName(serName)
+ val instance = serClass.getConstructor().newInstance().asInstanceOf[CachedBatchSerializer]
+ ser = Some(instance)
+ }
+ ser.get
+ }
+
+ def convertToColumnarIfPossible(plan: SparkPlan): SparkPlan = plan match {
+ case gen: WholeStageCodegenExec => gen.child match {
+ case c2r: ColumnarToRowTransition => c2r.child match {
+ case ia: InputAdapter => ia.child
+ case _ => plan
+ }
+ case _ => plan
+ }
+ case c2r: ColumnarToRowTransition => // This matches when whole stage code gen is disabled.
+ c2r.child
+ case _ => plan
+ }
+
def apply(
- useCompression: Boolean,
- batchSize: Int,
+ storageLevel: StorageLevel,
+ qe: QueryExecution,
+ tableName: Option[String]): InMemoryRelation = {
+ val optimizedPlan = qe.optimizedPlan
+ val serializer = getSerializer(optimizedPlan.conf)
+ val child = if (serializer.supportsColumnarInput(optimizedPlan.output)) {
+ convertToColumnarIfPossible(qe.executedPlan)
+ } else {
+ qe.executedPlan
+ }
+ val cacheBuilder = CachedRDDBuilder(serializer, storageLevel, child, tableName)
+ val relation = new InMemoryRelation(child.output, cacheBuilder, optimizedPlan.outputOrdering)
+ relation.statsOfPlanToCache = optimizedPlan.stats
+ relation
+ }
+
+ /**
+ * This API is intended only to be used for testing.
+ */
+ def apply(
+ serializer: CachedBatchSerializer,
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
optimizedPlan: LogicalPlan): InMemoryRelation = {
- val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)
+ val cacheBuilder = CachedRDDBuilder(serializer, storageLevel, child, tableName)
val relation = new InMemoryRelation(child.output, cacheBuilder, optimizedPlan.outputOrdering)
relation.statsOfPlanToCache = optimizedPlan.stats
relation
}
- def apply(cacheBuilder: CachedRDDBuilder, optimizedPlan: LogicalPlan): InMemoryRelation = {
+ def apply(cacheBuilder: CachedRDDBuilder, qe: QueryExecution): InMemoryRelation = {
+ val optimizedPlan = qe.optimizedPlan
+ val newBuilder = if (cacheBuilder.serializer.supportsColumnarInput(optimizedPlan.output)) {
+ cacheBuilder.copy(cachedPlan = convertToColumnarIfPossible(qe.executedPlan))
+ } else {
+ cacheBuilder.copy(cachedPlan = qe.executedPlan)
+ }
val relation = new InMemoryRelation(
- cacheBuilder.cachedPlan.output, cacheBuilder, optimizedPlan.outputOrdering)
+ newBuilder.cachedPlan.output, newBuilder, optimizedPlan.outputOrdering)
relation.statsOfPlanToCache = optimizedPlan.stats
relation
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index f03c2586048b..e4194562b7a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -17,19 +17,15 @@
package org.apache.spark.sql.execution.columnar
-import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.vectorized._
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-
+import org.apache.spark.sql.vectorized.ColumnarBatch
case class InMemoryTableScanExec(
attributes: Seq[Attribute],
@@ -57,68 +53,29 @@ case class InMemoryTableScanExec(
relation = relation.canonicalized.asInstanceOf[InMemoryRelation])
override def vectorTypes: Option[Seq[String]] =
- Option(Seq.fill(attributes.length)(
- if (!conf.offHeapColumnVectorEnabled) {
- classOf[OnHeapColumnVector].getName
- } else {
- classOf[OffHeapColumnVector].getName
- }
- ))
+ relation.cacheBuilder.serializer.vectorTypes(attributes, conf)
/**
* If true, get data from ColumnVector in ColumnarBatch, which are generally faster.
* If false, get data from UnsafeRow build from CachedBatch
*/
override val supportsColumnar: Boolean = {
- // In the initial implementation, for ease of review
- // support only primitive data types and # of fields is less than wholeStageMaxNumFields
- conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType => true
- case _ => false
- }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema)
- }
-
- private val columnIndices =
- attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray
-
- private val relationSchema = relation.schema.toArray
-
- private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i)))
-
- private def createAndDecompressColumn(
- cachedColumnarBatch: CachedBatch,
- offHeapColumnVectorEnabled: Boolean): ColumnarBatch = {
- val rowCount = cachedColumnarBatch.numRows
- val taskContext = Option(TaskContext.get())
- val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) {
- OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
- } else {
- OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
- }
- val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]])
- columnarBatch.setNumRows(rowCount)
-
- for (i <- attributes.indices) {
- ColumnAccessor.decompress(
- cachedColumnarBatch.buffers(columnIndices(i)),
- columnarBatch.column(i).asInstanceOf[WritableColumnVector],
- columnarBatchSchema.fields(i).dataType, rowCount)
- }
- taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close()))
- columnarBatch
+ conf.cacheVectorizedReaderEnabled &&
+ !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) &&
+ relation.cacheBuilder.serializer.supportsColumnarOutput(relation.schema)
}
private lazy val columnarInputRDD: RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
val buffers = filteredCachedBatches()
- val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled
- buffers
- .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled))
- .map { buffer =>
- numOutputRows += buffer.numRows()
- buffer
- }
+ relation.cacheBuilder.serializer.convertCachedBatchToColumnarBatch(
+ buffers,
+ relation.output,
+ attributes,
+ conf).map { cb =>
+ numOutputRows += cb.numRows()
+ cb
+ }
}
private lazy val inputRDD: RDD[InternalRow] = {
@@ -130,35 +87,24 @@ case class InMemoryTableScanExec(
val numOutputRows = longMetric("numOutputRows")
// Using these variables here to avoid serialization of entire objects (if referenced
// directly) within the map Partitions closure.
- val relOutput: AttributeSeq = relation.output
-
- filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
- // Find the ordinals and data types of the requested columns.
- val (requestedColumnIndices, requestedColumnDataTypes) =
- attributes.map { a =>
- relOutput.indexOf(a.exprId) -> a.dataType
- }.unzip
-
- // update SQL metrics
- val withMetrics = cachedBatchIterator.map { batch =>
- if (enableAccumulatorsForTest) {
- readBatches.add(1)
+ val relOutput = relation.output
+ val serializer = relation.cacheBuilder.serializer
+
+ // update SQL metrics
+ val withMetrics =
+ filteredCachedBatches().mapPartitionsInternal { iter =>
+ if (enableAccumulatorsForTest && iter.hasNext) {
+ readPartitions.add(1)
+ }
+ iter.map { batch =>
+ if (enableAccumulatorsForTest) {
+ readBatches.add(1)
+ }
+ numOutputRows += batch.numRows
+ batch
}
- numOutputRows += batch.numRows
- batch
- }
-
- val columnTypes = requestedColumnDataTypes.map {
- case udt: UserDefinedType[_] => udt.sqlType
- case other => other
- }.toArray
- val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
- columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
- if (enableAccumulatorsForTest && columnarIterator.hasNext) {
- readPartitions.add(1)
}
- columnarIterator
- }
+ serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf)
}
override def output: Seq[Attribute] = attributes
@@ -186,114 +132,6 @@ case class InMemoryTableScanExec(
override def outputOrdering: Seq[SortOrder] =
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
- // Keeps relation's partition statistics because we don't serialize relation.
- private val stats = relation.partitionStatistics
- private def statsFor(a: Attribute) = stats.forAttribute(a)
-
- // Currently, only use statistics from atomic types except binary type only.
- private object ExtractableLiteral {
- def unapply(expr: Expression): Option[Literal] = expr match {
- case lit: Literal => lit.dataType match {
- case BinaryType => None
- case _: AtomicType => Some(lit)
- case _ => None
- }
- case _ => None
- }
- }
-
- // Returned filter predicate should return false iff it is impossible for the input expression
- // to evaluate to `true` based on statistics collected about this partition batch.
- @transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
- case And(lhs: Expression, rhs: Expression)
- if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
- (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
-
- case Or(lhs: Expression, rhs: Expression)
- if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
- buildFilter(lhs) || buildFilter(rhs)
-
- case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
-
- case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
-
- case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
- case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound
-
- case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
- statsFor(a).lowerBound <= l
- case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
- l <= statsFor(a).upperBound
-
- case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
- case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l
-
- case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
- l <= statsFor(a).upperBound
- case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
- statsFor(a).lowerBound <= l
-
- case IsNull(a: Attribute) => statsFor(a).nullCount > 0
- case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
-
- case In(a: AttributeReference, list: Seq[Expression])
- if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
- list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
- l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
-
- // This is an example to explain how it works, imagine that the id column stored as follows:
- // __________________________________________
- // | Partition ID | lowerBound | upperBound |
- // |--------------|------------|------------|
- // | p1 | '1' | '9' |
- // | p2 | '10' | '19' |
- // | p3 | '20' | '29' |
- // | p4 | '30' | '39' |
- // | p5 | '40' | '49' |
- // |______________|____________|____________|
- //
- // A filter: df.filter($"id".startsWith("2")).
- // In this case it substr lowerBound and upperBound:
- // ________________________________________________________________________________________
- // | Partition ID | lowerBound.substr(0, Length("2")) | upperBound.substr(0, Length("2")) |
- // |--------------|-----------------------------------|-----------------------------------|
- // | p1 | '1' | '9' |
- // | p2 | '1' | '1' |
- // | p3 | '2' | '2' |
- // | p4 | '3' | '3' |
- // | p5 | '4' | '4' |
- // |______________|___________________________________|___________________________________|
- //
- // We can see that we only need to read p1 and p3.
- case StartsWith(a: AttributeReference, ExtractableLiteral(l)) =>
- statsFor(a).lowerBound.substr(0, Length(l)) <= l &&
- l <= statsFor(a).upperBound.substr(0, Length(l))
- }
-
- lazy val partitionFilters: Seq[Expression] = {
- predicates.flatMap { p =>
- val filter = buildFilter.lift(p)
- val boundFilter =
- filter.map(
- BindReferences.bindReference(
- _,
- stats.schema,
- allowFailures = true))
-
- boundFilter.foreach(_ =>
- filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
-
- // If the filter can't be resolved then we are missing required statistics.
- boundFilter.filter(_.resolved)
- }
- }
-
lazy val enableAccumulatorsForTest: Boolean = sqlContext.conf.inMemoryTableScanStatisticsEnabled
// Accumulators used for testing purposes
@@ -303,37 +141,13 @@ case class InMemoryTableScanExec(
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
private def filteredCachedBatches(): RDD[CachedBatch] = {
- // Using these variables here to avoid serialization of entire objects (if referenced directly)
- // within the map Partitions closure.
- val schema = stats.schema
- val schemaIndex = schema.zipWithIndex
val buffers = relation.cacheBuilder.cachedColumnBuffers
- buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
- val partitionFilter = Predicate.create(
- partitionFilters.reduceOption(And).getOrElse(Literal(true)),
- schema)
- partitionFilter.initialize(index)
-
- // Do partition batch pruning if enabled
- if (inMemoryPartitionPruningEnabled) {
- cachedBatchIterator.filter { cachedBatch =>
- if (!partitionFilter.eval(cachedBatch.stats)) {
- logDebug {
- val statsString = schemaIndex.map { case (a, i) =>
- val value = cachedBatch.stats.get(i, a.dataType)
- s"${a.name}: $value"
- }.mkString(", ")
- s"Skipping partition based on stats $statsString"
- }
- false
- } else {
- true
- }
- }
- } else {
- cachedBatchIterator
- }
+ if (inMemoryPartitionPruningEnabled) {
+ val filterFunc = relation.cacheBuilder.serializer.buildFilter(predicates, relation.output)
+ buffers.mapPartitionsWithIndexInternal(filterFunc)
+ } else {
+ buffers
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala
new file mode 100644
index 000000000000..72eba7f6e690
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import org.apache.spark.SparkConf
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer}
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.storage.StorageLevel
+
+case class SingleIntCachedBatch(data: Array[Int]) extends CachedBatch {
+ override def numRows: Int = data.length
+ override def sizeInBytes: Long = 4 * data.length
+}
+
+/**
+ * Very simple serializer that only supports a single int column, but does support columnar.
+ */
+class TestSingleIntColumnarCachedBatchSerializer extends CachedBatchSerializer {
+ override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true
+
+ override def convertInternalRowToCachedBatch(
+ input: RDD[InternalRow],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch] = {
+ throw new IllegalStateException("This does not work. This is only for testing")
+ }
+
+ override def convertColumnarBatchToCachedBatch(
+ input: RDD[ColumnarBatch],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch] = {
+ if (schema.length != 1 || schema.head.dataType != IntegerType) {
+ throw new IllegalArgumentException("Only a single column of non-nullable ints works. " +
+ s"This is for testing $schema")
+ }
+ input.map { cb =>
+ val column = cb.column(0)
+ val data = column.getInts(0, cb.numRows())
+ SingleIntCachedBatch(data)
+ }
+ }
+
+ override def supportsColumnarOutput(schema: StructType): Boolean = true
+ override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] =
+ Some(attributes.map(_ => classOf[OnHeapColumnVector].getName))
+
+ override def convertCachedBatchToColumnarBatch(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[ColumnarBatch] = {
+ if (selectedAttributes.isEmpty) {
+ input.map { cached =>
+ val single = cached.asInstanceOf[SingleIntCachedBatch]
+ new ColumnarBatch(new Array[ColumnVector](0), single.numRows)
+ }
+ } else {
+ if (selectedAttributes.length > 1 ||
+ selectedAttributes.head.dataType != IntegerType) {
+ throw new IllegalArgumentException("Only a single column of non-nullable ints works. " +
+ s"This is for testing")
+ }
+ input.map { cached =>
+ val single = cached.asInstanceOf[SingleIntCachedBatch]
+ val cv = OnHeapColumnVector.allocateColumns(single.numRows, selectedAttributes.toStructType)
+ val data = single.data
+ cv(0).putInts(0, data.length, data, 0)
+ new ColumnarBatch(cv.toArray, single.numRows)
+ }
+ }
+ }
+
+ override def convertCachedBatchToInternalRow(
+ input: RDD[CachedBatch],
+ cacheAttributes: Seq[Attribute],
+ selectedAttributes: Seq[Attribute],
+ conf: SQLConf): RDD[InternalRow] = {
+ throw new IllegalStateException("This does not work. This is only for testing")
+ }
+
+ override def buildFilter(
+ predicates: Seq[Expression],
+ cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = {
+ def ret(index: Int, cb: Iterator[CachedBatch]): Iterator[CachedBatch] = cb
+ ret
+ }
+}
+
+class CachedBatchSerializerSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf.set(
+ StaticSQLConf.SPARK_CACHE_SERIALIZER.key,
+ classOf[TestSingleIntColumnarCachedBatchSerializer].getName)
+ }
+
+ test("Columnar Cache Plugin") {
+ withTempPath { workDir =>
+ val workDirPath = workDir.getAbsolutePath
+ val input = Seq(100, 200, 300).toDF("count")
+ input.write.parquet(workDirPath)
+ val data = spark.read.parquet(workDirPath)
+ data.cache()
+ assert(data.count() == 3)
+ checkAnswer(data, Row(100) :: Row(200) :: Row(300) :: Nil)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 18f29f7b90ad..b8f73f4563ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -20,18 +20,32 @@ package org.apache.spark.sql.execution.columnar
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{ColumnarToRowExec, FilterExec, InputAdapter, LocalTableScanExec, WholeStageCodegenExec}
+import org.apache.spark.sql.columnar.CachedBatch
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FilterExec, InputAdapter, WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel._
-import org.apache.spark.util.Utils
+
+class TestCachedBatchSerializer(
+ useCompression: Boolean,
+ batchSize: Int) extends DefaultCachedBatchSerializer {
+
+ override def convertInternalRowToCachedBatch(input: RDD[InternalRow],
+ schema: Seq[Attribute],
+ storageLevel: StorageLevel,
+ conf: SQLConf): RDD[CachedBatch] = {
+ convertForCacheInternal(input, schema, batchSize, useCompression)
+ }
+}
class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
import testImplicits._
@@ -42,12 +56,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
data.createOrReplaceTempView(s"testData$dataType")
val storageLevel = MEMORY_ONLY
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
- val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None,
- data.logicalPlan)
+ val inMemoryRelation = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5),
+ storageLevel, plan, None, data.logicalPlan)
assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match {
- case _: CachedBatch =>
+ case _: DefaultCachedBatch =>
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
}
checkAnswer(inMemoryRelation, data.collect().toSeq)
@@ -119,8 +133,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("simple columnar query") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
- val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
- testData.logicalPlan)
+ val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5),
+ MEMORY_ONLY, plan, None, testData.logicalPlan)
checkAnswer(scan, testData.collect().toSeq)
}
@@ -140,8 +154,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("projection") {
val logicalPlan = testData.select('value, 'key).logicalPlan
val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan
- val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
- logicalPlan)
+ val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5),
+ MEMORY_ONLY, plan, None, logicalPlan)
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
@@ -157,8 +171,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
- val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
- testData.logicalPlan)
+ val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5),
+ MEMORY_ONLY, plan, None, testData.logicalPlan)
checkAnswer(scan, testData.collect().toSeq)
checkAnswer(scan, testData.collect().toSeq)
@@ -336,7 +350,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("SPARK-17549: cached table size should be correctly calculated") {
val data = spark.sparkContext.parallelize(1 to 10, 5).toDF()
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
- val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan)
+ val cached = InMemoryRelation(new TestCachedBatchSerializer(true, 5),
+ MEMORY_ONLY, plan, None, data.logicalPlan)
// Materialize the data.
val expectedAnswer = data.collect()
@@ -349,7 +364,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("cached row count should be calculated") {
val data = spark.range(6).toDF
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
- val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan)
+ val cached = InMemoryRelation(new TestCachedBatchSerializer(true, 5),
+ MEMORY_ONLY, plan, None, data.logicalPlan)
// Materialize the data.
val expectedAnswer = data.collect()
@@ -474,12 +490,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession {
test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") {
val attribute = AttributeReference("a", IntegerType)()
- val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil)
- val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None,
- LocalRelation(Seq(attribute), Nil))
- val tableScanExec = InMemoryTableScanExec(Seq(attribute),
- Seq(In(attribute, Nil)), testRelation)
- assert(tableScanExec.partitionFilters.isEmpty)
+ val testSerializer = new TestCachedBatchSerializer(false, 1)
+ testSerializer.buildFilter(Seq(In(attribute, Nil)), Seq(attribute))
}
testWithWholeStageCodegenOnAndOff("SPARK-22348: table cache " +