Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util

import scala.collection.mutable

Expand Down Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
private[mllib] abstract class PeriodicCheckpointer[T](
private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {

Expand Down Expand Up @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]

/**
* Call this to unpersist the Dataset.
*/
def unpersistDataSet(): Unit = {
while (persistedQueue.nonEmpty) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's the goal but this isn't thread-safe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would agree with you, that this is not thread safe. Is that a concern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the limited internal only use, it should be ok

val dataToUnpersist = persistedQueue.dequeue()
unpersist(dataToUnpersist)
}
}

/**
* Call this at the end to delete any remaining checkpoint files.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util
Copy link
Member

@felixcheung felixcheung Apr 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's PeriodicRDDCheckpointer, shouldn't this be in the org.apache.spark.rdd.util namespace?


import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util

import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {

import PeriodicRDDCheckpointerSuite._

Expand Down
29 changes: 23 additions & 6 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package org.apache.spark.graphx

import scala.reflect.ClassTag

import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.PeriodicRDDCheckpointer

/**
* Implements a Pregel-like bulk-synchronous message-passing API.
Expand Down Expand Up @@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
val checkpointInterval = graph.vertices.sparkContext.getConf
.getInt("spark.graphx.pregel.checkpointInterval", 10)
Copy link
Member

@viirya viirya Feb 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need to document this config into GraphX related document?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Currently I add the document in graphx-programming-guide.md. But I am not sure if it's the right place, please let me know if there is a better place to add the document.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also suggest incorporating this change into the Spark 2.2 release notes under a section for GraphX, but I don't see where these notes are maintained. The release notes for 2.1 are published at http://spark.apache.org/releases/spark-release-2-1-0.html, but I can't find them in the repo. Anybody know how these are generated or how to contribute to them? Is there another repo for this documentation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
checkpointInterval, graph.vertices.sparkContext)
graphCheckpointer.update(g)

// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
checkpointInterval, graph.vertices.sparkContext)
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update it before the loop? I think it should be enough to do the update in the loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. What do you think, @dding3?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, for the sake of simplicity and consistency, I'm going to suggest we keep the checkpointer update calls but remove all .cache() calls. The update calls persist the underlying data, making the calls to .cache() unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need cache graph/messages here so they don't need to be computed again in the loop. I agree with you and I will keep the checkpointer update calls and remove all .cache calls.

var activeMessages = messages.count()

// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
g = g.joinVertices(messages)(vprog).cache()
g = g.joinVertices(messages)(vprog)
graphCheckpointer.update(g)

val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
// and periodic checkpoint messages so it can be materialized on the next line, and avoid
// to have a long lineage chain.
messages = GraphXUtils.mapReduceTriplets(
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the comment here be updated?

// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()

logInfo("Pregel finished iteration " + i)
Expand All @@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
messages.unpersist(blocking = false)
messageCheckpointer.unpersistDataSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand this change. Why do we replace

messages.unpersist(blocking = false)

with

messageCheckpointer.unpersistDataSet()

Especially because this adds a new public method to PeriodicCheckpointer that no other code has needed before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the thing is we use messageCheckpointer.update to do the cache, to make a pair, we can use it to unpersist data. Please correct me if I understand wrong.
I think it's fine to add this new method as there is already a public method to cache data in PersistQueue, we should provide a public method to clean the queue.

graphCheckpointer.deleteAllCheckpoints()
messageCheckpointer.deleteAllCheckpoints()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be inside a finally clause to make sure checkpoint data is cleaned up even in error cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when there is an exception during training, if we keep the checkpoints, there is a chance for user to recover from it. I checked in RandomForest/GBT in spark, looks like they only delete the checkpoints when the training successful finished.

g
} // end of apply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.graphx.util

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down Expand Up @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
* TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
Expand All @@ -87,10 +87,10 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](

override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
data.vertices.persist()
data.vertices.cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't persist better? this could potentially support different storage level later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use cache because persist does not honor the default storage level requested when constructing the graph. Only cache does that. It's confusing, but true. To verify this for yourself, change these values to persist and run the PeriodicGraphCheckpointerSuite tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mallman @dding3 it will be good to add a comment on that here

}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
data.edges.persist()
data.edges.cache()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,77 +15,81 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.graphx.util

import org.apache.hadoop.fs.Path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this test suite isn't moved into the GraphX codebase?

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext {

import PeriodicGraphCheckpointerSuite._

test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]

val graph1 = createGraph(sc)
val checkpointer =
new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)

var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
checkpointer.update(graph)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkPersistence(graphsToCheck, iteration)
iteration += 1
withSpark { sc =>
val graph1 = createGraph(sc)
val checkpointer =
new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)

var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
checkpointer.update(graph)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkPersistence(graphsToCheck, iteration)
iteration += 1
}
}
}

test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val checkpointInterval = 2
var graphsToCheck = Seq.empty[GraphToCheck]
sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
checkpointInterval, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkCheckpoint(graphsToCheck, 1, checkpointInterval)

var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
checkpointer.update(graph)
graph.vertices.count()
graph.edges.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
iteration += 1
}
withSpark { sc =>
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val checkpointInterval = 2
var graphsToCheck = Seq.empty[GraphToCheck]
sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
checkpointInterval, graph1.vertices.sparkContext)
checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkCheckpoint(graphsToCheck, 1, checkpointInterval)

var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
checkpointer.update(graph)
graph.vertices.count()
graph.edges.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
iteration += 1
}

checkpointer.deleteAllCheckpoints()
graphsToCheck.foreach { graph =>
confirmCheckpointRemoved(graph.graph)
}
checkpointer.deleteAllCheckpoints()
graphsToCheck.foreach { graph =>
confirmCheckpointRemoved(graph.graph)
}

Utils.deleteRecursively(tempDir)
Utils.deleteRecursively(tempDir)
}
}
}

private object PeriodicGraphCheckpointerSuite {
private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER

case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)

Expand All @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite {
Edge[Double](3, 4, 0))

def createGraph(sc: SparkContext): Graph[Double, Double] = {
Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
Graph.fromEdges[Double, Double](
sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel)
}

def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
Expand All @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite {
assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
assert(graph.edges.getStorageLevel == StorageLevel.NONE)
} else {
assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
assert(graph.edges.getStorageLevel != StorageLevel.NONE)
assert(graph.vertices.getStorageLevel == defaultStorageLevel)
assert(graph.edges.getStorageLevel == defaultStorageLevel)
}
} catch {
case _: AssertionError =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
Expand All @@ -43,9 +42,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils


private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
with HasSeed with HasCheckpointInterval {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicRDDCheckpointer


private[spark] object GradientBoostedTrees extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis}

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.graphx._
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down