Skip to content

Commit 2a0c319

Browse files
committed
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
Conflicts: project/MimaExcludes.scala
2 parents 10a9f91 + 953ff89 commit 2a0c319

8 files changed

Lines changed: 194 additions & 32 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20-
import org.apache.hadoop.fs.Path
20+
import org.apache.hadoop.fs.{FileSystem, Path}
2121

22-
import org.apache.spark.annotation.{Experimental, Since}
22+
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.ml.{Estimator, Model}
2525
import org.apache.spark.ml.param._
2626
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
2727
import org.apache.spark.ml.util._
2828
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
29-
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
30-
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
31-
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
29+
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
30+
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
31+
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
32+
import org.apache.spark.mllib.impl.PeriodicCheckpointer
3233
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
3334
import org.apache.spark.rdd.RDD
3435
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
@@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
4142

4243
/**
4344
* Param for the number of topics (clusters) to infer. Must be > 1. Default: 10.
45+
*
4446
* @group param
4547
*/
4648
@Since("1.6.0")
@@ -173,6 +175,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
173175
* This uses a variational approximation following Hoffman et al. (2010), where the approximate
174176
* distribution is called "gamma." Technically, this method returns this approximation "gamma"
175177
* for each document.
178+
*
176179
* @group param
177180
*/
178181
@Since("1.6.0")
@@ -191,6 +194,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
191194
* iterations count less.
192195
* This is called "tau0" in the Online LDA paper (Hoffman et al., 2010)
193196
* Default: 1024, following Hoffman et al.
197+
*
194198
* @group expertParam
195199
*/
196200
@Since("1.6.0")
@@ -207,6 +211,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
207211
* This should be between (0.5, 1.0] to guarantee asymptotic convergence.
208212
* This is called "kappa" in the Online LDA paper (Hoffman et al., 2010).
209213
* Default: 0.51, based on Hoffman et al.
214+
*
210215
* @group expertParam
211216
*/
212217
@Since("1.6.0")
@@ -230,6 +235,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
230235
* [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]].
231236
*
232237
* Default: 0.05, i.e., 5% of total documents.
238+
*
233239
* @group param
234240
*/
235241
@Since("1.6.0")
@@ -246,6 +252,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
246252
* document-topic distribution) will be optimized during training.
247253
* Setting this to true will make the model more expressive and fit the training data better.
248254
* Default: false
255+
*
249256
* @group expertParam
250257
*/
251258
@Since("1.6.0")
@@ -257,8 +264,32 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
257264
@Since("1.6.0")
258265
def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)
259266

267+
/**
268+
* For EM optimizer, if using checkpointing, this indicates whether to keep the last
269+
* checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can
270+
* cause failures if a data partition is lost, so set this bit with care.
271+
* Note that checkpoints will be cleaned up via reference counting, regardless.
272+
*
273+
* See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and
274+
* [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints.
275+
*
276+
* Default: true
277+
*
278+
* @group expertParam
279+
*/
280+
@Since("2.0.0")
281+
final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint",
282+
"For EM optimizer, if using checkpointing, this indicates whether to keep the last" +
283+
" checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" +
284+
" cause failures if a data partition is lost, so set this bit with care.")
285+
286+
/** @group expertGetParam */
287+
@Since("2.0.0")
288+
def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint)
289+
260290
/**
261291
* Validates and transforms the input schema.
292+
*
262293
* @param schema input schema
263294
* @return output schema
264295
*/
@@ -303,6 +334,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
303334
.setOptimizeDocConcentration($(optimizeDocConcentration))
304335
case "em" =>
305336
new OldEMLDAOptimizer()
337+
.setKeepLastCheckpoint($(keepLastCheckpoint))
306338
}
307339
}
308340

