Skip to content

Commit 776a294

Browse files
committed
initial commit
1 parent eeef0e7 commit 776a294

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,20 +131,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
131131
*
132132
* @since 2.2.0
133133
*/
134-
def fill(value: Long): DataFrame = fill(value, df.columns)
134+
def fill(value: Long): DataFrame = fillValue(value, outputAttributes)
135135

136136
/**
137137
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
138138
* @since 1.3.1
139139
*/
140-
def fill(value: Double): DataFrame = fill(value, df.columns)
140+
def fill(value: Double): DataFrame = fillValue(value, outputAttributes)
141141

142142
/**
143143
* Returns a new `DataFrame` that replaces null values in string columns with `value`.
144144
*
145145
* @since 1.3.1
146146
*/
147-
def fill(value: String): DataFrame = fill(value, df.columns)
147+
def fill(value: String): DataFrame = fillValue(value, outputAttributes)
148148

149149
/**
150150
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
@@ -168,15 +168,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
168168
*
169169
* @since 2.2.0
170170
*/
171-
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
171+
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
172172

173173
/**
174174
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
175175
* numeric columns. If a specified column is not a numeric column, it is ignored.
176176
*
177177
* @since 1.3.1
178178
*/
179-
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
179+
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
180180

181181

182182
/**
@@ -193,22 +193,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
193193
*
194194
* @since 1.3.1
195195
*/
196-
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
196+
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
197197

198198
/**
199199
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
200200
*
201201
* @since 2.3.0
202202
*/
203-
def fill(value: Boolean): DataFrame = fill(value, df.columns)
203+
def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)
204204

205205
/**
206206
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
207207
* boolean columns. If a specified column is not a boolean column, it is ignored.
208208
*
209209
* @since 2.3.0
210210
*/
211-
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
211+
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
212212

213213
/**
214214
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
@@ -434,15 +434,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
434434

435435
/**
436436
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
437+
* It selects a column based on its name.
437438
*/
438439
private def fillCol[T](col: StructField, replacement: T): Column = {
439440
val quotedColName = "`" + col.name + "`"
440-
val colValue = col.dataType match {
441+
fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
442+
}
443+
444+
/**
445+
* Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
446+
* It uses the given `expr` as a column.
447+
*/
448+
private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
449+
val colValue = dataType match {
441450
case DoubleType | FloatType =>
442-
nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
443-
case _ => df.col(quotedColName)
451+
nanvl(expr, lit(null)) // nanvl only supports these types
452+
case _ => expr
444453
}
445-
coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
454+
coalesce(colValue, lit(replacement).cast(dataType)).as(name)
446455
}
447456

448457
/**
@@ -469,12 +478,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
469478
s"Unsupported value type ${v.getClass.getName} ($v).")
470479
}
471480

481+
private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
482+
cols.map(name => df.col(name).expr).collect {
483+
case a: Attribute => a
484+
}
485+
}
486+
487+
private def outputAttributes: Seq[Attribute] = {
488+
df.queryExecution.analyzed.output
489+
}
490+
472491
/**
473-
* Returns a new `DataFrame` that replaces null or NaN values in specified
474-
* numeric, string columns. If a specified column is not a numeric, string
475-
* or boolean column it is ignored.
492+
* Returns a new `DataFrame` that replaces null or NaN values in the specified
493+
* columns. If a specified column is not a numeric, string or boolean column,
494+
* it is ignored.
476495
*/
477-
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
496+
private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
478497
// the fill[T] which T is Long/Double,
479498
// should apply on all the NumericType Column, for example:
480499
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
@@ -488,20 +507,19 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
488507
s"Unsupported value type ${value.getClass.getName} ($value).")
489508
}
490509

491-
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
492-
val projections = df.schema.fields.map { f =>
493-
val typeMatches = (targetType, f.dataType) match {
510+
val projections = outputAttributes.map { col =>
511+
val typeMatches = (targetType, col.dataType) match {
494512
case (NumericType, dt) => dt.isInstanceOf[NumericType]
495513
case (StringType, dt) => dt == StringType
496514
case (BooleanType, dt) => dt == BooleanType
497515
case _ =>
498516
throw new IllegalArgumentException(s"$targetType is not matched at fillValue")
499517
}
500518
// Only fill if the column is part of the cols list.
501-
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
502-
fillCol[T](f, value)
519+
if (typeMatches && cols.exists(_.semanticEquals(col))) {
520+
fillCol(col.dataType, col.name, Column(col), value)
503521
} else {
504-
df.col(f.name)
522+
Column(col)
505523
}
506524
}
507525
df.select(projections : _*)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
2121

2222
import org.apache.spark.sql.internal.SQLConf
2323
import org.apache.spark.sql.test.SharedSQLContext
24+
import org.apache.spark.sql.types.{StringType, StructType}
2425

2526
class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
2627
import testImplicits._
@@ -239,6 +240,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
239240
}
240241
}
241242

243+
test("fill with col(*)") {
244+
val df = createDF()
245+
// If columns are specified with "*", they are ignored.
246+
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
247+
}
248+
249+
test("fill with nested columns") {
250+
val schema = new StructType()
251+
.add("c1", new StructType()
252+
.add("c1-1", StringType)
253+
.add("c1-2", StringType))
254+
255+
val data = Seq(
256+
Row(Row(null, "a2")),
257+
Row(Row("b1", "b2")),
258+
Row(null))
259+
260+
val df = spark.createDataFrame(
261+
spark.sparkContext.parallelize(data), schema)
262+
263+
checkAnswer(df.select("c1.c1-1"),
264+
Row(null) :: Row("b1") :: Row(null) :: Nil)
265+
266+
// Nested columns are ignored for fill().
267+
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
268+
}
269+
242270
test("replace") {
243271
val input = createDF()
244272

@@ -349,4 +377,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
349377
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
350378
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
351379
}
380+
381+
test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") {
382+
val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
383+
val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
384+
val df = left.join(right, Seq("col1"))
385+
386+
// If column names are specified, the following fails due to ambiguity.
387+
val exception = intercept[AnalysisException] {
388+
df.na.fill("hello", Seq("col2"))
389+
}
390+
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))
391+
392+
// If column names are not specified, fill() is applied to all the eligible columns.
393+
checkAnswer(
394+
df.na.fill("hello"),
395+
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
396+
}
352397
}

0 commit comments

Comments
 (0)