Skip to content

Commit d2e7deb

Browse files
committed
[SPARK-24867][SQL] Add AnalysisBarrier to DataFrameWriter
## What changes were proposed in this pull request? ```Scala val udf1 = udf({(x: Int, y: Int) => x + y}) val df = spark.range(0, 3).toDF("a") .withColumn("b", udf1($"a", udf1($"a", lit(10)))) df.cache() df.write.saveAsTable("t") ``` Cache is not being used because the plans do not match with the cached plan. This is a regression caused by the changes we made in AnalysisBarrier, since not all the Analyzer rules are idempotent. ## How was this patch tested? Added a test. Also found a bug in the DSV1 write path. This is not a regression. Thus, opened a separate JIRA https://issues.apache.org/jira/browse/SPARK-24869 Author: Xiao Li <[email protected]> Closes #21821 from gatorsmile/testMaster22.
1 parent 17f469b commit d2e7deb

3 files changed

Lines changed: 51 additions & 8 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
254254
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
255255
if (writer.isPresent) {
256256
runCommand(df.sparkSession, "save") {
257-
WriteToDataSourceV2(writer.get(), df.logicalPlan)
257+
WriteToDataSourceV2(writer.get(), df.planWithBarrier)
258258
}
259259
}
260260

@@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
275275
sparkSession = df.sparkSession,
276276
className = source,
277277
partitionColumns = partitioningColumns.getOrElse(Nil),
278-
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
278+
options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
279279
}
280280
}
281281

@@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
323323
InsertIntoTable(
324324
table = UnresolvedRelation(tableIdent),
325325
partition = Map.empty[String, Option[String]],
326-
query = df.logicalPlan,
326+
query = df.planWithBarrier,
327327
overwrite = mode == SaveMode.Overwrite,
328328
ifPartitionNotExists = false)
329329
}
@@ -459,7 +459,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
459459
partitionColumnNames = partitioningColumns.getOrElse(Nil),
460460
bucketSpec = getBucketSpec)
461461

462-
runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
462+
runCommand(df.sparkSession, "saveAsTable") {
463+
CreateTable(tableDesc, mode, Some(df.planWithBarrier))
464+
}
463465
}
464466

465467
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
2929

3030
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
3131
import org.apache.spark.sql.catalyst.TableIdentifier
32-
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
32+
import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver}
3333
import org.apache.spark.sql.catalyst.catalog._
3434
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3535
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@@ -891,8 +891,9 @@ object DDLUtils {
891891
* Throws exception if outputPath tries to overwrite inputpath.
892892
*/
893893
def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = {
894-
val inputPaths = query.collect {
895-
case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths
894+
val inputPaths = EliminateBarriers(query).collect {
895+
case LogicalRelation(r: HadoopFsRelation, _, _, _) =>
896+
r.location.rootPaths
896897
}.flatten
897898

898899
if (inputPaths.contains(outputPath)) {

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.api.java._
2121
import org.apache.spark.sql.catalyst.plans.logical.Project
22-
import org.apache.spark.sql.execution.command.ExplainCommand
22+
import org.apache.spark.sql.execution.QueryExecution
23+
import org.apache.spark.sql.execution.columnar.InMemoryRelation
24+
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
25+
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
2326
import org.apache.spark.sql.functions.{lit, udf}
2427
import org.apache.spark.sql.test.SharedSQLContext
2528
import org.apache.spark.sql.test.SQLTestData._
2629
import org.apache.spark.sql.types.{DataTypes, DoubleType}
30+
import org.apache.spark.sql.util.QueryExecutionListener
31+
2732

2833
private case class FunctionResult(f1: String, f2: String)
2934

@@ -325,6 +330,41 @@ class UDFSuite extends QueryTest with SharedSQLContext {
325330
}
326331
}
327332

333+
test("cached Data should be used in the write path") {
334+
withTable("t") {
335+
withTempPath { path =>
336+
var numTotalCachedHit = 0
337+
val listener = new QueryExecutionListener {
338+
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}
339+
340+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
341+
qe.withCachedData match {
342+
case c: CreateDataSourceTableAsSelectCommand
343+
if c.query.isInstanceOf[InMemoryRelation] =>
344+
numTotalCachedHit += 1
345+
case i: InsertIntoHadoopFsRelationCommand
346+
if i.query.isInstanceOf[InMemoryRelation] =>
347+
numTotalCachedHit += 1
348+
case _ =>
349+
}
350+
}
351+
}
352+
spark.listenerManager.register(listener)
353+
354+
val udf1 = udf({ (x: Int, y: Int) => x + y })
355+
val df = spark.range(0, 3).toDF("a")
356+
.withColumn("b", udf1($"a", lit(10)))
357+
df.cache()
358+
df.write.saveAsTable("t")
359+
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
360+
df.write.insertInto("t")
361+
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
362+
df.write.save(path.getCanonicalPath)
363+
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
364+
}
365+
}
366+
}
367+
328368
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
329369
val udf1 = udf({(x: Int, y: Int) => x + y})
330370
val df = spark.range(0, 3).toDF("a")

0 commit comments

Comments
 (0)