@@ -341,6 +373,7 @@ sealed abstract class LDAModel private[ml] (
341373
/**
342374
* The features for LDA should be a [[Vector]] representing the word counts in a document.
343375
* The vector should be of length vocabSize, with counts for each term (word).
376+
*
344377
* @group setParam
345378
*/
346379
@Since("1.6.0")
@@ -619,6 +652,35 @@ class DistributedLDAModel private[ml] (
619652
@Since("1.6.0")
620653
lazy val logPrior: Double = oldDistributedModel.logPrior
621654

655+
private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles
656+
657+
/**
658+
* If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be
659+
* saved checkpoint files. This method is provided so that users can manage those files.
660+
*
661+
* Note that removing the checkpoints can cause failures if a partition is lost and is needed
662+
* by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints
663+
* when this model and derivative data go out of scope.
664+
*
665+
* @return Checkpoint files from training
666+
*/
667+
@DeveloperApi
668+
@Since("2.0.0")
669+
def getCheckpointFiles: Array[String] = _checkpointFiles
670+
671+
/**
672+
* Remove any remaining checkpoint files from training.
673+
*
674+
* @see [[getCheckpointFiles]]
675+
*/
676+
@DeveloperApi
677+
@Since("2.0.0")
678+
def deleteCheckpointFiles(): Unit = {
679+
val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
680+
_checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
681+
_checkpointFiles = Array.empty[String]
682+
}
683+
622684
@Since("1.6.0")
623685
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
624686
}
@@ -696,11 +758,12 @@ class LDA @Since("1.6.0") (
696758

697759
setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10,
698760
learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05,
699-
optimizeDocConcentration -> true)
761+
optimizeDocConcentration -> true, keepLastCheckpoint -> true)
700762

701763
/**
702764
* The features for LDA should be a [[Vector]] representing the word counts in a document.
703765
* The vector should be of length vocabSize, with counts for each term (word).
766+
*
704767
* @group setParam
705768
*/
706769
@Since("1.6.0")
@@ -758,6 +821,10 @@ class LDA @Since("1.6.0") (
758821
@Since("1.6.0")
759822
def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value)
760823

