File tree Expand file tree Collapse file tree
mllib/src/main/scala/org/apache/spark/ml
catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis
core/src/main/scala/org/apache/spark/sql Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
131131
132132 // output label and label metadata as prediction
133133 aggregatedDataset
134- .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
134+ .withColumn($(predictionCol), labelUDF(col(accColName)), Some ( labelMetadata))
135135 .drop(accColName)
136136 }
137137
@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
203203 // TODO: use when ... otherwise after SPARK-7321 is merged
204204 val newLabelMeta = BinaryAttribute .defaultAttr.withName(" label" ).toMetadata()
205205 val labelColName = " mc2b$" + index
206- val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
207- val trainingDataset = multiclassLabeled .withColumn(labelColName, labelUDFWithNewMeta )
206+ val trainingDataset = multiclassLabeled
207+ .withColumn(labelColName, labelUDF(col($(labelCol))), Some (newLabelMeta) )
208208 val classifier = getClassifier
209209 val paramMap = new ParamMap ()
210210 paramMap.put(classifier.labelCol -> labelColName)
Original file line number Diff line number Diff line change @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
7575 }
7676 val newCol = bucketizer(dataset($(inputCol)))
7777 val newField = prepOutputField(dataset.schema)
78- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
78+ dataset.withColumn($(outputCol), newCol, Some ( newField.metadata))
7979 }
8080
8181 private def prepOutputField (schema : StructType ): StructField = {
Original file line number Diff line number Diff line change @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
341341 val newField = prepOutputField(dataset.schema)
342342 val transformUDF = udf { (vector : Vector ) => transformFunc(vector) }
343343 val newCol = transformUDF(dataset($(inputCol)))
344- dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
344+ dataset.withColumn($(outputCol), newCol, Some ( newField.metadata))
345345 }
346346
347347 override def transformSchema (schema : StructType ): StructType = {
Original file line number Diff line number Diff line change @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
119119 case features : SparseVector => features.slice(inds)
120120 }
121121 }
122- dataset.withColumn($(outputCol),
123- slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
122+ dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), Some (outputAttr.toMetadata()))
124123 }
125124
126125 /** Get the feature indices in order: indices, names */
Original file line number Diff line number Diff line change @@ -133,5 +133,7 @@ class AnalysisSuite extends AnalysisTest {
133133 // CreateStruct is a special case that we should not trim Alias for it.
134134 plan = testRelation.select(CreateStruct (Seq (a, (a + 1 ).as(" a+1" ))).as(" col" ))
135135 checkAnalysis(plan, plan)
136+ plan = testRelation.select(CreateStructUnsafe (Seq (a, (a + 1 ).as(" a+1" ))).as(" col" ))
137+ checkAnalysis(plan, plan)
136138 }
137139}
Original file line number Diff line number Diff line change @@ -1135,17 +1135,18 @@ class DataFrame private[sql](
11351135 * @group dfops
11361136 * @since 1.3.0
11371137 */
1138- def withColumn (colName : String , col : Column ): DataFrame = {
1138+ def withColumn (colName : String , col : Column , metadata : Option [ Metadata ] = None ): DataFrame = {
11391139 val resolver = sqlContext.analyzer.resolver
11401140 val replaced = schema.exists(f => resolver(f.name, colName))
1141+ val aliasedColumn = metadata.map(md => col.as(colName, md)).getOrElse(col.as(colName))
11411142 if (replaced) {
11421143 val colNames = schema.map { field =>
11431144 val name = field.name
1144- if (resolver(name, colName)) col.as(colName) else Column (name)
1145+ if (resolver(name, colName)) aliasedColumn else Column (name)
11451146 }
11461147 select(colNames : _* )
11471148 } else {
1148- select(Column (" *" ), col.as(colName) )
1149+ select(Column (" *" ), aliasedColumn )
11491150 }
11501151 }
11511152
You can’t perform that action at this time.
0 commit comments