Skip to content

Commit e7a489c

Browse files
gatorsmilejkbradley
authored andcommitted
[SPARK-15644][MLLIB][SQL] Replace SQLContext with SparkSession in MLlib
#### What changes were proposed in this pull request? This PR is to use the latest `SparkSession` to replace the existing `SQLContext` in `MLlib`. `SQLContext` is removed from `MLlib`. Also fix a test case issue in `BroadcastJoinSuite`. BTW, `SQLContext` is not being used in the `MLlib` test suites. #### How was this patch tested? Existing test cases. Author: gatorsmile <[email protected]> Author: xiaoli <[email protected]> Author: Xiao Li <[email protected]> Closes #13380 from gatorsmile/sqlContextML. (cherry picked from commit 0e3ce75) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent f3a2ebe commit e7a489c

31 files changed

+100
-81
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
243243
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
244244
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
245245
val dataPath = new Path(path, "data").toString
246-
sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
246+
sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
247247
}
248248
}
249249

@@ -258,7 +258,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
258258
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
259259
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
260260
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
261-
val root = loadTreeNodes(path, metadata, sqlContext)
261+
val root = loadTreeNodes(path, metadata, sparkSession)
262262
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
263263
DefaultParamsReader.getAndSetParams(model, metadata)
264264
model

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
270270
val extraMetadata: JObject = Map(
271271
"numFeatures" -> instance.numFeatures,
272272
"numTrees" -> instance.getNumTrees)
273-
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
273+
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
274274
}
275275
}
276276

@@ -283,7 +283,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
283283
override def load(path: String): GBTClassificationModel = {
284284
implicit val format = DefaultFormats
285285
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
286-
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
286+
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
287287
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
288288
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
289289

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
660660
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
661661
instance.coefficients)
662662
val dataPath = new Path(path, "data").toString
663-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
663+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
664664
}
665665
}
666666

@@ -674,7 +674,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
674674
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
675675

676676
val dataPath = new Path(path, "data").toString
677-
val data = sqlContext.read.format("parquet").load(dataPath)
677+
val data = sparkSession.read.format("parquet").load(dataPath)
678678
.select("numClasses", "numFeatures", "intercept", "coefficients").head()
679679
// We will need numClasses, numFeatures in the future for multinomial logreg support.
680680
// val numClasses = data.getInt(0)

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ object MultilayerPerceptronClassificationModel
356356
// Save model data: layers, weights
357357
val data = Data(instance.layers, instance.weights)
358358
val dataPath = new Path(path, "data").toString
359-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
359+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
360360
}
361361
}
362362

@@ -370,7 +370,7 @@ object MultilayerPerceptronClassificationModel
370370
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
371371

372372
val dataPath = new Path(path, "data").toString
373-
val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
373+
val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head()
374374
val layers = data.getAs[Seq[Int]](0).toArray
375375
val weights = data.getAs[Vector](1)
376376
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
262262
// Save model data: pi, theta
263263
val data = Data(instance.pi, instance.theta)
264264
val dataPath = new Path(path, "data").toString
265-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
265+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
266266
}
267267
}
268268

@@ -275,7 +275,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
275275
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
276276

277277
val dataPath = new Path(path, "data").toString
278-
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
278+
val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
279279
val pi = data.getAs[Vector](0)
280280
val theta = data.getAs[Matrix](1)
281281
val model = new NaiveBayesModel(metadata.uid, pi, theta)

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
282282
"numFeatures" -> instance.numFeatures,
283283
"numClasses" -> instance.numClasses,
284284
"numTrees" -> instance.getNumTrees)
285-
EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
285+
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
286286
}
287287
}
288288

@@ -296,7 +296,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
296296
override def load(path: String): RandomForestClassificationModel = {
297297
implicit val format = DefaultFormats
298298
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
299-
EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
299+
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
300300
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
301301
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
302302
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
195195
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
196196
val data = Data(weights, mus, sigmas)
197197
val dataPath = new Path(path, "data").toString
198-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
198+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
199199
}
200200
}
201201

@@ -208,7 +208,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
208208
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
209209

210210
val dataPath = new Path(path, "data").toString
211-
val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
211+
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
212212
val weights = row.getSeq[Double](0).toArray
213213
val mus = row.getSeq[OldVector](1).toArray
214214
val sigmas = row.getSeq[OldMatrix](2).toArray

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
211211
Data(idx, center)
212212
}
213213
val dataPath = new Path(path, "data").toString
214-
sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
214+
sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
215215
}
216216
}
217217

@@ -222,8 +222,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
222222

223223
override def load(path: String): KMeansModel = {
224224
// Import implicits for Dataset Encoder
225-
val sqlContext = super.sqlContext
226-
import sqlContext.implicits._
225+
val sparkSession = super.sparkSession
226+
import sparkSession.implicits._
227227

228228
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
229229
val dataPath = new Path(path, "data").toString
@@ -232,11 +232,11 @@ object KMeansModel extends MLReadable[KMeansModel] {
232232
val versionRegex(major, _) = metadata.sparkVersion
233233

234234
val clusterCenters = if (major.toInt >= 2) {
235-
val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
235+
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
236236
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
237237
} else {
238238
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
239-
sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
239+
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
240240
}
241241
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
242242
DefaultParamsReader.getAndSetParams(model, metadata)

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
202202
DefaultParamsWriter.saveMetadata(instance, path, sc)
203203
val data = Data(instance.selectedFeatures.toSeq)
204204
val dataPath = new Path(path, "data").toString
205-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
205+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
206206
}
207207
}
208208

@@ -213,7 +213,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
213213
override def load(path: String): ChiSqSelectorModel = {
214214
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
215215
val dataPath = new Path(path, "data").toString
216-
val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
216+
val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
217217
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
218218
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
219219
val model = new ChiSqSelectorModel(metadata.uid, oldModel)

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
297297
DefaultParamsWriter.saveMetadata(instance, path, sc)
298298
val data = Data(instance.vocabulary)
299299
val dataPath = new Path(path, "data").toString
300-
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
300+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
301301
}
302302
}
303303

@@ -308,7 +308,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
308308
override def load(path: String): CountVectorizerModel = {
309309
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
310310
val dataPath = new Path(path, "data").toString
311-
val data = sqlContext.read.parquet(dataPath)
311+
val data = sparkSession.read.parquet(dataPath)
312312
.select("vocabulary")
313313
.head()
314314
val vocabulary = data.getAs[Seq[String]](0).toArray

0 commit comments

Comments
 (0)