Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need the Apache license header at the top of the file. You can copy and paste it from any other pre-existing Scala file.



import org.apache.spark.SparkConf
import org.apache.spark.util.collection.AppendOnlyMap
import scala.collection.mutable.MutableList
import scala.collection.Iterator

/**
* Created by vladio on 7/14/15.
*/

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove this. I don't see this header on any other files in Spark.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, sorry! :-D

private[spark] class ShuffleAggregationManager[K, V](
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest a name like "PartialAggregationVerifier" or something.

Also, you can just pass the SparkConf directly in the constructor and use it immediately when defining your variables, and make your variables vals. Make the class declaration as follows:

private[spark] class ShuffleAggregationManager[K, V](conf: SparkConf. iterator: Iterator[Product2[K,V]]) {
   ...
    private val isSpillEnabled = conf.getBoolean("spark.shuffle.spill", true)
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the same idea from the code written by Reynold: apache/spark@master...rxin:partialAggCore

Do you think is a good idea to pass SparkConf as a constructor parameter? I haven't seen this so far (in my very limited experience with Spark :-D)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockManager has sparkConf in the constructor.

Reynold's commit is pretty old, so just use whatever seems closest to the newest codebase style.

val conf: SparkConf,
records: Iterator[Product2[K, V]]) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: why Product2[K, V] rather than (K, V)?


private val partialAggCheckInterval = conf.getInt("spark.partialAgg.interval", 10000)
private val partialAggReduction = conf.getDouble("spark.partialAgg.reduction", 0.5)
private var partialAggEnabled = true

private val uniqueKeysMap = new AppendOnlyMap[K, Boolean]
private var iteratedElements = MutableList[Product2[K, V]]()
private var numIteratedRecords = 0

def getRestoredIterator(): Iterator[Product2[K, V]] = {
if (records.hasNext) {
iteratedElements.toIterator ++ records
} else {
iteratedElements.toIterator
}
}

def enableAggregation(): Boolean = {
while (records.hasNext
&& numIteratedRecords < partialAggCheckInterval
&& partialAggEnabled) {
val kv = records.next()

iteratedElements += kv

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iteratedElements could potential oom if we storing in memory 10k large items. Perhaps we should store it in a size tracking collection, and stop sampling when either we hit 10k items, or if the size tracking collection gets too big?

Talking to @mccheah, the other way is to do it "inline", you can talk to him directly if you want some insight about that.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline logic is what the original commit did, but that's pretty hard to do. I like the size tracking idea, but size tracking in and of itself isn't free so you should benchmark that. Size tracking is most expensive when you have an RDD of composite objects (i.e. not primitives, think like an RDD of HashSet objects)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Yep, we were aware of that and I think there are 2 possible solutions to solve that:

  1. I think I have a mathematical algorithm (with probabilities, which assumes that keys are uniformly distributed) which is able to stop iterating after a number of steps (in this case, MAXIMUM 10k). I have to think a little bit more (actually, to remember it :-) ) and I'll post it here.
  2. Instead of using a MutableList, maybe we can switch to ExternalAppendOnlyMap or ExternalList (which Matt created in some of his previous commits).
    I'll be thinking about that :-)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should just use a size tracking collection. The collection doesn't have to spill, significantly simplifying your implementation. The size tracking collection will be able to report its size in memory, and then when it hits some memory threshold you take that as the sample to conduct your heuristic.

numIteratedRecords += 1

uniqueKeysMap.update(kv._1, true)

if (numIteratedRecords == partialAggCheckInterval) {
val partialAggSize = uniqueKeysMap.size
if (partialAggSize > numIteratedRecords * partialAggReduction) {
partialAggEnabled = false
}
}
}

partialAggEnabled
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,17 @@ private[spark] class HashShuffleWriter[K, V](

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
// Decide if it's optimal to do the pre-aggregation.
val aggManager = new ShuffleAggregationManager[K, V](SparkEnv.get.conf, records)

val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
if (aggManager.enableAggregation()) {
dep.aggregator.get.combineValuesByKey(aggManager.getRestoredIterator(), context)
} else {
aggManager.getRestoredIterator().map(kv =>
(kv._1, dep.aggregator.get.createCombiner(kv._2)))
}
} else {
records
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.shuffle.{ShuffleAggregationManager, IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter

import scala.reflect.ClassTag

private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
handle: BaseShuffleHandle[K, V, C],
Expand All @@ -36,7 +38,7 @@ private[spark] class SortShuffleWriter[K, V, C](

private val blockManager = SparkEnv.get.blockManager

private var sorter: SortShuffleFileWriter[K, V] = null
private var sorter: SortShuffleFileWriter[K, _] = null

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
Expand All @@ -50,34 +52,66 @@ private[spark] class SortShuffleWriter[K, V, C](

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else if (SortShuffleWriter.shouldBypassMergeSort(
SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need local aggregation and sorting, write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
// Decide if it's optimal to do the pre-aggregation.
val aggManager = new ShuffleAggregationManager[K, V](SparkEnv.get.conf, records)
// Short-circuit if dep.mapSideCombine is explicitly set to false so we don't do the
// aggregation check. Else check if pre-aggregation would actually be beneficial
// via aggManager.
val enableMapSideCombine = dep.mapSideCombine && aggManager.enableAggregation()

sorter = (dep.mapSideCombine, enableMapSideCombine) match {
case (_, true) =>
assert (dep.mapSideCombine)
require(dep.aggregator.isDefined, "Map-side combine requested without " +
"Aggregator specified!")
val selectedSorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
writeInternal(aggManager.getRestoredIterator(), selectedSorter)
selectedSorter
case (true, false) =>
// The user requested map-side combine, but we determined that it would be sub-optimal
// for here. So just write out the initial combiners (as the reducer will expect values
// of type "C") but don't do the aggregations itself
require (dep.aggregator.isDefined, "Map side combine requested with " +
"no aggregator specified!")
val selectedSorter = if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf,
dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
new BypassMergeSortShuffleWriter[K, C](SparkEnv.get.conf, blockManager, dep.partitioner,
writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
new ExternalSorter[K, C, C](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
val definedAggregator = dep.aggregator.get
val recordsWithAppliedCreateCombiner = aggManager.getRestoredIterator.map(kv =>
(kv._1, definedAggregator.createCombiner(kv._2)))
writeInternal(recordsWithAppliedCreateCombiner, selectedSorter)
selectedSorter
case (false, _) => {
val selectedSorter = if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf,
dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
writeMetrics, Serializer.getSerializer(dep.serializer))
} else {
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
writeInternal(aggManager.getRestoredIterator(), selectedSorter)
selectedSorter
}
}
sorter.insertAll(records)
logInfo("SortShuffleWriter - Enable Pre-Aggregation: " + enableMapSideCombine)
}

private def writeInternal[VorC](records: Iterator[Product2[K, VorC]],
selectedSorter: SortShuffleFileWriter[K, VorC]): Unit = {
selectedSorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
val partitionLengths = selectedSorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)

mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need the Apache license header at the top of the file. You can copy and paste it from any other pre-existing Scala file.


import org.apache.spark.shuffle.ShuffleAggregationManager
import org.apache.spark.shuffle.sort.SortShuffleWriter._
import org.mockito.Mockito._

/**
* Created by vladio on 7/15/15.
*/
class ShuffleAggregationManagerSuite extends SparkFunSuite {

test("conditions for doing the pre-aggregation") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

// This test will pass if the first 4 elements of a set contains at most 2 unique keys.
// Generate the records.
val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"),
(3, "Patricia"), (4, "Georgeta"))

// Test.
val aggManager = new ShuffleAggregationManager[Int, String](conf, records)
assert(aggManager.enableAggregation() == true)
}

test("conditions for skipping the pre-aggregation") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

val records = Iterator((1, "Vlad"), (2, "Marius"), (3, "Marian"), (2, "Cornel"),
(3, "Patricia"), (4, "Georgeta"))

val aggManager = new ShuffleAggregationManager[Int, String](conf, records)
assert(aggManager.enableAggregation() == false)
}

test("restoring the iterator") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

val listOfElements = List((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"),
(3, "Patricia"), (4, "Georgeta"))
val records = listOfElements.toIterator
val recordsCopy = listOfElements.toIterator

val aggManager = new ShuffleAggregationManager[Int, String](conf, records)
assert(aggManager.enableAggregation() == true)

val restoredRecords = aggManager.getRestoredIterator()
assert(restoredRecords.hasNext)

while (restoredRecords.hasNext && recordsCopy.hasNext) {
val kv1 = restoredRecords.next()
val kv2 = recordsCopy.next()

assert(kv1 == kv2)
}

assert(!restoredRecords.hasNext)
assert(!recordsCopy.hasNext)
}
}
16 changes: 16 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(thrown.getMessage.contains("SPARK-5063"))
}

test("combineByKey with pre-aggregation") {
val records = sc.parallelize(Seq(("Vlad", 3), ("Marius", 4), ("Vlad", 2), ("Marius", 6),
("Georgeta", 4), ("Vlad", 5)))
val createCombiner = ((value: Int) => (value, 1))
val mergeValue = (x: (Int, Int), v: Int) => (x._1 + v, x._2 + 1)
val mergeCombiners = (x: (Int, Int), y: (Int, Int)) => (x._1 + y._1, x._2 + y._2)

val combined = records.combineByKey(createCombiner, mergeValue, mergeCombiners)
val combinedMap = combined.collectAsMap()

assert(combinedMap.size == 3)
assert(combinedMap.exists(x => x._1 == "Vlad" && x._2._1 == 10 && x._2._2 == 3))
assert(combinedMap.exists(x => x._1 == "Marius" && x._2._1 == 10 && x._2._2 == 2))
assert(combinedMap.exists(x => x._1 == "Georgeta" && x._2._1 == 4 && x._2._2 == 1))
}

test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
val existingRDD = sc.parallelize(1 to 100)
sc.stop()
Expand Down