diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4df90d7b6b0b..7c0b30997ecd 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -384,19 +384,23 @@ private[spark] class ApplicationMaster( // been set by the Thread executing the user class. val sc = waitForSparkContextInitialized() - // If there is no SparkContext at this point, just fail the app. - if (sc == null) { - finish(FinalApplicationStatus.FAILED, - ApplicationMaster.EXIT_SC_NOT_INITED, - "Timed out waiting for SparkContext.") - } else { - rpcEnv = sc.env.rpcEnv - val driverRef = runAMEndpoint( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port"), - isClusterMode = true) - registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) - userClassThread.join() + if (!finished) { + // If there is no SparkContext at this point, just fail the app. + if (!sc.isDefined) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SC_NOT_INITED, + "Timed out waiting for SparkContext.") + } else { + val sparkContext = sc.get + rpcEnv = sparkContext.env.rpcEnv + val driverRef = runAMEndpoint( + sparkContext.getConf.get("spark.driver.host"), + sparkContext.getConf.get("spark.driver.port"), + isClusterMode = true) + registerAM(rpcEnv, driverRef, sparkContext.ui.map(_.appUIAddress).getOrElse(""), + securityMgr) + userClassThread.join() + } } } @@ -503,7 +507,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkContextInitialized(): SparkContext = { + private def waitForSparkContextInitialized(): Option[SparkContext] = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME) @@ -513,13 +517,16 @@ private[spark] class ApplicationMaster( logInfo("Waiting for spark context initialization ... ") sparkContextRef.wait(10000L) } - - val sparkContext = sparkContextRef.get() - if (sparkContext == null) { - logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier" - + " log output for errors. Failing the application.").format(totalWaitTime)) + if (!finished) { + val sparkContext = sparkContextRef.get() + if (sparkContext == null) { + logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier" + + " log output for errors. Failing the application.").format(totalWaitTime)) + } else { + return Some(sparkContext) + } } - sparkContext + None } }