Skip to content

Commit 7e92791

Browse files
cloud-fanCodingCat
authored andcommitted
[SPARK-9929] [SQL] support metadata in withColumn
in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call `withColumn` and in `withColumn` we alias this clolumn again. Here I overloaded `withColumn` to allow user set metadata, just like what we did for `Column.as`. Author: Wenchen Fan <[email protected]> Closes apache#8159 from cloud-fan/withColumn.
1 parent 9463ebd commit 7e92791

5 files changed

Lines changed: 23 additions & 7 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
131131

132132
// output label and label metadata as prediction
133133
aggregatedDataset
134-
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
134+
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
135135
.drop(accColName)
136136
}
137137

@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
203203
// TODO: use when ... otherwise after SPARK-7321 is merged
204204
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
205205
val labelColName = "mc2b$" + index
206-
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
207-
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
206+
val trainingDataset =
207+
multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
208208
val classifier = getClassifier
209209
val paramMap = new ParamMap()
210210
paramMap.put(classifier.labelCol -> labelColName)

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
7575
}
7676
val newCol = bucketizer(dataset($(inputCol)))
7777
val newField = prepOutputField(dataset.schema)
78-
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
78+
dataset.withColumn($(outputCol), newCol, newField.metadata)
7979
}
8080

8181
private def prepOutputField(schema: StructType): StructField = {

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
341341
val newField = prepOutputField(dataset.schema)
342342
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
343343
val newCol = transformUDF(dataset($(inputCol)))
344-
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
344+
dataset.withColumn($(outputCol), newCol, newField.metadata)
345345
}
346346

347347
override def transformSchema(schema: StructType): StructType = {

mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
119119
case features: SparseVector => features.slice(inds)
120120
}
121121
}
122-
dataset.withColumn($(outputCol),
123-
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
122+
dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
124123
}
125124

126125
/** Get the feature indices in order: indices, names */

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,23 @@ class DataFrame private[sql](
11491149
}
11501150
}
11511151

1152+
/**
1153+
* Returns a new [[DataFrame]] by adding a column with metadata.
1154+
*/
1155+
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
1156+
val resolver = sqlContext.analyzer.resolver
1157+
val replaced = schema.exists(f => resolver(f.name, colName))
1158+
if (replaced) {
1159+
val colNames = schema.map { field =>
1160+
val name = field.name
1161+
if (resolver(name, colName)) col.as(colName, metadata) else Column(name)
1162+
}
1163+
select(colNames : _*)
1164+
} else {
1165+
select(Column("*"), col.as(colName, metadata))
1166+
}
1167+
}
1168+
11521169
/**
11531170
* Returns a new [[DataFrame]] with a column renamed.
11541171
* This is a no-op if schema doesn't contain existingName.

0 commit comments

Comments
 (0)