diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d8f3dfa87443..58815434cbda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -204,8 +204,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { - this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + - " since no output columns were set.") + this.logWarning(s"$uid: Predictor.transform() does nothing" + + " because no output columns were set.") dataset.toDF } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index e35e6ce7fdad..568cdd11a12a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -204,8 +204,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur } if (numColsOutput == 0) { - logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + logWarning(s"$uid: ClassificationModel.transform() does nothing" + + " because no output columns were set.") } outputData.toDF } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index e1fceb1fc96a..675315e3bb07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -169,9 +169,9 @@ final class OneVsRestModel private[ml] ( // Check schema transformSchema(dataset.schema, logging = true) - if (getPredictionCol == "" && getRawPredictionCol == "") { - logWarning(s"$uid: OneVsRestModel.transform() was called as NOOP" + - " since no output columns were set.") + if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) { + logWarning(s"$uid: OneVsRestModel.transform() does nothing" + + " because no output columns were set.") return dataset.toDF } @@ -218,7 +218,7 @@ final class OneVsRestModel private[ml] ( var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] - if (getRawPredictionCol != "") { + if (getRawPredictionCol.nonEmpty) { val numClass = models.length // output the RawPrediction as vector @@ -228,18 +228,18 @@ final class OneVsRestModel private[ml] ( Vectors.dense(predArray) } - predictionColNames = predictionColNames :+ getRawPredictionCol - predictionColumns = predictionColumns :+ rawPredictionUDF(col(accColName)) + predictionColNames :+= getRawPredictionCol + predictionColumns :+= rawPredictionUDF(col(accColName)) } - if (getPredictionCol != "") { + if (getPredictionCol.nonEmpty) { // output the index of the classifier with highest confidence as prediction val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } - predictionColNames = predictionColNames :+ getPredictionCol - predictionColumns = predictionColumns :+ labelUDF(col(accColName)) + predictionColNames :+= getPredictionCol + predictionColumns :+= labelUDF(col(accColName)) .as(getPredictionCol, labelMetadata) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 730fcab333e1..5046caa568d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -147,8 +147,8 @@ abstract class ProbabilisticClassificationModel[ } if (numColsOutput == 0) { - this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() does nothing" + + " because no output columns were set.") } outputData.toDF } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index fb4698ab5564..9a51d2f18846 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -110,11 +110,29 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predUDF = udf((vector: Vector) => predict(vector)) - val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset - .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) - .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + + if ($(predictionCol).nonEmpty) { + val predUDF = udf((vector: Vector) => predict(vector)) + predictionColNames :+= $(predictionCol) + predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + } + + if ($(probabilityCol).nonEmpty) { + val probUDF = udf((vector: Vector) => predictProbability(vector)) + predictionColNames :+= $(probabilityCol) + predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() + } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 681bb9515618..91201e7bd03f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -461,17 +461,10 @@ abstract class LDAModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - if ($(topicDistributionCol).nonEmpty) { - val func = getTopicDistributionMethod - val transformer = udf(func) - - dataset.withColumn($(topicDistributionCol), - transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol))) - } else { - logWarning("LDAModel.transform was called without any output columns. Set an output column" + - " such as topicDistributionCol to produce results.") - dataset.toDF() - } + val func = getTopicDistributionMethod + val transformer = udf(func) + dataset.withColumn($(topicDistributionCol), + transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 067dfa43433e..1565782dd631 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -355,13 +355,28 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf { features: Vector => predict(features) } - val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + + if ($(predictionCol).nonEmpty) { + val predictUDF = udf { features: Vector => predict(features) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol))) + } + if (hasQuantilesCol) { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) + val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + predictionColNames :+= $(quantilesCol) + predictionColumns :+= predictQuantilesUDF(col($(featuresCol))) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) } else { - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + this.logWarning(s"$uid: AFTSurvivalRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index f4f4e56a3578..6348289de516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -216,16 +216,28 @@ class DecisionTreeRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector) => predict(features) } - val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset.toDF() + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + val predictUDF = udf { (features: Vector) => predict(features) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol))) } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { - output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + predictionColNames :+= $(varianceCol) + predictionColumns :+= predictVarianceUDF(col($(featuresCol))) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: DecisionTreeRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } - output } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 885b13bf8dac..b1a8f95c1261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1041,18 +1041,31 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } - val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) - var output = dataset + if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) + val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } + predictionColNames :+= $(predictionCol) + predictionColumns :+= predictUDF(col($(featuresCol)), offset) } + if (hasLinkPredictionCol) { - output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) + val predictLinkUDF = + udf { (features: Vector, offset: Double) => predictLink(features, offset) } + predictionColNames :+= $(linkPredictionCol) + predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset) + } + + if (predictionColNames.nonEmpty) { + dataset.withColumns(predictionColNames, predictionColumns) + } else { + this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" + + " because no output columns were set.") + dataset.toDF() } - output.toDF() } /**