diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0359508c00395..bc053a9bac89f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
-import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
@@ -1419,7 +1419,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
- queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
+ queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
similarity index 97%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
index 145dc22b7428e..ab72addb2466b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.rdd.util
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
similarity index 95%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
index 4dd498cd91b4e..ce06e18879a49 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.util
import scala.collection.mutable
@@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
-private[mllib] abstract class PeriodicCheckpointer[T](
+private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {
@@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]
+ /**
+ * Call this to unpersist the Dataset.
+ */
+ def unpersistDataSet(): Unit = {
+ while (persistedQueue.nonEmpty) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ }
+
/**
* Call this at the end to delete any remaining checkpoint files.
*/
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index f9a7f151823a2..7f20206202cb9 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
}
test("get a range of elements in an array not partitioned by a range partitioner") {
- val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
+ val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairs = sc.parallelize(pairArr, 10)
val range = pairs.filterByRange(200, 800).collect()
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
similarity index 96%
rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
index 14adf8c29fc6b..f9e1b791c86ea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
@@ -15,18 +15,18 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.utils
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {
import PeriodicRDDCheckpointerSuite._
diff --git a/docs/configuration.md b/docs/configuration.md
index 2fcb3a096aea5..1823fb9b1f53a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -2115,6 +2115,20 @@ showDF(properties, numRows = 200, truncate = FALSE)
+### GraphX
+
+
+| Property Name | Default | Meaning |
+
+ spark.graphx.pregel.checkpointInterval |
+ -1 |
+
+ Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains
+ after lots of iterations. The checkpoint is disabled by default.
+ |
+
+
+
### Deploy
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index e271b28fb4f28..76aa7b405e18c 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -708,7 +708,9 @@ messages remaining.
> messaging function. These constraints allow additional optimization within GraphX.
The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
-of its implementation (note calls to graph.cache have been removed):
+of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally
+checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number,
+say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)):
{% highlight scala %}
class GraphOps[VD, ED] {
@@ -722,6 +724,7 @@ class GraphOps[VD, ED] {
: Graph[VD, ED] = {
// Receive the initial message at each vertex
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
+
// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
@@ -734,8 +737,8 @@ class GraphOps[VD, ED] {
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
- messages = g.mapReduceTriplets(
- sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ messages = GraphXUtils.mapReduceTriplets(
+ g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
activeMessages = messages.count()
i += 1
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 646462b4a8350..755c6febc48e6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -19,7 +19,10 @@ package org.apache.spark.graphx
import scala.reflect.ClassTag
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
/**
* Implements a Pregel-like bulk-synchronous message-passing API.
@@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")
- var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
+ val checkpointInterval = graph.vertices.sparkContext.getConf
+ .getInt("spark.graphx.pregel.checkpointInterval", -1)
+ var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
+ val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
+ checkpointInterval, graph.vertices.sparkContext)
+ graphCheckpointer.update(g)
+
// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
+ val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
+ checkpointInterval, graph.vertices.sparkContext)
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
var activeMessages = messages.count()
+
// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
- g = g.joinVertices(messages)(vprog).cache()
+ g = g.joinVertices(messages)(vprog)
+ graphCheckpointer.update(g)
val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = GraphXUtils.mapReduceTriplets(
- g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()
logInfo("Pregel finished iteration " + i)
@@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
- messages.unpersist(blocking = false)
+ messageCheckpointer.unpersistDataSet()
+ graphCheckpointer.deleteAllCheckpoints()
+ messageCheckpointer.deleteAllCheckpoints()
g
} // end of apply
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
similarity index 91%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
index 80074897567eb..fda501aa757d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
@@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
- * TODO: Move this out of MLlib?
*/
-private[mllib] class PeriodicGraphCheckpointer[VD, ED](
+private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
@@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
- data.vertices.persist()
+ /* We need to use cache because persist does not honor the default storage level requested
+ * when constructing the graph. Only cache does that.
+ */
+ data.vertices.cache()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
- data.edges.persist()
+ data.edges.cache()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
similarity index 70%
rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
index a13e7f63a9296..e0c65e6940f66 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
@@ -15,77 +15,81 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.graphx.{Edge, Graph}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext {
import PeriodicGraphCheckpointerSuite._
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
- val graph1 = createGraph(sc)
- val checkpointer =
- new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkPersistence(graphsToCheck, 1)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkPersistence(graphsToCheck, iteration)
- iteration += 1
+ withSpark { sc =>
+ val graph1 = createGraph(sc)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkPersistence(graphsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkPersistence(graphsToCheck, iteration)
+ iteration += 1
+ }
}
}
test("Checkpointing") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
- val checkpointInterval = 2
- var graphsToCheck = Seq.empty[GraphToCheck]
- sc.setCheckpointDir(path)
- val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
- checkpointInterval, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graph1.edges.count()
- graph1.vertices.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkCheckpoint(graphsToCheck, 1, checkpointInterval)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graph.vertices.count()
- graph.edges.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
- iteration += 1
- }
+ withSpark { sc =>
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var graphsToCheck = Seq.empty[GraphToCheck]
+ sc.setCheckpointDir(path)
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graph1.edges.count()
+ graph1.vertices.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graph.vertices.count()
+ graph.edges.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
- checkpointer.deleteAllCheckpoints()
- graphsToCheck.foreach { graph =>
- confirmCheckpointRemoved(graph.graph)
- }
+ checkpointer.deleteAllCheckpoints()
+ graphsToCheck.foreach { graph =>
+ confirmCheckpointRemoved(graph.graph)
+ }
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(tempDir)
+ }
}
}
private object PeriodicGraphCheckpointerSuite {
+ private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER
case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
@@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite {
Edge[Double](3, 4, 0))
def createGraph(sc: SparkContext): Graph[Double, Double] = {
- Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+ Graph.fromEdges[Double, Double](
+ sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel)
}
def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
@@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite {
assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
assert(graph.edges.getStorageLevel == StorageLevel.NONE)
} else {
- assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
- assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+ assert(graph.vertices.getStorageLevel == defaultStorageLevel)
+ assert(graph.edges.getStorageLevel == defaultStorageLevel)
}
} catch {
case _: AssertionError =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index bbcef3502d1dc..a7812643877e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -34,7 +34,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
-import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -43,9 +42,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils
-
private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
with HasSeed with HasCheckpointInterval {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index f3bace8181570..1ce7c87dbd15d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 48bae4276c480..3697a9b46dd84 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.graphx._
-import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel