Skip to content

Commit fa552c3

Browse files
committed
[SPARK-24867][SQL] Add AnalysisBarrier to DataFrameWriter
```Scala val udf1 = udf({(x: Int, y: Int) => x + y}) val df = spark.range(0, 3).toDF("a") .withColumn("b", udf1($"a", udf1($"a", lit(10)))) df.cache() df.write.saveAsTable("t") ``` Cache is not being used because the plans do not match with the cached plan. This is a regression caused by the changes we made in AnalysisBarrier, since not all the Analyzer rules are idempotent. Added a test. Also found a bug in the DSV1 write path. This is not a regression. Thus, opened a separate JIRA https://issues.apache.org/jira/browse/SPARK-24869 Author: Xiao Li <[email protected]> Closes apache#21821 from gatorsmile/testMaster22. (cherry picked from commit d2e7deb) Signed-off-by: Xiao Li <[email protected]>
1 parent 740606e commit fa552c3

3 files changed

Lines changed: 51 additions & 8 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
254254
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
255255
if (writer.isPresent) {
256256
runCommand(df.sparkSession, "save") {
257-
WriteToDataSourceV2(writer.get(), df.logicalPlan)
257+
WriteToDataSourceV2(writer.get(), df.planWithBarrier)
258258
}
259259
}
260260

@@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
275275
sparkSession = df.sparkSession,
276276
className = source,
277277
partitionColumns = partitioningColumns.getOrElse(Nil),
278-
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
278+
options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
279279
}
280280
}
281281

@@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
323323
InsertIntoTable(
324324
table = UnresolvedRelation(tableIdent),
325325
partition = Map.empty[String, Option[String]],
326-
query = df.logicalPlan,
326+
query = df.planWithBarrier,
327327
overwrite = mode == SaveMode.Overwrite,
328328
ifPartitionNotExists = false)
329329
}
@@ -455,7 +455,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
455455
partitionColumnNames = partitioningColumns.getOrElse(Nil),
456456
bucketSpec = getBucketSpec)
457457

458-
runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
458+
runCommand(df.sparkSession, "saveAsTable") {
459+
CreateTable(tableDesc, mode, Some(df.planWithBarrier))
460+
}
459461
}
460462

461463
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
2929

3030
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
3131
import org.apache.spark.sql.catalyst.TableIdentifier
32-
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
32+
import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver}
3333
import org.apache.spark.sql.catalyst.catalog._
3434
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3535
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@@ -889,8 +889,9 @@ object DDLUtils {
889889
* Throws exception if outputPath tries to overwrite inputpath.
890890
*/
891891
def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = {
892-
val inputPaths = query.collect {
893-
case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths
892+
val inputPaths = EliminateBarriers(query).collect {
893+
case LogicalRelation(r: HadoopFsRelation, _, _, _) =>
894+
r.location.rootPaths
894895
}.flatten
895896

896897
if (inputPaths.contains(outputPath)) {

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.api.java._
2121
import org.apache.spark.sql.catalyst.plans.logical.Project
22-
import org.apache.spark.sql.execution.command.ExplainCommand
22+
import org.apache.spark.sql.execution.QueryExecution
23+
import org.apache.spark.sql.execution.columnar.InMemoryRelation
24+
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
25+
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
2326
import org.apache.spark.sql.functions.{lit, udf}
2427
import org.apache.spark.sql.test.SharedSQLContext
2528
import org.apache.spark.sql.test.SQLTestData._
2629
import org.apache.spark.sql.types.{DataTypes, DoubleType}
30+
import org.apache.spark.sql.util.QueryExecutionListener
31+
2732

2833
private case class FunctionResult(f1: String, f2: String)
2934

@@ -305,6 +310,41 @@ class UDFSuite extends QueryTest with SharedSQLContext {
305310
.contains(s"UDF:$udf1Name(UDF:$udf2Name(1))"))
306311
}
307312

313+
test("cached Data should be used in the write path") {
314+
withTable("t") {
315+
withTempPath { path =>
316+
var numTotalCachedHit = 0
317+
val listener = new QueryExecutionListener {
318+
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}
319+
320+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
321+
qe.withCachedData match {
322+
case c: CreateDataSourceTableAsSelectCommand
323+
if c.query.isInstanceOf[InMemoryRelation] =>
324+
numTotalCachedHit += 1
325+
case i: InsertIntoHadoopFsRelationCommand
326+
if i.query.isInstanceOf[InMemoryRelation] =>
327+
numTotalCachedHit += 1
328+
case _ =>
329+
}
330+
}
331+
}
332+
spark.listenerManager.register(listener)
333+
334+
val udf1 = udf({ (x: Int, y: Int) => x + y })
335+
val df = spark.range(0, 3).toDF("a")
336+
.withColumn("b", udf1($"a", lit(10)))
337+
df.cache()
338+
df.write.saveAsTable("t")
339+
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
340+
df.write.insertInto("t")
341+
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
342+
df.write.save(path.getCanonicalPath)
343+
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
344+
}
345+
}
346+
}
347+
308348
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
309349
val udf1 = udf({(x: Int, y: Int) => x + y})
310350
val df = spark.range(0, 3).toDF("a")

0 commit comments

Comments
 (0)