File tree Expand file tree Collapse file tree
mllib/src/main/scala/org/apache/spark/ml
sql/core/src/main/scala/org/apache/spark/sql Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
131131
132132 // output label and label metadata as prediction
133133 aggregatedDataset
134- .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol) , labelMetadata) )
134+ .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
135135 .drop(accColName)
136136 }
137137
@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
203203 // TODO: use when ... otherwise after SPARK-7321 is merged
204204 val newLabelMeta = BinaryAttribute .defaultAttr.withName(" label" ).toMetadata()
205205 val labelColName = " mc2b$" + index
206- val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
207- val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta )
206+ val trainingDataset =
207+ multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta )
208208 val classifier = getClassifier
209209 val paramMap = new ParamMap ()
210210 paramMap.put(classifier.labelCol -> labelColName)
Original file line number Diff line number Diff line change @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
7575 }
7676 val newCol = bucketizer(dataset($(inputCol)))
7777 val newField = prepOutputField(dataset.schema)
78- dataset.withColumn($(outputCol), newCol.as($(outputCol) , newField.metadata) )
78+ dataset.withColumn($(outputCol), newCol, newField.metadata)
7979 }
8080
8181 private def prepOutputField (schema : StructType ): StructField = {
Original file line number Diff line number Diff line change @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
341341 val newField = prepOutputField(dataset.schema)
342342 val transformUDF = udf { (vector : Vector ) => transformFunc(vector) }
343343 val newCol = transformUDF(dataset($(inputCol)))
344- dataset.withColumn($(outputCol), newCol.as($(outputCol) , newField.metadata) )
344+ dataset.withColumn($(outputCol), newCol, newField.metadata)
345345 }
346346
347347 override def transformSchema (schema : StructType ): StructType = {
Original file line number Diff line number Diff line change @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
119119 case features : SparseVector => features.slice(inds)
120120 }
121121 }
122- dataset.withColumn($(outputCol),
123- slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
122+ dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
124123 }
125124
126125 /** Get the feature indices in order: indices, names */
Original file line number Diff line number Diff line change @@ -1149,6 +1149,23 @@ class DataFrame private[sql](
11491149 }
11501150 }
11511151
1152+ /**
1153+ * Returns a new [[DataFrame ]] by adding a column with metadata.
1154+ */
1155+ private [spark] def withColumn (colName : String , col : Column , metadata : Metadata ): DataFrame = {
1156+ val resolver = sqlContext.analyzer.resolver
1157+ val replaced = schema.exists(f => resolver(f.name, colName))
1158+ if (replaced) {
1159+ val colNames = schema.map { field =>
1160+ val name = field.name
1161+ if (resolver(name, colName)) col.as(colName, metadata) else Column (name)
1162+ }
1163+ select(colNames : _* )
1164+ } else {
1165+ select(Column (" *" ), col.as(colName, metadata))
1166+ }
1167+ }
1168+
11521169 /**
11531170 * Returns a new [[DataFrame ]] with a column renamed.
11541171 * This is a no-op if schema doesn't contain existingName.
You can’t perform that action at this time.
0 commit comments