824+
/** @group expertSetParam */
825+
@Since("2.0.0")
826+
def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value)
827+
761828
@Since("1.6.0")
762829
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
763830

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] (
534534
@Since("1.5.0") override val docConcentration: Vector,
535535
@Since("1.5.0") override val topicConcentration: Double,
536536
private[spark] val iterationTimes: Array[Double],
537-
override protected[clustering] val gammaShape: Double = 100)
537+
override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape,
538+
private[spark] val checkpointFiles: Array[String] = Array.empty[String])
538539
extends LDAModel {
539540

540541
import LDA._
@@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] (
806807

807808
override protected def formatVersion = "1.0"
808809

809-
/**
810-
* Java-friendly version of [[topicDistributions]]
811-
*/
812810
@Since("1.5.0")
813811
override def save(sc: SparkContext, path: String): Unit = {
812+
// Note: This intentionally does not save checkpointFiles.
814813
DistributedLDAModel.SaveLoadV1_0.save(
815814
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
816815
iterationTimes, gammaShape)
@@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] (
822821
@Since("1.5.0")
823822
object DistributedLDAModel extends Loader[DistributedLDAModel] {
824823

824+
/**
825+
* The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100
826+
* to ensure equivalence in LDAModel.toLocal conversion.
827+
*/
828+
private[clustering] val defaultGammaShape: Double = 100
829+
825830
private object SaveLoadV1_0 {
826831

827832
val thisFormatVersion = "1.0"

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer {
8080

8181
import LDA._
8282

83+
// Adjustable parameters
84+
private var keepLastCheckpoint: Boolean = true
85+
8386
/**
84-
* The following fields will only be initialized through the initialize() method
87+
* If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
88+
*/
89+
@Since("2.0.0")
90+
def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint
91+
92+
/**
93+
* If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
94+
* Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with
95+
* care. Note that checkpoints will be cleaned up via reference counting, regardless.
96+
*
97+
* Default: true
8598
*/
99+
@Since("2.0.0")
100+
def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = {
101+
this.keepLastCheckpoint = keepLastCheckpoint
102+
this
103+
}
104+
105+
// The following fields will only be initialized through the initialize() method
86106
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
87107
private[clustering] var k: Int = 0
88108
private[clustering] var vocabSize: Int = 0
@@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer {
208228

209229
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
210230
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
211-
this.graphCheckpointer.deleteAllCheckpoints()
231+
val checkpointFiles: Array[String] = if (keepLastCheckpoint) {
232+
this.graphCheckpointer.deleteAllCheckpointsButLast()
233+
this.graphCheckpointer.getAllCheckpointFiles
234+
} else {
235+
this.graphCheckpointer.deleteAllCheckpoints()
236+
Array.empty[String]
237+
}
212238
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
213-
// LDAModel.toLocal conversion
239+
// LDAModel.toLocal conversion.
214240
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
215241
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
216-
iterationTimes)
242+
iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles)
217243
}
218244
}
219245

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ private[mllib] abstract class PeriodicCheckpointer[T](
133133
}
134134
}
135135

136+
/**
137+
* Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
138+
* Note that there may not be any checkpoints at all.
139+
*/
140+
def deleteAllCheckpointsButLast(): Unit = {
141+
while (checkpointQueue.size > 1) {
142+
removeCheckpointFile()
143+
}
144+
}
145+
146+
/**
147+
* Get all current checkpoint files.
148+
* This is useful in combination with [[deleteAllCheckpointsButLast()]].
149+
*/
150+
def getAllCheckpointFiles: Array[String] = {
151+
checkpointQueue.flatMap(getCheckpointFiles).toArray
152+
}
153+
136154
/**
137155
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
138156
* This prints a warning but does not fail if the files cannot be removed.
@@ -141,15 +159,20 @@ private[mllib] abstract class PeriodicCheckpointer[T](
141159
val old = checkpointQueue.dequeue()
142160
// Since the old checkpoint is not deleted by Spark, we manually delete it.
143161
val fs = FileSystem.get(sc.hadoopConfiguration)
144-
getCheckpointFiles(old).foreach { checkpointFile =>
145-
try {
146-
fs.delete(new Path(checkpointFile), true)
147-
} catch {
148-
case e: Exception =>
149-
logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
150-
checkpointFile)
151-
}
152-
}
162+
getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
153163
}
164+
}
165+
166+
private[spark] object PeriodicCheckpointer extends Logging {
154167

168+
/** Delete a checkpoint file, and log a warning if deletion fails. */
169+
def removeCheckpointFile(path: String, fs: FileSystem): Unit = {
170+
try {
171+
fs.delete(new Path(path), true)
172+
} catch {
173+
case e: Exception =>
174+
logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
175+
path)
176+
}
177+
}
155178
}

mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import org.apache.hadoop.fs.{FileSystem, Path}
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2224
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
261263
testEstimatorAndModelReadWrite(lda, dataset,
262264
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
263265
}
266+
267+
test("EM LDA checkpointing: save last checkpoint") {
268+
// Checkpoint dir is set by MLlibTestSparkContext
269+
val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
270+
val model_ = lda.fit(dataset)
271+
assert(model_.isInstanceOf[DistributedLDAModel])
272+
val model = model_.asInstanceOf[DistributedLDAModel]
273+
274+
// There should be 1 checkpoint remaining.
275+
assert(model.getCheckpointFiles.length === 1)
276+
val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
277+
assert(fs.exists(new Path(model.getCheckpointFiles.head)))
278+
model.deleteCheckpointFiles()
279+
assert(model.getCheckpointFiles.isEmpty)
280+
}
281+
282+
test("EM LDA checkpointing: remove last checkpoint") {
283+
// Checkpoint dir is set by MLlibTestSparkContext
284+
val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
285+
.setKeepLastCheckpoint(false)
286+
val model_ = lda.fit(dataset)
287+
assert(model_.isInstanceOf[DistributedLDAModel])
288+
val model = model_.asInstanceOf[DistributedLDAModel]
289+
290+
assert(model.getCheckpointFiles.isEmpty)
291+
}
264292
}

0 commit comments

Comments
 (0)