-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-5484][GraphX] Periodically do checkpoint in Pregel #15125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
3834981
166fd6d
b119e4a
d183a7c
352dcb2
ad82e45
e786838
a25d00c
38e6238
f2efef6
194dc27
9d7e796
dae94aa
dd6c366
2639eb1
11bc349
9a6fd1f
5015b44
24d4ad6
ec62659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.impl | ||
| package org.apache.spark.util | ||
|
||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -708,7 +708,9 @@ messages remaining. | |
| > messaging function. These constraints allow additional optimization within GraphX. | ||
|
|
||
| The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* | ||
| of its implementation (note calls to graph.cache have been removed): | ||
| of its implementation (note: to avoid stackOverflowError due to long lineage chains, graph and | ||
|
||
| messages are periodically checkpoint and the checkpoint interval is set by | ||
| "spark.graphx.pregel.checkpointInterval", it can be disable by set as -1): | ||
|
|
||
| {% highlight scala %} | ||
| class GraphOps[VD, ED] { | ||
|
|
@@ -722,8 +724,9 @@ class GraphOps[VD, ED] { | |
| : Graph[VD, ED] = { | ||
| // Receive the initial message at each vertex | ||
| var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() | ||
|
|
||
| // compute the messages | ||
| var messages = g.mapReduceTriplets(sendMsg, mergeMsg) | ||
| var messages = g.mapReduceTriplets(sendMsg, mergeMsg) | ||
| var activeMessages = messages.count() | ||
| // Loop until no messages remain or maxIterations is achieved | ||
| var i = 0 | ||
|
|
@@ -732,10 +735,10 @@ class GraphOps[VD, ED] { | |
| g = g.joinVertices(messages)(vprog).cache() | ||
| val oldMessages = messages | ||
| // Send new messages, skipping edges where neither side received a message. We must cache | ||
| // messages so it can be materialized on the next line, allowing us to uncache the previous | ||
| // iteration. | ||
| messages = g.mapReduceTriplets( | ||
| sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() | ||
| // and periodic checkpoint messages so it can be materialized on the next line, and avoid | ||
| // to have a long lineage chain. | ||
| messages = GraphXUtils.mapReduceTriplets( | ||
| g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() | ||
| activeMessages = messages.count() | ||
| i += 1 | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,10 @@ package org.apache.spark.graphx | |
|
|
||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.graphx.util.PeriodicGraphCheckpointer | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.util.PeriodicRDDCheckpointer | ||
|
|
||
| /** | ||
| * Implements a Pregel-like bulk-synchronous message-passing API. | ||
|
|
@@ -122,27 +125,39 @@ object Pregel extends Logging { | |
| require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + | ||
| s" but got ${maxIterations}") | ||
|
|
||
| var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() | ||
| val checkpointInterval = graph.vertices.sparkContext.getConf | ||
| .getInt("spark.graphx.pregel.checkpointInterval", 10) | ||
|
||
| var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) | ||
| val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( | ||
| checkpointInterval, graph.vertices.sparkContext) | ||
| graphCheckpointer.update(g) | ||
|
|
||
| // compute the messages | ||
| var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) | ||
| val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( | ||
| checkpointInterval, graph.vertices.sparkContext) | ||
| messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to update it before the loop? I think it should be enough to do the update in the loop.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. What do you think, @dding3?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, for the sake of simplicity and consistency, I'm going to suggest we keep the checkpointer update calls but remove all
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need cache graph/messages here so they don't need to be computed again in the loop. I agree with you and I will keep the checkpointer update calls and remove all .cache calls. |
||
| var activeMessages = messages.count() | ||
|
|
||
| // Loop | ||
| var prevG: Graph[VD, ED] = null | ||
| var i = 0 | ||
| while (activeMessages > 0 && i < maxIterations) { | ||
| // Receive the messages and update the vertices. | ||
| prevG = g | ||
| g = g.joinVertices(messages)(vprog).cache() | ||
| g = g.joinVertices(messages)(vprog) | ||
| graphCheckpointer.update(g) | ||
|
|
||
| val oldMessages = messages | ||
| // Send new messages, skipping edges where neither side received a message. We must cache | ||
| // messages so it can be materialized on the next line, allowing us to uncache the previous | ||
| // iteration. | ||
| // and periodic checkpoint messages so it can be materialized on the next line, and avoid | ||
| // to have a long lineage chain. | ||
| messages = GraphXUtils.mapReduceTriplets( | ||
| g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() | ||
| g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) | ||
| // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should the comment here be updated? |
||
| // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages | ||
| // and the vertices of g). | ||
| messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) | ||
| activeMessages = messages.count() | ||
|
|
||
| logInfo("Pregel finished iteration " + i) | ||
|
|
@@ -154,7 +169,9 @@ object Pregel extends Logging { | |
| // count the iteration | ||
| i += 1 | ||
| } | ||
| messages.unpersist(blocking = false) | ||
| messageCheckpointer.unpersistDataSet() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't understand this change. Why do we replace messages.unpersist(blocking = false)with messageCheckpointer.unpersistDataSet()Especially because this adds a new public method to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the thing is we use messageCheckpointer.update to do the cache, to make a pair, we can use it to unpersist data. Please correct me if I understand wrong. |
||
| graphCheckpointer.deleteAllCheckpoints() | ||
| messageCheckpointer.deleteAllCheckpoints() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be inside a finally clause to make sure checkpoint data is cleaned up even in error cases?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think when there is an exception during training, if we keep the checkpoints, there is a chance for user to recover from it. I checked in RandomForest/GBT in spark, looks like they only delete the checkpoints when the training successful finished. |
||
| g | ||
| } // end of apply | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,11 +15,12 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.impl | ||
| package org.apache.spark.graphx.util | ||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.graphx.Graph | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.PeriodicCheckpointer | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel | |
| * @tparam VD Vertex descriptor type | ||
| * @tparam ED Edge descriptor type | ||
| * | ||
| * TODO: Move this out of MLlib? | ||
| */ | ||
| private[mllib] class PeriodicGraphCheckpointer[VD, ED]( | ||
| private[spark] class PeriodicGraphCheckpointer[VD, ED]( | ||
| checkpointInterval: Int, | ||
| sc: SparkContext) | ||
| extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { | ||
|
|
@@ -87,10 +87,10 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( | |
|
|
||
| override protected def persist(data: Graph[VD, ED]): Unit = { | ||
| if (data.vertices.getStorageLevel == StorageLevel.NONE) { | ||
| data.vertices.persist() | ||
| data.vertices.cache() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't persist better? this could potentially support different storage level later
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
| if (data.edges.getStorageLevel == StorageLevel.NONE) { | ||
| data.edges.persist() | ||
| data.edges.cache() | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,77 +15,81 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.impl | ||
| package org.apache.spark.graphx.util | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason this test suite isn't moved into the GraphX codebase? |
||
| import org.apache.spark.{SparkContext, SparkFunSuite} | ||
| import org.apache.spark.graphx.{Edge, Graph} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.Utils | ||
|
|
||
|
|
||
| class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
| class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { | ||
|
|
||
| import PeriodicGraphCheckpointerSuite._ | ||
|
|
||
| test("Persisting") { | ||
| var graphsToCheck = Seq.empty[GraphToCheck] | ||
|
|
||
| val graph1 = createGraph(sc) | ||
| val checkpointer = | ||
| new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) | ||
| checkpointer.update(graph1) | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) | ||
| checkPersistence(graphsToCheck, 1) | ||
|
|
||
| var iteration = 2 | ||
| while (iteration < 9) { | ||
| val graph = createGraph(sc) | ||
| checkpointer.update(graph) | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) | ||
| checkPersistence(graphsToCheck, iteration) | ||
| iteration += 1 | ||
| withSpark { sc => | ||
| val graph1 = createGraph(sc) | ||
| val checkpointer = | ||
| new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) | ||
| checkpointer.update(graph1) | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) | ||
| checkPersistence(graphsToCheck, 1) | ||
|
|
||
| var iteration = 2 | ||
| while (iteration < 9) { | ||
| val graph = createGraph(sc) | ||
| checkpointer.update(graph) | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) | ||
| checkPersistence(graphsToCheck, iteration) | ||
| iteration += 1 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Checkpointing") { | ||
| val tempDir = Utils.createTempDir() | ||
| val path = tempDir.toURI.toString | ||
| val checkpointInterval = 2 | ||
| var graphsToCheck = Seq.empty[GraphToCheck] | ||
| sc.setCheckpointDir(path) | ||
| val graph1 = createGraph(sc) | ||
| val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( | ||
| checkpointInterval, graph1.vertices.sparkContext) | ||
| checkpointer.update(graph1) | ||
| graph1.edges.count() | ||
| graph1.vertices.count() | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) | ||
| checkCheckpoint(graphsToCheck, 1, checkpointInterval) | ||
|
|
||
| var iteration = 2 | ||
| while (iteration < 9) { | ||
| val graph = createGraph(sc) | ||
| checkpointer.update(graph) | ||
| graph.vertices.count() | ||
| graph.edges.count() | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) | ||
| checkCheckpoint(graphsToCheck, iteration, checkpointInterval) | ||
| iteration += 1 | ||
| } | ||
| withSpark { sc => | ||
| val tempDir = Utils.createTempDir() | ||
| val path = tempDir.toURI.toString | ||
| val checkpointInterval = 2 | ||
| var graphsToCheck = Seq.empty[GraphToCheck] | ||
| sc.setCheckpointDir(path) | ||
| val graph1 = createGraph(sc) | ||
| val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( | ||
| checkpointInterval, graph1.vertices.sparkContext) | ||
| checkpointer.update(graph1) | ||
| graph1.edges.count() | ||
| graph1.vertices.count() | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) | ||
| checkCheckpoint(graphsToCheck, 1, checkpointInterval) | ||
|
|
||
| var iteration = 2 | ||
| while (iteration < 9) { | ||
| val graph = createGraph(sc) | ||
| checkpointer.update(graph) | ||
| graph.vertices.count() | ||
| graph.edges.count() | ||
| graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) | ||
| checkCheckpoint(graphsToCheck, iteration, checkpointInterval) | ||
| iteration += 1 | ||
| } | ||
|
|
||
| checkpointer.deleteAllCheckpoints() | ||
| graphsToCheck.foreach { graph => | ||
| confirmCheckpointRemoved(graph.graph) | ||
| } | ||
| checkpointer.deleteAllCheckpoints() | ||
| graphsToCheck.foreach { graph => | ||
| confirmCheckpointRemoved(graph.graph) | ||
| } | ||
|
|
||
| Utils.deleteRecursively(tempDir) | ||
| Utils.deleteRecursively(tempDir) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private object PeriodicGraphCheckpointerSuite { | ||
| private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER | ||
|
|
||
| case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) | ||
|
|
||
|
|
@@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { | |
| Edge[Double](3, 4, 0)) | ||
|
|
||
| def createGraph(sc: SparkContext): Graph[Double, Double] = { | ||
| Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) | ||
| Graph.fromEdges[Double, Double]( | ||
| sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) | ||
| } | ||
|
|
||
| def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { | ||
|
|
@@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { | |
| assert(graph.vertices.getStorageLevel == StorageLevel.NONE) | ||
| assert(graph.edges.getStorageLevel == StorageLevel.NONE) | ||
| } else { | ||
| assert(graph.vertices.getStorageLevel != StorageLevel.NONE) | ||
| assert(graph.edges.getStorageLevel != StorageLevel.NONE) | ||
| assert(graph.vertices.getStorageLevel == defaultStorageLevel) | ||
| assert(graph.edges.getStorageLevel == defaultStorageLevel) | ||
| } | ||
| } catch { | ||
| case _: AssertionError => | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it's the goal but this isn't thread-safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would agree with you, that this is not thread safe. Is that a concern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the limited internal only use, it should be ok