|
17 | 17 |
|
18 | 18 | package org.apache.spark.util.random |
19 | 19 |
|
| 20 | +import scala.collection.{Map, mutable} |
20 | 21 | import scala.collection.mutable.ArrayBuffer |
21 | | -import scala.collection.{mutable, Map} |
| 22 | +import scala.reflect.ClassTag |
| 23 | + |
22 | 24 | import org.apache.commons.math3.random.RandomDataGenerator |
23 | | -import org.apache.spark.{Logging, TaskContext} |
24 | | -import org.apache.spark.util.random.{PoissonBounds => PB} |
25 | | -import scala.Some |
| 25 | +import org.apache.spark.{Logging, SparkContext, TaskContext} |
26 | 26 | import org.apache.spark.rdd.RDD |
| 27 | +import org.apache.spark.util.Utils |
| 28 | +import org.apache.spark.util.random.{PoissonBounds => PB} |
27 | 29 |
|
28 | 30 | private[spark] object StratifiedSampler extends Logging { |
| 31 | + |
| 32 | + /** |
| 33 | + * A version of {@link #aggregate()} that passes the TaskContext to the function that does |
| 34 | + * aggregation for each partition. This function avoids creating an extra depth in the RDD |
| 35 | + * lineage, as opposed to using mapPartitionsWithId, which results in slightly improved run time. |
| 36 | + */ |
| 37 | + def aggregateWithContext[U: ClassTag, T: ClassTag](zeroValue: U) |
| 38 | + (rdd: RDD[T], |
| 39 | + seqOp: ((TaskContext, U), T) => U, |
| 40 | + combOp: (U, U) => U): U = { |
| 41 | + val sc: SparkContext = rdd.sparkContext |
| 42 | + // Clone the zero value since we will also be serializing it as part of tasks |
| 43 | + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) |
| 44 | + // pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce |
| 45 | + val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item)) |
| 46 | + val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) => |
| 47 | + (arg1._1, combOp(arg1._2, arg1._2)) |
| 48 | + val cleanSeqOp = sc.clean(paddedSeqOp) |
| 49 | + val cleanCombOp = sc.clean(paddedcombOp) |
| 50 | + val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => |
| 51 | + (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2 |
| 52 | + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) |
| 53 | + sc.runJob(rdd, aggregatePartition, mergeResult) |
| 54 | + jobResult |
| 55 | + } |
| 56 | + |
29 | 57 | /** |
30 | 58 | * Returns the function used by aggregate to collect sampling statistics for each partition. |
31 | 59 | */ |
@@ -153,7 +181,7 @@ private[spark] object StratifiedSampler extends Logging { |
153 | 181 | val seqOp = StratifiedSampler.getSeqOp[K,V](false, fractionByKey, None) |
154 | 182 | val combOp = StratifiedSampler.getCombOp[K]() |
155 | 183 | val zeroU = new Result[K](Map[K, Stratum](), seed = seed) |
156 | | - val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap |
| 184 | + val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap |
157 | 185 | samplingRateByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey) |
158 | 186 | } |
159 | 187 | (idx: Int, iter: Iterator[(K, V)]) => { |
@@ -183,7 +211,7 @@ private[spark] object StratifiedSampler extends Logging { |
183 | 211 | val seqOp = StratifiedSampler.getSeqOp[K,V](true, fractionByKey, counts) |
184 | 212 | val combOp = StratifiedSampler.getCombOp[K]() |
185 | 213 | val zeroU = new Result[K](Map[K, Stratum](), seed = seed) |
186 | | - val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap |
| 214 | + val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap |
187 | 215 | val thresholdByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey) |
188 | 216 | (idx: Int, iter: Iterator[(K, V)]) => { |
189 | 217 | val random = new RandomDataGenerator() |
|
0 commit comments