Skip to content

Commit 36c5b8b

Browse files
committed
fix mllib
1 parent 4cd4fad commit 36c5b8b

6 files changed

Lines changed: 12 additions & 10 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
131131

132132
// output label and label metadata as prediction
133133
aggregatedDataset
134-
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
134+
.withColumn($(predictionCol), labelUDF(col(accColName)), Some(labelMetadata))
135135
.drop(accColName)
136136
}
137137

@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
203203
// TODO: use when ... otherwise after SPARK-7321 is merged
204204
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
205205
val labelColName = "mc2b$" + index
206-
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
207-
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
206+
val trainingDataset = multiclassLabeled
207+
.withColumn(labelColName, labelUDF(col($(labelCol))), Some(newLabelMeta))
208208
val classifier = getClassifier
209209
val paramMap = new ParamMap()
210210
paramMap.put(classifier.labelCol -> labelColName)

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
7575
}
7676
val newCol = bucketizer(dataset($(inputCol)))
7777
val newField = prepOutputField(dataset.schema)
78-
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
78+
dataset.withColumn($(outputCol), newCol, Some(newField.metadata))
7979
}
8080

8181
private def prepOutputField(schema: StructType): StructField = {

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
341341
val newField = prepOutputField(dataset.schema)
342342
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
343343
val newCol = transformUDF(dataset($(inputCol)))
344-
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
344+
dataset.withColumn($(outputCol), newCol, Some(newField.metadata))
345345
}
346346

347347
override def transformSchema(schema: StructType): StructType = {

mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
119119
case features: SparseVector => features.slice(inds)
120120
}
121121
}
122-
dataset.withColumn($(outputCol),
123-
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
122+
dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), Some(outputAttr.toMetadata()))
124123
}
125124

126125
/** Get the feature indices in order: indices, names */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,7 @@ class AnalysisSuite extends AnalysisTest {
133133
// CreateStruct is a special case that we should not trim Alias for it.
134134
plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
135135
checkAnalysis(plan, plan)
136+
plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
137+
checkAnalysis(plan, plan)
136138
}
137139
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,17 +1135,18 @@ class DataFrame private[sql](
11351135
* @group dfops
11361136
* @since 1.3.0
11371137
*/
1138-
def withColumn(colName: String, col: Column): DataFrame = {
1138+
def withColumn(colName: String, col: Column, metadata: Option[Metadata] = None): DataFrame = {
11391139
val resolver = sqlContext.analyzer.resolver
11401140
val replaced = schema.exists(f => resolver(f.name, colName))
1141+
val aliasedColumn = metadata.map(md => col.as(colName, md)).getOrElse(col.as(colName))
11411142
if (replaced) {
11421143
val colNames = schema.map { field =>
11431144
val name = field.name
1144-
if (resolver(name, colName)) col.as(colName) else Column(name)
1145+
if (resolver(name, colName)) aliasedColumn else Column(name)
11451146
}
11461147
select(colNames : _*)
11471148
} else {
1148-
select(Column("*"), col.as(colName))
1149+
select(Column("*"), aliasedColumn)
11491150
}
11501151
}
11511152

0 commit comments

Comments
 (0)