-
Notifications
You must be signed in to change notification settings - Fork 51
[SPARK-2253] Add logic to decide when to do key/value pre-aggregation on Map side. #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5c50042
8b2ec26
c73b508
6d82da4
9404ab1
33cff4c
8d5f36a
5ce87f4
35d924f
b0df06c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.shuffle | ||
|
|
||
|
|
||
| import org.apache.spark.SparkConf | ||
| import org.apache.spark.util.collection.AppendOnlyMap | ||
| import scala.collection.mutable.MutableList | ||
| import scala.collection.Iterator | ||
|
|
||
| /** | ||
| * Created by vladio on 7/14/15. | ||
| */ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should remove this. I don't see this header on any other files in Spark.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, sorry! :-D |
||
| private[spark] class ShuffleAggregationManager[K, V]( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest a name like "PartialAggregationVerifier" or something. Also, you can just pass the SparkConf directly in the constructor and use it immediately when defining your variables, and make your variables vals. Make the class declaration as follows:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used the same idea from the code written by Reynold: apache/spark@master...rxin:partialAggCore Do you think is a good idea to pass SparkConf as a constructor parameter? I haven't seen this so far (in my very limited experience with Spark :-D) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BlockManager has sparkConf in the constructor. Reynold's commit is pretty old, so just use whatever seems closest to the newest codebase style. |
||
| val conf: SparkConf, | ||
| records: Iterator[Product2[K, V]]) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noob question: why |
||
|
|
||
| private val partialAggCheckInterval = conf.getInt("spark.partialAgg.interval", 10000) | ||
| private val partialAggReduction = conf.getDouble("spark.partialAgg.reduction", 0.5) | ||
| private var partialAggEnabled = true | ||
|
|
||
| private val uniqueKeysMap = new AppendOnlyMap[K, Boolean] | ||
| private var iteratedElements = MutableList[Product2[K, V]]() | ||
| private var numIteratedRecords = 0 | ||
|
|
||
| def getRestoredIterator(): Iterator[Product2[K, V]] = { | ||
| if (records.hasNext) { | ||
| iteratedElements.toIterator ++ records | ||
| } else { | ||
| iteratedElements.toIterator | ||
| } | ||
| } | ||
|
|
||
| def enableAggregation(): Boolean = { | ||
| while (records.hasNext | ||
| && numIteratedRecords < partialAggCheckInterval | ||
| && partialAggEnabled) { | ||
| val kv = records.next() | ||
|
|
||
| iteratedElements += kv | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iteratedElements could potential oom if we storing in memory 10k large items. Perhaps we should store it in a size tracking collection, and stop sampling when either we hit 10k items, or if the size tracking collection gets too big? Talking to @mccheah, the other way is to do it "inline", you can talk to him directly if you want some insight about that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inline logic is what the original commit did, but that's pretty hard to do. I like the size tracking idea, but size tracking in and of itself isn't free so you should benchmark that. Size tracking is most expensive when you have an RDD of composite objects (i.e. not primitives, think like an RDD of HashSet objects)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi! Yep, we were aware of that and I think there are 2 possible solutions to solve that:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should just use a size tracking collection. The collection doesn't have to spill, significantly simplifying your implementation. The size tracking collection will be able to report its size in memory, and then when it hits some memory threshold you take that as the sample to conduct your heuristic. |
||
| numIteratedRecords += 1 | ||
|
|
||
| uniqueKeysMap.update(kv._1, true) | ||
|
|
||
| if (numIteratedRecords == partialAggCheckInterval) { | ||
| val partialAggSize = uniqueKeysMap.size | ||
| if (partialAggSize > numIteratedRecords * partialAggReduction) { | ||
| partialAggEnabled = false | ||
| } | ||
| } | ||
| } | ||
|
|
||
| partialAggEnabled | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll need the Apache license header at the top of the file. You can copy and paste it from any other pre-existing Scala file. |
||
|
|
||
| import org.apache.spark.shuffle.ShuffleAggregationManager | ||
| import org.apache.spark.shuffle.sort.SortShuffleWriter._ | ||
| import org.mockito.Mockito._ | ||
|
|
||
| /** | ||
| * Created by vladio on 7/15/15. | ||
| */ | ||
| class ShuffleAggregationManagerSuite extends SparkFunSuite { | ||
|
|
||
| test("conditions for doing the pre-aggregation") { | ||
| val conf = new SparkConf(loadDefaults = false) | ||
| conf.set("spark.partialAgg.interval", "4") | ||
| conf.set("spark.partialAgg.reduction", "0.5") | ||
|
|
||
| // This test will pass if the first 4 elements of a set contains at most 2 unique keys. | ||
| // Generate the records. | ||
| val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), | ||
| (3, "Patricia"), (4, "Georgeta")) | ||
|
|
||
| // Test. | ||
| val aggManager = new ShuffleAggregationManager[Int, String](conf, records) | ||
| assert(aggManager.enableAggregation() == true) | ||
| } | ||
|
|
||
| test("conditions for skipping the pre-aggregation") { | ||
| val conf = new SparkConf(loadDefaults = false) | ||
| conf.set("spark.partialAgg.interval", "4") | ||
| conf.set("spark.partialAgg.reduction", "0.5") | ||
|
|
||
| val records = Iterator((1, "Vlad"), (2, "Marius"), (3, "Marian"), (2, "Cornel"), | ||
| (3, "Patricia"), (4, "Georgeta")) | ||
|
|
||
| val aggManager = new ShuffleAggregationManager[Int, String](conf, records) | ||
| assert(aggManager.enableAggregation() == false) | ||
| } | ||
|
|
||
| test("restoring the iterator") { | ||
| val conf = new SparkConf(loadDefaults = false) | ||
| conf.set("spark.partialAgg.interval", "4") | ||
| conf.set("spark.partialAgg.reduction", "0.5") | ||
|
|
||
| val listOfElements = List((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), | ||
| (3, "Patricia"), (4, "Georgeta")) | ||
| val records = listOfElements.toIterator | ||
| val recordsCopy = listOfElements.toIterator | ||
|
|
||
| val aggManager = new ShuffleAggregationManager[Int, String](conf, records) | ||
| assert(aggManager.enableAggregation() == true) | ||
|
|
||
| val restoredRecords = aggManager.getRestoredIterator() | ||
| assert(restoredRecords.hasNext) | ||
|
|
||
| while (restoredRecords.hasNext && recordsCopy.hasNext) { | ||
| val kv1 = restoredRecords.next() | ||
| val kv2 = recordsCopy.next() | ||
|
|
||
| assert(kv1 == kv2) | ||
| } | ||
|
|
||
| assert(!restoredRecords.hasNext) | ||
| assert(!recordsCopy.hasNext) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need the Apache license header at the top of the file. You can copy and paste it from any other pre-existing Scala file.