@@ -191,6 +191,7 @@ final class OneVsRestModel private[ml] (
191191 val updateUDF = udf { (predictions : Map [Int , Double ], prediction : Vector ) =>
192192 predictions + ((index, prediction(1 )))
193193 }
194+
194195 model.setFeaturesCol($(featuresCol))
195196 val transformedDataset = model.transform(df).select(columns : _* )
196197 val updatedDataset = transformedDataset
@@ -206,18 +207,31 @@ final class OneVsRestModel private[ml] (
206207 }
207208
208209 // output the RawPrediction as vector
209- val rawPredictionUDF = udf { (predictions : Map [Int , Double ]) =>
210- Vectors .sparse(numClasses, predictions.toList )
211- }
210+ if (getRawPredictionCol != " " ) {
211+ val rawPredictionUDF = udf { (predictions : Map [Int , Double ]) =>
212+ val myArray = Array .fill[Double ](numClasses)(0.0 )
213+ predictions.foreach { case (idx, value) => myArray(idx) = value }
214+ Vectors .dense(myArray)
215+ }
212216
213- // output the index of the classifier with highest confidence as prediction
214- val labelUDF = udf { (predictions : Vector ) => predictions.argmax.toDouble }
217+ // output the index of the classifier with highest confidence as prediction
218+ val labelUDF = udf { (predictions : Vector ) => predictions.argmax.toDouble }
215219
216- // output confidence as rwa prediction, label and label metadata as prediction
217- aggregatedDataset
218- .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
219- .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
220- .drop(accColName)
220+ aggregatedDataset
221+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
222+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
223+ .drop(accColName)
224+ }
225+ else {
226+ // output the index of the classifier with highest confidence as prediction
227+ val labelUDF = udf { (predictions : Map [Int , Double ]) =>
228+ predictions.maxBy(_._2)._1.toDouble
229+ }
230+ // output confidence as rwa prediction, label and label metadata as prediction
231+ aggregatedDataset
232+ .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
233+ .drop(accColName)
234+ }
221235 }
222236
223237 @ Since (" 1.4.1" )
0 commit comments