Skip to content

Commit 2a47e2b

Browse files
committed
make rawPrediction optionall
1 parent 0cfc20a commit 2a47e2b

2 files changed

Lines changed: 26 additions & 12 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ final class OneVsRestModel private[ml] (
191191
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
192192
predictions + ((index, prediction(1)))
193193
}
194+
194195
model.setFeaturesCol($(featuresCol))
195196
val transformedDataset = model.transform(df).select(columns: _*)
196197
val updatedDataset = transformedDataset
@@ -206,18 +207,31 @@ final class OneVsRestModel private[ml] (
206207
}
207208

208209
// output the RawPrediction as vector
209-
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
210-
Vectors.sparse(numClasses, predictions.toList )
211-
}
210+
if (getRawPredictionCol != "") {
211+
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
212+
val myArray = Array.fill[Double](numClasses)(0.0)
213+
predictions.foreach { case (idx, value) => myArray(idx) = value }
214+
Vectors.dense(myArray)
215+
}
212216

213-
// output the index of the classifier with highest confidence as prediction
214-
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
217+
// output the index of the classifier with highest confidence as prediction
218+
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
215219

216-
// output confidence as rwa prediction, label and label metadata as prediction
217-
aggregatedDataset
218-
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
219-
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
220-
.drop(accColName)
220+
aggregatedDataset
221+
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
222+
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
223+
.drop(accColName)
224+
}
225+
else {
226+
// output the index of the classifier with highest confidence as prediction
227+
val labelUDF = udf { (predictions: Map[Int, Double]) =>
228+
predictions.maxBy(_._2)._1.toDouble
229+
}
230+
// output confidence as rwa prediction, label and label metadata as prediction
231+
aggregatedDataset
232+
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
233+
.drop(accColName)
234+
}
221235
}
222236

223237
@Since("1.4.1")

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
180180
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
181181
ovaModel.setFeaturesCol("fea")
182182
ovaModel.setPredictionCol("pred")
183-
ovaModel.setRawPredictionCol("rawpred")
183+
ovaModel.setRawPredictionCol("")
184184
val transformedDataset = ovaModel.transform(dataset2)
185185
val outputFields = transformedDataset.schema.fieldNames.toSet
186-
assert(outputFields === Set("y", "fea", "pred", "rawpred"))
186+
assert(outputFields === Set("y", "fea", "pred"))
187187
}
188188

189189
test("SPARK-8049: OneVsRest shouldn't output temp columns") {

0 commit comments

Comments
 (0)