Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
}
if (getPredictionCol != "") {
val predUDF = if (getRawPredictionCol != "") {
callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
udf[Double, Vector](raw2prediction).apply(col(getRawPredictionCol))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type info is not necessary. Try

udf(raw2prediction _).apply(col(getRawPredictionCol))

The callUDF on line 106 needs some explicit type info since FeaturesType doesn't have type tag. We can write this:

      val predictRawUDF = udf { (features: Any) =>
        predictRaw(features.asInstanceOf[FeaturesType])
      }
      outputData = outputData.withColumn(getRawPredictionCol,
        predictRawUDF(col(getFeaturesCol)))

} else {
callUDF(predict _, DoubleType, col(getFeaturesCol))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this line as well.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ final class OneVsRestModel private[ml] (
val accColName = "mbc$acc" + UUID.randomUUID().toString
val init: () => Map[Int, Double] = () => {Map()}
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
val newDataset = dataset.withColumn(accColName, udf(init).apply())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: I think we call udf above so the invocation just looks like init().


// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -110,9 +110,9 @@ final class OneVsRestModel private[ml] (
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
}
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try this:

        val update = udf { (predictions: Map[Int, Double], prediction: Vector) =>
          predictions + ((index, prediction(1)))
        }
        val transformedDataset = model.transform(df).select(columns : _*)
        val updatedDataset = transformedDataset.withColumn(
          tmpColName, update(col(accColName), col(rawPredictionCol)))

val transformedDataset = model.transform(df).select(columns : _*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF)
val newColumns = origCols ++ List(col(tmpColName))

// switch out the intermediate column with the accumulator column
Expand All @@ -129,8 +129,8 @@ final class OneVsRestModel private[ml] (
}

// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
val labelUDF = udf(label).apply(col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUDF.as($(predictionCol), labelMetadata))
.drop(accColName)
}

Expand Down Expand Up @@ -175,12 +175,12 @@ final class OneVsRest(override val uid: String)
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val multiClassLabeled = dataset.select($(labelCol), $(featuresCol))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiclass is a single word in ML vocab.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad.


// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
multiClassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
}

// create k columns, one for each binary classifier.
Expand All @@ -192,17 +192,17 @@ final class OneVsRest(override val uid: String)

// generate new label metadata for the binary problem.
// TODO: use when ... otherwise after SPARK-7321 is merged
val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
val labelUDF = udf(label).apply(col($(labelCol)))
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val trainingDataset = multiClassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
}.toArray[ClassificationModel[_, _]]

if (handlePersistence) {
multiclassLabeled.unpersist()
multiClassLabeled.unpersist()
}

// extract label metadata from label column if present, or create a nominal attribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ private[spark] abstract class ProbabilisticClassificationModel[
}
if ($(predictionCol).nonEmpty) {
val predUDF = if ($(rawPredictionCol).nonEmpty) {
callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
udf[Double, Vector](raw2prediction).apply(col($(rawPredictionCol)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my previous comments about type info.

} else if ($(probabilityCol).nonEmpty) {
callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
udf[Double, Vector](probability2prediction).apply(col($(probabilityCol)))
} else {
callUDF(predict _, DoubleType, col($(featuresCol)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,6 @@ class DataFrameSuite extends QueryTest {
)
}

test("deprecated callUdf in SQLContext") {
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
val sqlctx = df.sqlContext
sqlctx.udf.register("simpleUdf", (v: Int) => v * v)
checkAnswer(
df.select($"id", callUdf("simpleUdf", $"value")),
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
}

test("callUDF in SQLContext") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this unit test even callUdf is deprecated.

val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
val sqlctx = df.sqlContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {

test("SPARK-7158 collect and take return different results") {
import java.util.UUID
import org.apache.spark.sql.types._

val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
def id: () => String = () => { UUID.randomUUID().toString() }
val idUdf = udf(() => UUID.randomUUID().toString)

val dfWithId = df.withColumn("id", callUDF(id, StringType))
val dfWithId = df.withColumn("id", idUdf())
// Make a new DataFrame (actually the same reference to the old one)
val cached = dfWithId.cache()
// Trigger the cache
Expand Down