Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}

Expand Down Expand Up @@ -1419,7 +1419,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.rdd.util

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util

import scala.collection.mutable

Expand Down Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
private[mllib] abstract class PeriodicCheckpointer[T](
private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {

Expand Down Expand Up @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]

/**
* Call this to unpersist the Dataset.
*/
def unpersistDataSet(): Unit = {
while (persistedQueue.nonEmpty) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's the goal but this isn't thread-safe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would agree with you, that this is not thread safe. Is that a concern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the limited internal only use, it should be ok

val dataToUnpersist = persistedQueue.dequeue()
unpersist(dataToUnpersist)
}
}

/**
* Call this at the end to delete any remaining checkpoint files.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
}

test("get a range of elements in an array not partitioned by a range partitioner") {
val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairs = sc.parallelize(pairArr, 10)
val range = pairs.filterByRange(200, 800).collect()
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.utils

import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {

import PeriodicRDDCheckpointerSuite._

Expand Down
14 changes: 14 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,20 @@ showDF(properties, numRows = 200, truncate = FALSE)

</table>

### GraphX

<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
<tr>
<td><code>spark.graphx.pregel.checkpointInterval</code></td>
<td>-1</td>
<td>
Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains
after lots of iterations. The checkpoint is disabled by default.
</td>
</tr>
</table>

### Deploy

<table class="table">
Expand Down
9 changes: 6 additions & 3 deletions docs/graphx-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ messages remaining.
> messaging function. These constraints allow additional optimization within GraphX.

The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
of its implementation (note calls to graph.cache have been removed):
of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally
checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number,
say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)):

{% highlight scala %}
class GraphOps[VD, ED] {
Expand All @@ -722,6 +724,7 @@ class GraphOps[VD, ED] {
: Graph[VD, ED] = {
// Receive the initial message at each vertex
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()

// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
Expand All @@ -734,8 +737,8 @@ class GraphOps[VD, ED] {
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = g.mapReduceTriplets(
sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
messages = GraphXUtils.mapReduceTriplets(
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
activeMessages = messages.count()
i += 1
}
Expand Down
25 changes: 21 additions & 4 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package org.apache.spark.graphx

import scala.reflect.ClassTag

import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer

/**
* Implements a Pregel-like bulk-synchronous message-passing API.
Expand Down Expand Up @@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
val checkpointInterval = graph.vertices.sparkContext.getConf
.getInt("spark.graphx.pregel.checkpointInterval", -1)
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
checkpointInterval, graph.vertices.sparkContext)
graphCheckpointer.update(g)

// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
checkpointInterval, graph.vertices.sparkContext)
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
var activeMessages = messages.count()

// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
g = g.joinVertices(messages)(vprog).cache()
g = g.joinVertices(messages)(vprog)
graphCheckpointer.update(g)

val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = GraphXUtils.mapReduceTriplets(
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the comment here be updated?

// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()

logInfo("Pregel finished iteration " + i)
Expand All @@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
messages.unpersist(blocking = false)
messageCheckpointer.unpersistDataSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand this change. Why do we replace

messages.unpersist(blocking = false)

with

messageCheckpointer.unpersistDataSet()

Especially because this adds a new public method to PeriodicCheckpointer that no other code has needed before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the thing is we use messageCheckpointer.update to do the cache, to make a pair, we can use it to unpersist data. Please correct me if I understand wrong.
I think it's fine to add this new method as there is already a public method to cache data in PersistQueue, we should provide a public method to clean the queue.

graphCheckpointer.deleteAllCheckpoints()
messageCheckpointer.deleteAllCheckpoints()
g
} // end of apply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.graphx.util

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down Expand Up @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
* TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
Expand All @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](

override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
data.vertices.persist()
/* We need to use cache because persist does not honor the default storage level requested
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't persist better? this could potentially support different storage level later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use cache because persist does not honor the default storage level requested when constructing the graph. Only cache does that. It's confusing, but true. To verify this for yourself, change these values to persist and run the PeriodicGraphCheckpointerSuite tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mallman @dding3 it will be good to add a comment on that here

* when constructing the graph. Only cache does that.
*/
data.vertices.cache()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
data.edges.persist()
data.edges.cache()
}
}

Expand Down
Loading