-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-8575][SQL] Deprecate callUDF in favor of udf #6993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
477709f
84d6780
197ec82
0ea30b3
fe2a10b
bbdeaf3
49e4904
a672228
1305492
94345b5
8013409
0ebd0da
48ca15e
1ddb452
26f5a7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,7 +108,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur | |
| } | ||
| if (getPredictionCol != "") { | ||
| val predUDF = if (getRawPredictionCol != "") { | ||
| callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) | ||
| udf[Double, Vector](raw2prediction).apply(col(getRawPredictionCol)) | ||
| } else { | ||
| callUDF(predict _, DoubleType, col(getFeaturesCol)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update this line as well. |
||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,7 +90,7 @@ final class OneVsRestModel private[ml] ( | |
| val accColName = "mbc$acc" + UUID.randomUUID().toString | ||
| val init: () => Map[Int, Double] = () => {Map()} | ||
| val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) | ||
| val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) | ||
| val newDataset = dataset.withColumn(accColName, udf(init).apply()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Style nit: I think we call |
||
|
|
||
| // persist if underlying dataset is not persistent. | ||
| val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE | ||
|
|
@@ -110,9 +110,9 @@ final class OneVsRestModel private[ml] ( | |
| (predictions: Map[Int, Double], prediction: Vector) => { | ||
| predictions + ((index, prediction(1))) | ||
| } | ||
| val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) | ||
| val updateUDF = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try this: val update = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
val transformedDataset = model.transform(df).select(columns : _*)
val updatedDataset = transformedDataset.withColumn(
tmpColName, update(col(accColName), col(rawPredictionCol))) |
||
| val transformedDataset = model.transform(df).select(columns : _*) | ||
| val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) | ||
| val updatedDataset = transformedDataset.withColumn(tmpColName, updateUDF) | ||
| val newColumns = origCols ++ List(col(tmpColName)) | ||
|
|
||
| // switch out the intermediate column with the accumulator column | ||
|
|
@@ -129,8 +129,8 @@ final class OneVsRestModel private[ml] ( | |
| } | ||
|
|
||
| // output label and label metadata as prediction | ||
| val labelUdf = callUDF(label, DoubleType, col(accColName)) | ||
| aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) | ||
| val labelUDF = udf(label).apply(col(accColName)) | ||
| aggregatedDataset.withColumn($(predictionCol), labelUDF.as($(predictionCol), labelMetadata)) | ||
| .drop(accColName) | ||
| } | ||
|
|
||
|
|
@@ -175,12 +175,12 @@ final class OneVsRest(override val uid: String) | |
| } | ||
| val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) | ||
|
|
||
| val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) | ||
| val multiClassLabeled = dataset.select($(labelCol), $(featuresCol)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. |
||
|
|
||
| // persist if underlying dataset is not persistent. | ||
| val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE | ||
| if (handlePersistence) { | ||
| multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) | ||
| multiClassLabeled.persist(StorageLevel.MEMORY_AND_DISK) | ||
| } | ||
|
|
||
| // create k columns, one for each binary classifier. | ||
|
|
@@ -192,17 +192,17 @@ final class OneVsRest(override val uid: String) | |
|
|
||
| // generate new label metadata for the binary problem. | ||
| // TODO: use when ... otherwise after SPARK-7321 is merged | ||
| val labelUDF = callUDF(label, DoubleType, col($(labelCol))) | ||
| val labelUDF = udf(label).apply(col($(labelCol))) | ||
| val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() | ||
| val labelColName = "mc2b$" + index | ||
| val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) | ||
| val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) | ||
| val trainingDataset = multiClassLabeled.withColumn(labelColName, labelUDFWithNewMeta) | ||
| val classifier = getClassifier | ||
| classifier.fit(trainingDataset, classifier.labelCol -> labelColName) | ||
| }.toArray[ClassificationModel[_, _]] | ||
|
|
||
| if (handlePersistence) { | ||
| multiclassLabeled.unpersist() | ||
| multiClassLabeled.unpersist() | ||
| } | ||
|
|
||
| // extract label metadata from label column if present, or create a nominal attribute | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,9 +113,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ | |
| } | ||
| if ($(predictionCol).nonEmpty) { | ||
| val predUDF = if ($(rawPredictionCol).nonEmpty) { | ||
| callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) | ||
| udf[Double, Vector](raw2prediction).apply(col($(rawPredictionCol))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my previous comments about type info. |
||
| } else if ($(probabilityCol).nonEmpty) { | ||
| callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) | ||
| udf[Double, Vector](probability2prediction).apply(col($(probabilityCol))) | ||
| } else { | ||
| callUDF(predict _, DoubleType, col($(featuresCol))) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -301,15 +301,6 @@ class DataFrameSuite extends QueryTest { | |
| ) | ||
| } | ||
|
|
||
| test("deprecated callUdf in SQLContext") { | ||
| val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") | ||
| val sqlctx = df.sqlContext | ||
| sqlctx.udf.register("simpleUdf", (v: Int) => v * v) | ||
| checkAnswer( | ||
| df.select($"id", callUdf("simpleUdf", $"value")), | ||
| Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) | ||
| } | ||
|
|
||
| test("callUDF in SQLContext") { | ||
|
||
| val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") | ||
| val sqlctx = df.sqlContext | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type info is not necessary. Try
The
callUDFon line 106 needs some explicit type info sinceFeaturesTypedoesn't have type tag. We can write this: