-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2174][MLLIB] treeReduce and treeAggregate #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
fe42a5e
eb71c33
0f94490
be6a88a
8a2a59c
d58a087
142a857
7495681
9bcc5d3
b04b96a
c6cd267
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd | |
| import scala.language.implicitConversions | ||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.HashPartitioner | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| /** | ||
| * Machine learning specific RDD functions. | ||
|
|
@@ -44,6 +47,65 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { | |
| new SlidingRDD[T](self, windowSize) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Reduces the elements of this RDD in a tree pattern. | ||
| * @param depth suggested depth of the tree | ||
| * @see [[org.apache.spark.rdd.RDD#reduce]] | ||
| */ | ||
| def treeReduce(f: (T, T) => T, depth: Int): T = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this used at all? if not maybe we don't need it?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this PR, but it is useful for testing (simpler than treeAggregate) and it could be used in the feature. |
||
| require(depth >= 1, s"Depth must be greater than 1 but got $depth.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. greater or equal to?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
| val cleanF = self.context.clean(f) | ||
| val reducePartition: Iterator[T] => Option[T] = iter => { | ||
| if (iter.hasNext) { | ||
| Some(iter.reduceLeft(cleanF)) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| val local = self.mapPartitions(it => Iterator(reducePartition(it))) | ||
| val op: (Option[T], Option[T]) => Option[T] = (c, x) => { | ||
| if (c.isDefined && x.isDefined) { | ||
| Some(cleanF(c.get, x.get)) | ||
| } else if (c.isDefined) { | ||
| c | ||
| } else if (x.isDefined) { | ||
| x | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| RDDFunctions.fromRDD(local).treeAggregate(Option.empty[T])(op, op, depth) | ||
| .getOrElse(throw new UnsupportedOperationException("empty collection")) | ||
| } | ||
|
|
||
| /** | ||
| * Aggregates the elements of this RDD in a tree pattern. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "in a tree pattern" -> "in a multi-level aggregation tree pattern". |
||
| * @param depth suggested depth of the tree | ||
| * @see [[org.apache.spark.rdd.RDD#aggregate]] | ||
| */ | ||
| def treeAggregate[U: ClassTag](zeroValue: U)( | ||
| seqOp: (U, T) => U, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the convention is 4 space indent for arguments
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| combOp: (U, U) => U, | ||
| depth: Int): U = { | ||
| require(depth >= 1, s"Depth must be greater than 1 but got $depth.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| if (self.partitions.size == 0) { | ||
| return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) | ||
| } | ||
| val cleanSeqOp = self.context.clean(seqOp) | ||
| val cleanCombOp = self.context.clean(combOp) | ||
| val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) | ||
| var local = self.mapPartitions(it => Iterator(aggregatePartition(it))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think partial or partiallyAggregated is probably a better name
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed to |
||
| var numPartitions = local.partitions.size | ||
| val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) | ||
| while (numPartitions > scale + numPartitions / scale) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be great to add some comments here..
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also just for the sake of program always terminating, i'd set a cap on max depth (like 8).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| numPartitions /= scale | ||
| local = local.mapPartitionsWithIndex { (i, iter) => | ||
| iter.map((i % numPartitions, _)) | ||
| }.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is beyond your PR -- but @mateiz and I talked about adding some native primitive to shuffle to improve specifically this pattern (basically there is no need to create numPartitions streams within each map task).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that would be great ~ |
||
| } | ||
| local.reduce(cleanCombOp) | ||
| } | ||
| } | ||
|
|
||
| private[mllib] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can set depth to be 2 by default so you don't need to repeat this ...