@@ -131,20 +131,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
131131 *
132132 * @since 2.2.0
133133 */
134- def fill (value : Long ): DataFrame = fill (value, df.columns )
134+ def fill (value : Long ): DataFrame = fillValue (value, outputAttributes )
135135
136136 /**
137137 * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
138138 * @since 1.3.1
139139 */
140- def fill (value : Double ): DataFrame = fill (value, df.columns )
140+ def fill (value : Double ): DataFrame = fillValue (value, outputAttributes )
141141
142142 /**
143143 * Returns a new `DataFrame` that replaces null values in string columns with `value`.
144144 *
145145 * @since 1.3.1
146146 */
147- def fill (value : String ): DataFrame = fill (value, df.columns )
147+ def fill (value : String ): DataFrame = fillValue (value, outputAttributes )
148148
149149 /**
150150 * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
@@ -168,15 +168,15 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
168168 *
169169 * @since 2.2.0
170170 */
171- def fill (value : Long , cols : Seq [String ]): DataFrame = fillValue(value, cols)
171+ def fill (value : Long , cols : Seq [String ]): DataFrame = fillValue(value, toAttributes( cols) )
172172
173173 /**
174174 * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
175175 * numeric columns. If a specified column is not a numeric column, it is ignored.
176176 *
177177 * @since 1.3.1
178178 */
179- def fill (value : Double , cols : Seq [String ]): DataFrame = fillValue(value, cols)
179+ def fill (value : Double , cols : Seq [String ]): DataFrame = fillValue(value, toAttributes( cols) )
180180
181181
182182 /**
@@ -193,22 +193,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
193193 *
194194 * @since 1.3.1
195195 */
196- def fill (value : String , cols : Seq [String ]): DataFrame = fillValue(value, cols)
196+ def fill (value : String , cols : Seq [String ]): DataFrame = fillValue(value, toAttributes( cols) )
197197
198198 /**
199199 * Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
200200 *
201201 * @since 2.3.0
202202 */
203- def fill (value : Boolean ): DataFrame = fill (value, df.columns )
203+ def fill (value : Boolean ): DataFrame = fillValue (value, outputAttributes )
204204
205205 /**
206206 * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
207207 * boolean columns. If a specified column is not a boolean column, it is ignored.
208208 *
209209 * @since 2.3.0
210210 */
211- def fill (value : Boolean , cols : Seq [String ]): DataFrame = fillValue(value, cols)
211+ def fill (value : Boolean , cols : Seq [String ]): DataFrame = fillValue(value, toAttributes( cols) )
212212
213213 /**
214214 * Returns a new `DataFrame` that replaces null values in specified boolean columns.
@@ -434,15 +434,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
434434
435435 /**
436436 * Returns a [[Column ]] expression that replaces null value in `col` with `replacement`.
437+ * It selects a column based on its name.
437438 */
438439 private def fillCol [T ](col : StructField , replacement : T ): Column = {
439440 val quotedColName = " `" + col.name + " `"
440- val colValue = col.dataType match {
441+ fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
442+ }
443+
444+ /**
445+ * Returns a [[Column ]] expression that replaces null value in `expr` with `replacement`.
446+ * It uses the given `expr` as a column.
447+ */
448+ private def fillCol [T ](dataType : DataType , name : String , expr : Column , replacement : T ): Column = {
449+ val colValue = dataType match {
441450 case DoubleType | FloatType =>
442- nanvl(df.col(quotedColName) , lit(null )) // nanvl only supports these types
443- case _ => df.col(quotedColName)
451+ nanvl(expr , lit(null )) // nanvl only supports these types
452+ case _ => expr
444453 }
445- coalesce(colValue, lit(replacement).cast(col. dataType)).as(col. name)
454+ coalesce(colValue, lit(replacement).cast(dataType)).as(name)
446455 }
447456
448457 /**
@@ -469,12 +478,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
469478 s " Unsupported value type ${v.getClass.getName} ( $v). " )
470479 }
471480
481+ private def toAttributes (cols : Seq [String ]): Seq [Attribute ] = {
482+ cols.map(name => df.col(name).expr).collect {
483+ case a : Attribute => a
484+ }
485+ }
486+
487+ private def outputAttributes : Seq [Attribute ] = {
488+ df.queryExecution.analyzed.output
489+ }
490+
472491 /**
473- * Returns a new `DataFrame` that replaces null or NaN values in specified
474- * numeric, string columns. If a specified column is not a numeric, string
475- * or boolean column it is ignored.
492+ * Returns a new `DataFrame` that replaces null or NaN values in the specified
493+ * columns. If a specified column is not a numeric, string or boolean column,
494+ * it is ignored.
476495 */
477- private def fillValue [T ](value : T , cols : Seq [String ]): DataFrame = {
496+ private def fillValue [T ](value : T , cols : Seq [Attribute ]): DataFrame = {
478497 // the fill[T] which T is Long/Double,
479498 // should apply on all the NumericType Column, for example:
480499 // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
@@ -488,20 +507,19 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
488507 s " Unsupported value type ${value.getClass.getName} ( $value). " )
489508 }
490509
491- val columnEquals = df.sparkSession.sessionState.analyzer.resolver
492- val projections = df.schema.fields.map { f =>
493- val typeMatches = (targetType, f.dataType) match {
510+ val projections = outputAttributes.map { col =>
511+ val typeMatches = (targetType, col.dataType) match {
494512 case (NumericType , dt) => dt.isInstanceOf [NumericType ]
495513 case (StringType , dt) => dt == StringType
496514 case (BooleanType , dt) => dt == BooleanType
497515 case _ =>
498516 throw new IllegalArgumentException (s " $targetType is not matched at fillValue " )
499517 }
500518 // Only fill if the column is part of the cols list.
501- if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
502- fillCol[ T ](f , value)
519+ if (typeMatches && cols.exists(_.semanticEquals( col))) {
520+ fillCol(col.dataType, col.name, Column (col) , value)
503521 } else {
504- df.col(f.name )
522+ Column (col )
505523 }
506524 }
507525 df.select(projections : _* )
0 commit comments