Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (

// output label and label metadata as prediction
aggregatedDataset
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}

Expand Down Expand Up @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
// TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val trainingDataset =
multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
dataset.withColumn($(outputCol), newCol, newField.metadata)
}

private def prepOutputField(schema: StructType): StructField = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
dataset.withColumn($(outputCol), newCol, newField.metadata)
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
case features: SparseVector => features.slice(inds)
}
}
dataset.withColumn($(outputCol),
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata())
}

/** Get the feature indices in order: indices, names */
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,23 @@ class DataFrame private[sql](
}
}

/**
* Returns a new [[DataFrame]] by adding a column with metadata.
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make Column.as(alias: String, metadata: Metadata) also private? As it expose the Metadata too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we can change that since it's been public since 1.3.

val resolver = sqlContext.analyzer.resolver
val replaced = schema.exists(f => resolver(f.name, colName))
if (replaced) {
val colNames = schema.map { field =>
val name = field.name
if (resolver(name, colName)) col.as(colName, metadata) else Column(name)
}
select(colNames : _*)
} else {
select(Column("*"), col.as(colName, metadata))
}
}

/**
* Returns a new [[DataFrame]] with a column renamed.
* This is a no-op if schema doesn't contain existingName.
Expand Down