Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,24 @@ abstract class RDD[T: ClassTag](
preservesPartitioning)
}

/**
* [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
* performance API to be used carefully only if we are sure that the RDD elements are
* serializable and don't require closure cleaning.
*
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
* which should be `false` unless this is a pair RDD and the input function doesn't modify
* the keys.
*/
private[spark] def mapPartitionsInternal[U: ClassTag](
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => f(iter),
preservesPartitioning)
}

/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private[sql] case class InMemoryRelation(

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitions { rowIterator =>
val cached = child.execute().mapPartitionsInternal { rowIterator =>
new Iterator[CachedBatch] {
def next(): CachedBatch = {
val columnBuilders = output.map { attribute =>
Expand Down Expand Up @@ -292,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan(
val relOutput = relation.output
val buffers = relation.cachedColumnBuffers

buffers.mapPartitions { cachedBatchIterator =>
buffers.mapPartitionsInternal { cachedBatchIterator =>
val partitionFilter = newPredicate(
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class Exchange(
case RangePartitioning(sortingExpressions, numPartitions) =>
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
val rddForSampling = rdd.mapPartitions { iter =>
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
Expand Down Expand Up @@ -200,12 +200,12 @@ case class Exchange(
}
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
rdd.mapPartitions { iter =>
rdd.mapPartitionsInternal { iter =>
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
}
} else {
rdd.mapPartitions { iter =>
rdd.mapPartitionsInternal { iter =>
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ case class Generate(
protected override def doExecute(): RDD[InternalRow] = {
// boundGenerator.terminate() should be triggered after all of the rows in the partition
if (join) {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
val joinedRow = new JoinedRow

Expand All @@ -79,7 +79,7 @@ case class Generate(
}
}
} else {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
iter.flatMap(row => boundGenerator.eval(row)) ++
LazyIterator(() => boundGenerator.terminate())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class SortBasedAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
// Because the constructor of an aggregation iterator will read at least the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)

protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val project = UnsafeProjection.create(projectList, child.output,
subexpressionEliminationEnabled)
iter.map { row =>
Expand All @@ -67,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val predicate = newPredicate(condition, child.output)
iter.filter { row =>
numInputRows += 1
Expand Down Expand Up @@ -161,19 +161,19 @@ case class Limit(limit: Int, child: SparkPlan)

protected override def doExecute(): RDD[InternalRow] = {
val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
iter.take(limit).map(row => (false, row.copy()))
}
} else {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val mutablePair = new MutablePair[Boolean, InternalRow]()
iter.take(limit).map(row => mutablePair.update(false, row))
}
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitions(_.take(limit).map(_._2))
shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
}
}

Expand Down Expand Up @@ -294,7 +294,7 @@ case class MapPartitions[T, U](
child: SparkPlan) extends UnaryNode {

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val tBoundEncoder = tEncoder.bind(child.output)
func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow)
}
Expand All @@ -318,7 +318,7 @@ case class AppendColumns[T, U](
override def output: Seq[Attribute] = child.output ++ newColumns

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val tBoundEncoder = tEncoder.bind(child.output)
val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
iter.map { row =>
Expand Down Expand Up @@ -350,7 +350,7 @@ case class MapGroups[K, T, U](
Seq(groupingAttributes.map(SortOrder(_, Ascending)))

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
val groupKeyEncoder = kEncoder.bind(groupingAttributes)
val groupDataEncoder = tEncoder.bind(child.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ case class BroadcastLeftSemiJoinHash(
val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric)
val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
left.execute().mapPartitionsInternal { streamIter =>
hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows)
}
} else {
val hashRelation =
HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

left.execute().mapPartitions { streamIter =>
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
hashedRelation match {
case unsafe: UnsafeHashedRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
row.copy()
}

leftResults.cartesian(rightResults).mapPartitions { iter =>
leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
val joinedRow = new JoinedRow
iter.map { r =>
numOutputRows += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class Sort(
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
child.execute().mapPartitions( { iterator =>
child.execute().mapPartitionsInternal( { iterator =>
val ordering = newOrdering(sortOrder, child.output)
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
TaskContext.get(), ordering = Some(ordering))
Expand Down