1818package org .apache .spark .ml .clustering
1919
2020import org .apache .hadoop .fs .Path
21+ import org .json4s .DefaultFormats
22+ import org .json4s .JsonAST .JObject
23+ import org .json4s .jackson .JsonMethods ._
2124
2225import org .apache .spark .annotation .{DeveloperApi , Experimental , Since }
2326import org .apache .spark .internal .Logging
@@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
2629import org .apache .spark .ml .param ._
2730import org .apache .spark .ml .param .shared .{HasCheckpointInterval , HasFeaturesCol , HasMaxIter , HasSeed }
2831import org .apache .spark .ml .util ._
32+ import org .apache .spark .ml .util .DefaultParamsReader .Metadata
2933import org .apache .spark .mllib .clustering .{DistributedLDAModel => OldDistributedLDAModel ,
3034 EMLDAOptimizer => OldEMLDAOptimizer , LDA => OldLDA , LDAModel => OldLDAModel ,
3135 LDAOptimizer => OldLDAOptimizer , LocalLDAModel => OldLocalLDAModel ,
3236 OnlineLDAOptimizer => OldOnlineLDAOptimizer }
3337import org .apache .spark .mllib .impl .PeriodicCheckpointer
34- import org .apache .spark .mllib .linalg .{Matrices => OldMatrices , Vector => OldVector ,
35- Vectors => OldVectors }
38+ import org .apache .spark .mllib .linalg .{Vector => OldVector , Vectors => OldVectors }
3639import org .apache .spark .mllib .linalg .MatrixImplicits ._
3740import org .apache .spark .mllib .linalg .VectorImplicits ._
41+ import org .apache .spark .mllib .util .MLUtils
3842import org .apache .spark .rdd .RDD
3943import org .apache .spark .sql .{DataFrame , Dataset , Row , SparkSession }
4044import org .apache .spark .sql .functions .{col , monotonically_increasing_id , udf }
4145import org .apache .spark .sql .types .StructType
46+ import org .apache .spark .util .VersionUtils
4247
4348
4449private [clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
@@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
8085 * - Values should be >= 0
8186 * - default = uniformly (1.0 / k), following the implementation from
8287 * [[https://github.com/Blei-Lab/onlineldavb ]].
88+ *
8389 * @group param
8490 */
8591 @ Since (" 1.6.0" )
@@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
121127 * - Value should be >= 0
122128 * - default = (1.0 / k), following the implementation from
123129 * [[https://github.com/Blei-Lab/onlineldavb ]].
130+ *
124131 * @group param
125132 */
126133 @ Since (" 1.6.0" )
@@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
354361 }
355362}
356363
364+ private object LDAParams {
365+
366+ /**
367+ * Equivalent to [[DefaultParamsReader.getAndSetParams() ]], but handles [[LDA ]] and [[LDAModel ]]
368+ * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
369+ *
370+ * @param model [[LDA ]] or [[LDAModel ]] instance. This instance will be modified with
371+ * [[Param ]] values extracted from metadata.
372+ * @param metadata Loaded model metadata
373+ */
374+ def getAndSetParams (model : LDAParams , metadata : Metadata ): Unit = {
375+ VersionUtils .majorMinorVersion(metadata.sparkVersion) match {
376+ case (1 , 6 ) =>
377+ implicit val format = DefaultFormats
378+ metadata.params match {
379+ case JObject (pairs) =>
380+ pairs.foreach { case (paramName, jsonValue) =>
381+ val origParam =
382+ if (paramName == " topicDistribution" ) " topicDistributionCol" else paramName
383+ val param = model.getParam(origParam)
384+ val value = param.jsonDecode(compact(render(jsonValue)))
385+ model.set(param, value)
386+ }
387+ case _ =>
388+ throw new IllegalArgumentException (
389+ s " Cannot recognize JSON metadata: ${metadata.metadataJson}. " )
390+ }
391+ case _ => // 2.0+
392+ DefaultParamsReader .getAndSetParams(model, metadata)
393+ }
394+ }
395+ }
396+
357397
358398/**
359399 * :: Experimental ::
@@ -414,11 +454,11 @@ sealed abstract class LDAModel private[ml] (
414454 val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)
415455
416456 val t = udf { (v : Vector ) => transformer(OldVectors .fromML(v)).asML }
417- dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
457+ dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
418458 } else {
419459 logWarning(" LDAModel.transform was called without any output columns. Set an output column" +
420460 " such as topicDistributionCol to produce results." )
421- dataset.toDF
461+ dataset.toDF()
422462 }
423463 }
424464
@@ -574,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
574614 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
575615 val dataPath = new Path (path, " data" ).toString
576616 val data = sparkSession.read.parquet(dataPath)
577- .select(" vocabSize" , " topicsMatrix" , " docConcentration" , " topicConcentration" ,
578- " gammaShape" )
579- .head()
580- val vocabSize = data.getAs[Int ](0 )
581- val topicsMatrix = data.getAs[Matrix ](1 )
582- val docConcentration = data.getAs[Vector ](2 )
583- val topicConcentration = data.getAs[Double ](3 )
584- val gammaShape = data.getAs[Double ](4 )
617+ val vectorConverted = MLUtils .convertVectorColumnsToML(data, " docConcentration" )
618+ val matrixConverted = MLUtils .convertMatrixColumnsToML(vectorConverted, " topicsMatrix" )
619+ val Row (vocabSize : Int , topicsMatrix : Matrix , docConcentration : Vector ,
620+ topicConcentration : Double , gammaShape : Double ) =
621+ matrixConverted.select(" vocabSize" , " topicsMatrix" , " docConcentration" ,
622+ " topicConcentration" , " gammaShape" ).head()
585623 val oldModel = new OldLocalLDAModel (topicsMatrix, docConcentration, topicConcentration,
586624 gammaShape)
587625 val model = new LocalLDAModel (metadata.uid, vocabSize, oldModel, sparkSession)
588- DefaultParamsReader .getAndSetParams(model, metadata)
626+ LDAParams .getAndSetParams(model, metadata)
589627 model
590628 }
591629 }
@@ -731,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
731769 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
732770 val modelPath = new Path (path, " oldModel" ).toString
733771 val oldModel = OldDistributedLDAModel .load(sc, modelPath)
734- val model = new DistributedLDAModel (
735- metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None )
736- DefaultParamsReader .getAndSetParams(model, metadata)
772+ val model = new DistributedLDAModel (metadata.uid, oldModel.vocabSize,
773+ oldModel, sparkSession, None )
774+ LDAParams .getAndSetParams(model, metadata)
737775 model
738776 }
739777 }
@@ -881,7 +919,7 @@ class LDA @Since("1.6.0") (
881919}
882920
883921@ Since (" 2.0.0" )
884- object LDA extends DefaultParamsReadable [LDA ] {
922+ object LDA extends MLReadable [LDA ] {
885923
886924 /** Get dataset for spark.mllib LDA */
887925 private [clustering] def getOldDataset (
@@ -896,6 +934,20 @@ object LDA extends DefaultParamsReadable[LDA] {
896934 }
897935 }
898936
937+ private class LDAReader extends MLReader [LDA ] {
938+
939+ private val className = classOf [LDA ].getName
940+
941+ override def load (path : String ): LDA = {
942+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
943+ val model = new LDA (metadata.uid)
944+ LDAParams .getAndSetParams(model, metadata)
945+ model
946+ }
947+ }
948+
949+ override def read : MLReader [LDA ] = new LDAReader
950+
899951 @ Since (" 2.0.0" )
900952 override def load (path : String ): LDA = super .load(path)
901953}
0 commit comments