From a4562264c45c24a88f4d508c2d34d4e7aed50631 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Oct 2018 01:18:44 +0800 Subject: [PATCH 1/6] SQL execution listener shouldn't happen on execution thread --- project/MimaExcludes.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 13 +-- .../scala/org/apache/spark/sql/Dataset.scala | 20 +--- .../spark/sql/execution/SQLExecution.scala | 34 +++++-- .../spark/sql/execution/ui/SQLListener.scala | 13 ++- .../internal/BaseSessionStateBuilder.scala | 4 +- .../sql/util/QueryExecutionListener.scala | 93 +++++++------------ .../sql/execution/SQLJsonProtocolSuite.scala | 33 ++++++- .../sql/util/DataFrameCallbackSuite.scala | 13 +++ .../util/ExecutionListenerManagerSuite.scala | 14 +-- 10 files changed, 134 insertions(+), 105 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0b074fbf64eda..1a932ed7e05f8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this") ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 188fce72efac5..f0572706e6560 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -667,17 +667,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { */ private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { val qe = session.sessionState.executePlan(command) - try { - val start = System.nanoTime() - // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(session, qe)(qe.toRdd) - val end = System.nanoTime() - session.listenerManager.onSuccess(name, qe, end - start) - } catch { - case e: Exception => - session.listenerManager.onFailure(name, qe, e) - throw e - } + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee968..0b2b3c2a4aee2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3356,21 +3356,11 @@ class Dataset[T] private[sql]( * user-registered callback functions. */ private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { - try { - qe.executedPlan.foreach { plan => - plan.resetMetrics() - } - val start = System.nanoTime() - val result = SQLExecution.withNewExecutionId(sparkSession, qe) { - action(qe.executedPlan) - } - val end = System.nanoTime() - sparkSession.listenerManager.onSuccess(name, qe, end - start) - result - } catch { - case e: Exception => - sparkSession.listenerManager.onFailure(name, qe, e) - throw e + qe.executedPlan.foreach { plan => + plan.resetMetrics() + } + SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) { + action(qe.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 439932b0cc3ac..183ad2cc39fee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -58,7 +58,8 @@ object SQLExecution { */ def withNewExecutionId[T]( sparkSession: SparkSession, - queryExecution: QueryExecution)(body: => T): T = { + queryExecution: QueryExecution, + name: Option[String] = None)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) val executionId = SQLExecution.nextExecutionId @@ -71,14 +72,35 @@ object SQLExecution { val callSite = sc.getCallSite() withSQLConfPropagated(sparkSession) { - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + var ex: Option[Exception] = None + val startTime = System.currentTimeMillis() try { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId = executionId, + description = callSite.shortForm, + details = callSite.longForm, + physicalPlanDescription = queryExecution.toString, + // `queryExecution.executedPlan` triggers query planning. If it fails, the exception + // will be caught and reported in the `SparkListenerSQLExecutionEnd` + sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), + time = startTime)) body + } catch { + case e: Exception => + ex = Some(e) + throw e } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + val endTime = System.currentTimeMillis() + val event = SparkListenerSQLExecutionEnd(executionId, endTime) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` + // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We + // can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index b58b8c6d45e5b..9a66b1d955ab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.ui +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.`type`.TypeFactory import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -24,8 +25,7 @@ import com.fasterxml.jackson.databind.util.Converter import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo} @DeveloperApi case class SparkListenerSQLExecutionStart( @@ -39,7 +39,14 @@ case class SparkListenerSQLExecutionStart( @DeveloperApi case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) - extends SparkListenerEvent + extends SparkListenerEvent { + + @JsonIgnore private[sql] var executionName: Option[String] = None + // These 3 fields are only accessed when `executionName` is defined. + @JsonIgnore private[sql] var duration: Long = 0L + @JsonIgnore private[sql] var qe: QueryExecution = null + @JsonIgnore private[sql] var executionFailure: Option[Exception] = None +} /** * A message used to update SQL metric value for driver-side updates (which doesn't get reflected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3a0db7e16c23a..60bba5e10703c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse( - new ExecutionListenerManager(session.sparkContext.conf)) + parentState.map(_.listenerManager.clone(session)).getOrElse( + new ExecutionListenerManager(session, loadExtensions = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 2b46233e1a5df..3ea3fae65b948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.util -import java.util.concurrent.locks.ReentrantReadWriteLock +import java.util.concurrent.CopyOnWriteArrayList -import scala.collection.mutable.ListBuffer -import scala.util.control.NonFatal +import scala.collection.JavaConverters._ -import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.Utils @@ -75,95 +76,69 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private extends Logging { +class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) + extends SparkListener with Logging { - private[sql] def this(conf: SparkConf) = { - this() + private[this] val listeners = new CopyOnWriteArrayList[QueryExecutionListener] + + if (loadExtensions) { + val conf = session.sparkContext.conf conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) } } + session.sparkContext.listenerBus.addToSharedQueue(this) + /** * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi - def register(listener: QueryExecutionListener): Unit = writeLock { - listeners += listener + def register(listener: QueryExecutionListener): Unit = { + listeners.add(listener) } /** * Unregisters the specified [[QueryExecutionListener]]. */ @DeveloperApi - def unregister(listener: QueryExecutionListener): Unit = writeLock { - listeners -= listener + def unregister(listener: QueryExecutionListener): Unit = { + listeners.remove(listener) } /** * Removes all the registered [[QueryExecutionListener]]. */ @DeveloperApi - def clear(): Unit = writeLock { + def clear(): Unit = { listeners.clear() } /** * Get an identical copy of this listener manager. */ - @DeveloperApi - override def clone(): ExecutionListenerManager = writeLock { - val newListenerManager = new ExecutionListenerManager - listeners.foreach(newListenerManager.register) + private[sql] def clone(session: SparkSession): ExecutionListenerManager = { + val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false) + listeners.iterator().asScala.foreach(newListenerManager.register) newListenerManager } - private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - readLock { - withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: SparkListenerSQLExecutionEnd if shouldCatchEvent(e) => + val funcName = e.executionName.get + e.executionFailure match { + case Some(ex) => + listeners.iterator().asScala.foreach(_.onFailure(funcName, e.qe, ex)) + case _ => + listeners.iterator().asScala.foreach(_.onSuccess(funcName, e.qe, e.duration)) } - } - } - private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - readLock { - withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) - } - } + case _ => // Ignore } - private[this] val listeners = ListBuffer.empty[QueryExecutionListener] - - /** A lock to prevent updating the list of listeners while we are traversing through them. */ - private[this] val lock = new ReentrantReadWriteLock() - - private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { - for (listener <- listeners) { - try { - f(listener) - } catch { - case NonFatal(e) => logWarning("Error executing query execution listener", e) - } - } - } - - /** Acquires a read lock on the cache for the duration of `f`. */ - private def readLock[A](f: => A): A = { - val rl = lock.readLock() - rl.lock() - try f finally { - rl.unlock() - } - } - - /** Acquires a write lock on the cache for the duration of `f`. */ - private def writeLock[A](f: => A): A = { - val wl = lock.writeLock() - wl.lock() - try f finally { - wl.unlock() - } + private def shouldCatchEvent(e: SparkListenerSQLExecutionEnd): Boolean = { + // Only catch SQL execution with a name, and triggered by the same spark session that this + // listener manager belongs. + e.executionName.isDefined && e.qe.sparkSession.eq(this.session) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index 08e40e28d3d57..08789e63fa7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution -import org.json4s.jackson.JsonMethods.parse +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} +import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.JsonProtocol -class SQLJsonProtocolSuite extends SparkFunSuite { +class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { test("SparkPlanGraph backward compatibility: metadata") { val SQLExecutionStartJsonString = @@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite { new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) assert(reconstructedEvent == expectedEvent) } + + test("SparkListenerSQLExecutionEnd backward compatibility") { + spark = new TestSparkSession() + val qe = spark.sql("select 1").queryExecution + val event = SparkListenerSQLExecutionEnd(1, 10) + event.duration = 1000 + event.executionName = Some("test") + event.qe = qe + event.executionFailure = Some(new RuntimeException("test")) + val json = JsonProtocol.sparkEventToJson(event) + assert(json == parse( + """ + |{ + | "Event" : "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd", + | "executionId" : 1, + | "time" : 10 + |} + """.stripMargin)) + val readBack = JsonProtocol.sparkEventFromJson(json) + event.duration = 0 + event.executionName = None + event.qe = null + event.executionFailure = None + assert(readBack == event) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a239e39d9c5a3..e8710aeb40bd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -48,6 +48,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { df.select("i").collect() df.filter($"i" > 0).count() + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 2) assert(metrics(0)._1 == "collect") @@ -78,6 +79,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 1) assert(metrics(0)._1 == "collect") assert(metrics(0)._2.analyzed.isInstanceOf[Project]) @@ -103,10 +105,16 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + // Wait for the first `collect` to be caught by our listener. Otherwise the next `collect` will + // reset the plan metrics. + sparkContext.listenerBus.waitUntilEmpty(1000) df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 3) assert(metrics(0) === 1) assert(metrics(1) === 1) @@ -154,6 +162,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // For this simple case, the peakExecutionMemory of a stage should be the data size of the // aggregate operator, as we only have one memory consuming operator per stage. + sparkContext.listenerBus.waitUntilEmpty(1000) assert(metrics.length == 2) assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) @@ -177,6 +186,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTempPath { path => spark.range(10).write.format("json").save(path.getCanonicalPath) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 1) assert(commands.head._1 == "save") assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) @@ -187,6 +197,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTable("tab") { sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 3) assert(commands(2)._1 == "insertInto") assert(commands(2)._2.isInstanceOf[InsertIntoTable]) @@ -197,6 +208,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(commands.length == 5) assert(commands(4)._1 == "saveAsTable") assert(commands(4)._2.isInstanceOf[CreateTable]) @@ -208,6 +220,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { spark.range(10).select($"id", $"id").write.insertInto("tab") } + sparkContext.listenerBus.waitUntilEmpty(1000) assert(exceptions.length == 1) assert(exceptions.head._1 == "insertInto") assert(exceptions.head._2 == e) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala index 4205e23ae240a..da414f4311e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -20,26 +20,28 @@ package org.apache.spark.sql.util import java.util.concurrent.atomic.AtomicInteger import org.apache.spark._ +import org.apache.spark.sql.{LocalSparkSession, SparkSession} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.StaticSQLConf._ -class ExecutionListenerManagerSuite extends SparkFunSuite { +class ExecutionListenerManagerSuite extends SparkFunSuite with LocalSparkSession { import CountingQueryExecutionListener._ test("register query execution listeners using configuration") { val conf = new SparkConf(false) .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() - val mgr = new ExecutionListenerManager(conf) + spark.sql("select 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty(1000) assert(INSTANCE_COUNT.get() === 1) - mgr.onSuccess(null, null, 42L) assert(CALLBACK_COUNT.get() === 1) - val clone = mgr.clone() + val cloned = spark.cloneSession() + cloned.sql("select 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty(1000) assert(INSTANCE_COUNT.get() === 1) - - clone.onSuccess(null, null, 42L) assert(CALLBACK_COUNT.get() === 2) } From 436197b4a395323a5ea26d194389d1c0c41cb578 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Oct 2018 14:35:51 +0800 Subject: [PATCH 2/6] fix tests --- .../test/scala/org/apache/spark/sql/SessionStateSuite.scala | 3 +++ sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 3 +++ 2 files changed, 6 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index e1b5eba53f06a..6317cd28bcc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -155,6 +155,7 @@ class SessionStateSuite extends SparkFunSuite { assert(forkedSession ne activeSession) assert(forkedSession.listenerManager ne activeSession.listenerManager) runCollectQueryOn(forkedSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorA.commands.length == 1) // forked should callback to A assert(collectorA.commands(0) == "collect") @@ -162,12 +163,14 @@ class SessionStateSuite extends SparkFunSuite { // => changes to forked do not affect original forkedSession.listenerManager.register(collectorB) runCollectQueryOn(activeSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorB.commands.isEmpty) // original should not callback to B assert(collectorA.commands.length == 2) // original should still callback to A assert(collectorA.commands(1) == "collect") // <= changes to original do not affect forked activeSession.listenerManager.register(collectorC) runCollectQueryOn(forkedSession) + activeSession.sparkContext.listenerBus.waitUntilEmpty(1000) assert(collectorC.commands.isEmpty) // forked should not callback to C assert(collectorA.commands.length == 3) // forked should still callback to A assert(collectorB.commands.length == 1) // forked should still callback to B diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 30dca9497ddde..269600dd59cb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { .withColumn("b", udf1($"a", lit(10))) df.cache() df.write.saveAsTable("t") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") df.write.insertInto("t") + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 2, "expected to be cached in insertInto") df.write.save(path.getCanonicalPath) + sparkContext.listenerBus.waitUntilEmpty(1000) assert(numTotalCachedHit == 3, "expected to be cached in save for native") } } From 642ddd321dd4aa196329308ba16195b8bf35a4bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Oct 2018 23:21:17 +0800 Subject: [PATCH 3/6] address comment --- .../org/apache/spark/sql/util/QueryExecutionListener.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 3ea3fae65b948..633f56993484d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -76,6 +76,11 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving +// The `session` is used to indicate which session carries this listener manager, and we only +// catch SQL executions which are launched by the same session. +// The `loadExtensions` flag is used to indicate whether we should load the pre-defined, +// user-specified listeners during construction. We should not do it when cloning this listener +// manager, as we will copy all listeners to the cloned listener manager. class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) extends SparkListener with Logging { From a25524b36c2fe85d3ee443e3bd0c6e7b447085fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Oct 2018 21:51:38 +0800 Subject: [PATCH 4/6] address comments --- .../org/apache/spark/util/ListenerBus.scala | 8 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 6 +-- .../spark/sql/execution/SQLExecution.scala | 8 ++-- .../spark/sql/execution/ui/SQLListener.scala | 10 +++- .../sql/util/QueryExecutionListener.scala | 47 ++++++++++--------- 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a8f10684d5a2c..2e517707ff774 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * Remove all listeners and they won't receive any events. This method is thread-safe and can be + * called in any thread. + */ + final def removeAllListeners(): Unit = { + listenersPlusTimers.clear() + } + /** * This can be overridden by subclasses if there is any extra cleanup to do when removing a * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0b2b3c2a4aee2..0fb3301b36162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3356,10 +3356,10 @@ class Dataset[T] private[sql]( * user-registered callback functions. */ private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { - qe.executedPlan.foreach { plan => - plan.resetMetrics() - } SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) { + qe.executedPlan.foreach { plan => + plan.resetMetrics() + } action(qe.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 183ad2cc39fee..dda7cb55f5395 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -73,7 +73,7 @@ object SQLExecution { withSQLConfPropagated(sparkSession) { var ex: Option[Exception] = None - val startTime = System.currentTimeMillis() + val startTime = System.nanoTime() try { sc.listenerBus.post(SparkListenerSQLExecutionStart( executionId = executionId, @@ -83,15 +83,15 @@ object SQLExecution { // `queryExecution.executedPlan` triggers query planning. If it fails, the exception // will be caught and reported in the `SparkListenerSQLExecutionEnd` sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), - time = startTime)) + time = System.currentTimeMillis())) body } catch { case e: Exception => ex = Some(e) throw e } finally { - val endTime = System.currentTimeMillis() - val event = SparkListenerSQLExecutionEnd(executionId, endTime) + val endTime = System.nanoTime() + val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We // can specify the execution name in more places in the future, so that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 9a66b1d955ab8..c04a31c428d11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -41,10 +41,18 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent { + // The name of the execution, e.g. `df.collect` will trigger a SQL execution with name "collect". @JsonIgnore private[sql] var executionName: Option[String] = None - // These 3 fields are only accessed when `executionName` is defined. + + // The following 3 fields are only accessed when `executionName` is defined. + + // The duration of the SQL execution, in nanoseconds. @JsonIgnore private[sql] var duration: Long = 0L + + // The `QueryExecution` instance that represents the SQL execution @JsonIgnore private[sql] var qe: QueryExecution = null + + // The exception object that caused this execution to fail. None if the execution doesn't fail. @JsonIgnore private[sql] var executionFailure: Option[Exception] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 633f56993484d..9c0030d6a9363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.util -import java.util.concurrent.CopyOnWriteArrayList - import scala.collection.JavaConverters._ import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} -import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd import org.apache.spark.sql.internal.StaticSQLConf._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ListenerBus, Utils} /** * :: Experimental :: @@ -81,10 +78,9 @@ trait QueryExecutionListener { // The `loadExtensions` flag is used to indicate whether we should load the pre-defined, // user-specified listeners during construction. We should not do it when cloning this listener // manager, as we will copy all listeners to the cloned listener manager. -class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) - extends SparkListener with Logging { +class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) { - private[this] val listeners = new CopyOnWriteArrayList[QueryExecutionListener] + private val listenerBus = new ExecutionListenerBus(session) if (loadExtensions) { val conf = session.sparkContext.conf @@ -93,14 +89,12 @@ class ExecutionListenerManager private[sql](session: SparkSession, loadExtension } } - session.sparkContext.listenerBus.addToSharedQueue(this) - /** * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi def register(listener: QueryExecutionListener): Unit = { - listeners.add(listener) + listenerBus.addListener(listener) } /** @@ -108,7 +102,7 @@ class ExecutionListenerManager private[sql](session: SparkSession, loadExtension */ @DeveloperApi def unregister(listener: QueryExecutionListener): Unit = { - listeners.remove(listener) + listenerBus.removeListener(listener) } /** @@ -116,7 +110,7 @@ class ExecutionListenerManager private[sql](session: SparkSession, loadExtension */ @DeveloperApi def clear(): Unit = { - listeners.clear() + listenerBus.removeAllListeners() } /** @@ -124,24 +118,35 @@ class ExecutionListenerManager private[sql](session: SparkSession, loadExtension */ private[sql] def clone(session: SparkSession): ExecutionListenerManager = { val newListenerManager = new ExecutionListenerManager(session, loadExtensions = false) - listeners.iterator().asScala.foreach(newListenerManager.register) + listenerBus.listeners.asScala.foreach(newListenerManager.register) newListenerManager } +} + +private[sql] class ExecutionListenerBus(session: SparkSession) + extends SparkListener with ListenerBus[QueryExecutionListener, SparkListenerSQLExecutionEnd] { + + session.sparkContext.listenerBus.addToSharedQueue(this) override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case e: SparkListenerSQLExecutionEnd if shouldCatchEvent(e) => - val funcName = e.executionName.get - e.executionFailure match { + case e: SparkListenerSQLExecutionEnd => postToAll(e) + } + + override protected def doPostEvent( + listener: QueryExecutionListener, + event: SparkListenerSQLExecutionEnd): Unit = { + if (shouldReport(event)) { + val funcName = event.executionName.get + event.executionFailure match { case Some(ex) => - listeners.iterator().asScala.foreach(_.onFailure(funcName, e.qe, ex)) + listener.onFailure(funcName, event.qe, ex) case _ => - listeners.iterator().asScala.foreach(_.onSuccess(funcName, e.qe, e.duration)) + listener.onSuccess(funcName, event.qe, event.duration) } - - case _ => // Ignore + } } - private def shouldCatchEvent(e: SparkListenerSQLExecutionEnd): Boolean = { + private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = { // Only catch SQL execution with a name, and triggered by the same spark session that this // listener manager belongs. e.executionName.isDefined && e.qe.sparkSession.eq(this.session) From 3ffa536f3c29f6655843a4d45c215393f51e23c9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Oct 2018 22:13:57 +0800 Subject: [PATCH 5/6] add back Logging --- .../org/apache/spark/sql/util/QueryExecutionListener.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 9c0030d6a9363..296df25753198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.util import scala.collection.JavaConverters._ import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.QueryExecution @@ -78,7 +79,8 @@ trait QueryExecutionListener { // The `loadExtensions` flag is used to indicate whether we should load the pre-defined, // user-specified listeners during construction. We should not do it when cloning this listener // manager, as we will copy all listeners to the cloned listener manager. -class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) { +class ExecutionListenerManager private[sql](session: SparkSession, loadExtensions: Boolean) + extends Logging { private val listenerBus = new ExecutionListenerBus(session) From 0bfc2408a5941d7da8d93582668ba77a7394ac66 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 11 Oct 2018 23:41:08 +0800 Subject: [PATCH 6/6] fix a mistake --- .../scala/org/apache/spark/sql/util/QueryExecutionListener.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 296df25753198..1310fdfa1356b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -132,6 +132,7 @@ private[sql] class ExecutionListenerBus(session: SparkSession) override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case e: SparkListenerSQLExecutionEnd => postToAll(e) + case _ => } override protected def doPostEvent(