Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
if ($(predictionCol).nonEmpty) {
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
this.logWarning(s"$uid: Predictor.transform() does nothing" +
" because no output columns were set.")
dataset.toDF
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
Expand Down Expand Up @@ -204,8 +204,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
}

if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
logWarning(s"$uid: ClassificationModel.transform() does nothing" +
" because no output columns were set.")
}
outputData.toDF
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ final class OneVsRestModel private[ml] (
// Check schema
transformSchema(dataset.schema, logging = true)

if (getPredictionCol == "" && getRawPredictionCol == "") {
logWarning(s"$uid: OneVsRestModel.transform() was called as NOOP" +
" since no output columns were set.")
if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) {
logWarning(s"$uid: OneVsRestModel.transform() does nothing" +
" because no output columns were set.")
return dataset.toDF
}

Expand Down Expand Up @@ -218,7 +218,7 @@ final class OneVsRestModel private[ml] (
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

if (getRawPredictionCol != "") {
if (getRawPredictionCol.nonEmpty) {
val numClass = models.length

// output the RawPrediction as vector
Expand All @@ -228,18 +228,18 @@ final class OneVsRestModel private[ml] (
Vectors.dense(predArray)
}

predictionColNames = predictionColNames :+ getRawPredictionCol
predictionColumns = predictionColumns :+ rawPredictionUDF(col(accColName))
predictionColNames :+= getRawPredictionCol
predictionColumns :+= rawPredictionUDF(col(accColName))
}

if (getPredictionCol != "") {
if (getPredictionCol.nonEmpty) {
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}

predictionColNames = predictionColNames :+ getPredictionCol
predictionColumns = predictionColumns :+ labelUDF(col(accColName))
predictionColNames :+= getPredictionCol
predictionColumns :+= labelUDF(col(accColName))
.as(getPredictionCol, labelMetadata)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ abstract class ProbabilisticClassificationModel[
}

if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() does nothing" +
" because no output columns were set.")
}
outputData.toDF
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -110,11 +110,29 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predUDF = udf((vector: Vector) => predict(vector))
val probUDF = udf((vector: Vector) => predictProbability(vector))
dataset
.withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
.withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))

var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

if ($(predictionCol).nonEmpty) {
val predUDF = udf((vector: Vector) => predict(vector))
predictionColNames :+= $(predictionCol)
predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
}

if ($(probabilityCol).nonEmpty) {
val probUDF = udf((vector: Vector) => predictProbability(vector))
predictionColNames :+= $(probabilityCol)
predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
}

@Since("2.0.0")
Expand Down
15 changes: 4 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,17 +461,10 @@ abstract class LDAModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)

if ($(topicDistributionCol).nonEmpty) {
val func = getTopicDistributionMethod
val transformer = udf(func)

dataset.withColumn($(topicDistributionCol),
transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
dataset.toDF()
}
val func = getTopicDistributionMethod
val transformer = udf(func)
dataset.withColumn($(topicDistributionCol),
transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -355,13 +355,28 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}

var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

if ($(predictionCol).nonEmpty) {
val predictUDF = udf { features: Vector => predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
}

if (hasQuantilesCol) {
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
.withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol))))
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
predictionColNames :+= $(quantilesCol)
predictionColumns :+= predictQuantilesUDF(col($(featuresCol)))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
this.logWarning(s"$uid: AFTSurvivalRegressionModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

Expand Down Expand Up @@ -216,16 +216,28 @@ class DecisionTreeRegressionModel private[ml] (
}

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
var output = dataset.toDF()
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
val predictUDF = udf { (features: Vector) => predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
}

if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol))))
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
predictionColNames :+= $(varianceCol)
predictionColumns :+= predictVarianceUDF(col($(featuresCol)))
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
this.logWarning(s"$uid: DecisionTreeRegressionModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
output
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,18 +1041,31 @@ class GeneralizedLinearRegressionModel private[ml] (
}

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) }
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]

val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
var output = dataset

if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset))
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)), offset)
}

if (hasLinkPredictionCol) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset))
val predictLinkUDF =
udf { (features: Vector, offset: Double) => predictLink(features, offset) }
predictionColNames :+= $(linkPredictionCol)
predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset)
}

if (predictionColNames.nonEmpty) {
dataset.withColumns(predictionColNames, predictionColumns)
} else {
this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" +
" because no output columns were set.")
dataset.toDF()
}
output.toDF()
}

/**
Expand Down