Skip to content

Commit 54d4eee

Browse files
GayathriMuralijkbradley
authored andcommitted
[SPARK-16240][ML] ML persistence backward compatibility for LDA - 2.0 backport
## What changes were proposed in this pull request? Allow Spark 2.x to load instances of LDA, LocalLDAModel, and DistributedLDAModel saved from Spark 1.6. Backport of #15034 for branch-2.0 ## How was this patch tested? I tested this manually, saving the 3 types from 1.6 and loading them into master (2.x). In the future, we can add generic tests for testing backwards compatibility across all ML models in SPARK-15573. Author: Gayathri Murali <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes #15205 from jkbradley/lda-backward-2.0.
1 parent 22216d6 commit 54d4eee

2 files changed

Lines changed: 72 additions & 17 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
package org.apache.spark.ml.clustering
1919

2020
import org.apache.hadoop.fs.Path
21+
import org.json4s.DefaultFormats
22+
import org.json4s.JsonAST.JObject
23+
import org.json4s.jackson.JsonMethods._
2124

2225
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
2326
import org.apache.spark.internal.Logging
@@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
2629
import org.apache.spark.ml.param._
2730
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
2831
import org.apache.spark.ml.util._
32+
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
2933
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
3034
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
3135
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
3236
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
3337
import org.apache.spark.mllib.impl.PeriodicCheckpointer
34-
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector,
35-
Vectors => OldVectors}
38+
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
3639
import org.apache.spark.mllib.linalg.MatrixImplicits._
3740
import org.apache.spark.mllib.linalg.VectorImplicits._
41+
import org.apache.spark.mllib.util.MLUtils
3842
import org.apache.spark.rdd.RDD
3943
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
4044
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
4145
import org.apache.spark.sql.types.StructType
46+
import org.apache.spark.util.VersionUtils
4247

4348

4449
private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
@@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
8085
* - Values should be >= 0
8186
* - default = uniformly (1.0 / k), following the implementation from
8287
* [[https://github.com/Blei-Lab/onlineldavb]].
88+
*
8389
* @group param
8490
*/
8591
@Since("1.6.0")
@@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
121127
* - Value should be >= 0
122128
* - default = (1.0 / k), following the implementation from
123129
* [[https://github.com/Blei-Lab/onlineldavb]].
130+
*
124131
* @group param
125132
*/
126133
@Since("1.6.0")
@@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
354361
}
355362
}
356363

364+
private object LDAParams {
365+
366+
/**
367+
* Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
368+
* formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
369+
*
370+
* @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
371+
* [[Param]] values extracted from metadata.
372+
* @param metadata Loaded model metadata
373+
*/
374+
def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = {
375+
VersionUtils.majorMinorVersion(metadata.sparkVersion) match {
376+
case (1, 6) =>
377+
implicit val format = DefaultFormats
378+
metadata.params match {
379+
case JObject(pairs) =>
380+
pairs.foreach { case (paramName, jsonValue) =>
381+
val origParam =
382+
if (paramName == "topicDistribution") "topicDistributionCol" else paramName
383+
val param = model.getParam(origParam)
384+
val value = param.jsonDecode(compact(render(jsonValue)))
385+
model.set(param, value)
386+
}
387+
case _ =>
388+
throw new IllegalArgumentException(
389+
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
390+
}
391+
case _ => // 2.0+
392+
DefaultParamsReader.getAndSetParams(model, metadata)
393+
}
394+
}
395+
}
396+
357397

358398
/**
359399
* :: Experimental ::
@@ -414,11 +454,11 @@ sealed abstract class LDAModel private[ml] (
414454
val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)
415455

416456
val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
417-
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
457+
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
418458
} else {
419459
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
420460
" such as topicDistributionCol to produce results.")
421-
dataset.toDF
461+
dataset.toDF()
422462
}
423463
}
424464

@@ -574,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
574614
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
575615
val dataPath = new Path(path, "data").toString
576616
val data = sparkSession.read.parquet(dataPath)
577-
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
578-
"gammaShape")
579-
.head()
580-
val vocabSize = data.getAs[Int](0)
581-
val topicsMatrix = data.getAs[Matrix](1)
582-
val docConcentration = data.getAs[Vector](2)
583-
val topicConcentration = data.getAs[Double](3)
584-
val gammaShape = data.getAs[Double](4)
617+
val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
618+
val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix")
619+
val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector,
620+
topicConcentration: Double, gammaShape: Double) =
621+
matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration",
622+
"topicConcentration", "gammaShape").head()
585623
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
586624
gammaShape)
587625
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
588-
DefaultParamsReader.getAndSetParams(model, metadata)
626+
LDAParams.getAndSetParams(model, metadata)
589627
model
590628
}
591629
}
@@ -731,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
731769
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
732770
val modelPath = new Path(path, "oldModel").toString
733771
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
734-
val model = new DistributedLDAModel(
735-
metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
736-
DefaultParamsReader.getAndSetParams(model, metadata)
772+
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
773+
oldModel, sparkSession, None)
774+
LDAParams.getAndSetParams(model, metadata)
737775
model
738776
}
739777
}
@@ -881,7 +919,7 @@ class LDA @Since("1.6.0") (
881919
}
882920

883921
@Since("2.0.0")
884-
object LDA extends DefaultParamsReadable[LDA] {
922+
object LDA extends MLReadable[LDA] {
885923

886924
/** Get dataset for spark.mllib LDA */
887925
private[clustering] def getOldDataset(
@@ -896,6 +934,20 @@ object LDA extends DefaultParamsReadable[LDA] {
896934
}
897935
}
898936

937+
private class LDAReader extends MLReader[LDA] {
938+
939+
private val className = classOf[LDA].getName
940+
941+
override def load(path: String): LDA = {
942+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
943+
val model = new LDA(metadata.uid)
944+
LDAParams.getAndSetParams(model, metadata)
945+
model
946+
}
947+
}
948+
949+
override def read: MLReader[LDA] = new LDAReader
950+
899951
@Since("2.0.0")
900952
override def load(path: String): LDA = super.load(path)
901953
}

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,9 @@ object MimaExcludes {
784784
// SPARK-17096: Improve exception string reported through the StreamingQueryListener
785785
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"),
786786
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this")
787+
) ++ Seq(
788+
// SPARK-16240: ML persistence backward compatibility for LDA
789+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$")
787790
)
788791
}
789792

0 commit comments

Comments
 (0)