diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 8ffcfc0878a42..c0f9129a423f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -17,69 +17,89 @@ package org.apache.spark.deploy.mesos -import java.net.SocketAddress import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver +import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf +import org.apache.spark.util.ThreadUtils /** * An RPC endpoint that receives registration requests from Spark drivers running on Mesos. * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. */ -private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) +private[mesos] class MesosExternalShuffleBlockHandler( + transportConf: TransportConf, + cleanerIntervalS: Long) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { - // Stores a map of driver socket addresses to app ids - private val connectedApps = new mutable.HashMap[SocketAddress, String] + ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") + .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS) + + // Stores a map of app id to app state (timeout value and last heartbeat) + private val connectedApps = new ConcurrentHashMap[String, AppState]() protected override def handleMessage( message: BlockTransferMessage, client: TransportClient, callback: RpcResponseCallback): Unit = { message match { - case RegisterDriverParam(appId) => + case RegisterDriverParam(appId, appState) => val address = client.getSocketAddress - logDebug(s"Received registration request from app $appId (remote address $address).") - if (connectedApps.contains(address)) { - val existingAppId = connectedApps(address) - if (!existingAppId.equals(appId)) { - logError(s"A new app '$appId' has connected to existing address $address, " + - s"removing previously registered app '$existingAppId'.") - applicationRemoved(existingAppId, true) - } + val timeout = appState.heartbeatTimeout + logInfo(s"Received registration request from app $appId (remote address $address, " + + s"heartbeat timeout $timeout ms).") + if (connectedApps.containsKey(appId)) { + logWarning(s"Received a registration request from app $appId, but it was already " + + s"registered") } - connectedApps(address) = appId + connectedApps.put(appId, appState) callback.onSuccess(ByteBuffer.allocate(0)) + case Heartbeat(appId) => + val address = client.getSocketAddress + Option(connectedApps.get(appId)) match { + case Some(existingAppState) => + logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " + + s"address $address).") + existingAppState.lastHeartbeat = System.nanoTime() + case None => + logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + + s"address $address, appId '$appId').") + } case _ => super.handleMessage(message, client, callback) } } - /** - * On connection termination, clean up shuffle files written by the associated application. - */ - override def connectionTerminated(client: TransportClient): Unit = { - val address = client.getSocketAddress - if (connectedApps.contains(address)) { - val appId = connectedApps(address) - logInfo(s"Application $appId disconnected (address was $address).") - applicationRemoved(appId, true /* cleanupLocalDirs */) - connectedApps.remove(address) - } else { - logWarning(s"Unknown $address disconnected.") - } - } - /** An extractor object for matching [[RegisterDriver]] message. */ private object RegisterDriverParam { - def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId) + def unapply(r: RegisterDriver): Option[(String, AppState)] = + Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime()))) + } + + private object Heartbeat { + def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId) + } + + private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) + + private class CleanerThread extends Runnable { + override def run(): Unit = { + val now = System.nanoTime() + connectedApps.asScala.foreach { case (appId, appState) => + if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { + logInfo(s"Application $appId timed out. Removing shuffle files.") + connectedApps.remove(appId) + applicationRemoved(appId, true) + } + } + } } } @@ -93,7 +113,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage protected override def newShuffleBlockHandler( conf: TransportConf): ExternalShuffleBlockHandler = { - new MesosExternalShuffleBlockHandler(conf) + val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s") + new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS) } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index ab5bde55e683a..a3ebaff2ee697 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,6 +114,19 @@ private[spark] class Executor( private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -218,6 +231,7 @@ private[spark] class Executor( threwException = false res } finally { + val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" @@ -227,6 +241,17 @@ private[spark] class Executor( logError(errMsg) } } + + if (releasedLocks.nonEmpty) { + val errMsg = + s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" + + releasedLocks.mkString("[", ", ", "]") + if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } } val taskFinish = System.currentTimeMillis() @@ -452,8 +477,16 @@ private[spark] class Executor( logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd540..99858f785600d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 70af83b5ee092..89edaf58ebc29 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -119,13 +119,13 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w } /** - * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number - * of bytes removed from the pool's capacity. + * Free space to shrink the size of this storage memory pool by `spaceToFree` bytes. + * Note: this method doesn't actually reduce the pool size but relies on the caller to do so. + * + * @return number of bytes to be removed from the pool's capacity. */ - def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { - // First, shrink the pool by reclaiming free memory: + def freeSpaceToShrinkPool(spaceToFree: Long): Long = lock.synchronized { val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) - decrementPoolSize(spaceFreedByReleasingUnusedMemory) val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: @@ -134,7 +134,6 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. - decrementPoolSize(spaceFreedByEviction) spaceFreedByReleasingUnusedMemory + spaceFreedByEviction } else { spaceFreedByReleasingUnusedMemory diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 829f054dba0e9..802087c82b713 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -57,8 +57,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( storageRegionSize, maxMemory - storageRegionSize) { + assertInvariant() + // We always maintain this invariant: - assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + private def assertInvariant(): Unit = { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + } override def maxStorageMemory: Long = synchronized { maxMemory - onHeapExecutionMemoryPool.memoryUsed @@ -77,7 +81,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = synchronized { - assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + assertInvariant() assert(numBytes >= 0) memoryMode match { case MemoryMode.ON_HEAP => @@ -99,9 +103,10 @@ private[spark] class UnifiedMemoryManager private[memory] ( math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) if (memoryReclaimableFromStorage > 0) { // Only reclaim as much space as is necessary and available: - val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + val spaceToReclaim = storageMemoryPool.freeSpaceToShrinkPool( math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) - onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + storageMemoryPool.decrementPoolSize(spaceToReclaim) + onHeapExecutionMemoryPool.incrementPoolSize(spaceToReclaim) } } } @@ -137,7 +142,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) + assertInvariant() assert(numBytes >= 0) if (numBytes > maxStorageMemory) { // Fail fast if the block simply won't fit diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a53bc5ef4ffae..47f2f9d3cb647 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -581,7 +581,7 @@ private[netty] class NettyRpcHandler( private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } @@ -605,7 +605,7 @@ private[netty] class NettyRpcHandler( override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) // If the remove RpcEnv listens to some address, we should also fire a // RemoteProcessConnectionError for the remote RpcEnv listening address @@ -625,7 +625,7 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { clients.remove(client) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) val remoteEnvAddress = remoteAddresses.remove(clientAddr) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index d2e94f943aba5..65e9b56f81a40 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -29,7 +29,7 @@ import org.apache.spark.rpc.RpcAddress * @param rpcAddress The socket address of the endpint. * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { +private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { require(name != null, "RpcEndpoint name must be provided.") @@ -44,7 +44,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam } } -private[netty] object RpcEndpointAddress { +private[spark] object RpcEndpointAddress { + + def apply(host: String, port: Int, name: String): RpcEndpointAddress = { + new RpcEndpointAddress(host, port, name) + } def apply(sparkUrl: String): RpcEndpointAddress = { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 832eef38eef56..77a8a195ffeb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -720,7 +720,16 @@ private[spark] class TaskSetManager( failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) - addPendingTask(index) + + if (successful(index)) { + logInfo( + s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, " + + "but another instance of the task has already succeeded, " + + "so not re-queuing the task to be re-executed.") + } else { + addPendingTask(index) + } + if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 505c161141c88..87f2dbf6cb9b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -179,6 +179,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) context.reply(true) @@ -263,7 +267,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) - case None => logInfo(s"Asked to remove non-existent executor $executorId") + case None => + // SPARK-15262: If an executor is still alive even after the scheduler has removed + // its metadata, we may receive a heartbeat from that executor and tell its block + // manager to reregister itself. If that happens, the block manager master will know + // about the executor, but the scheduler will not. Therefore, we should remove the + // executor from the block manager when we hit this case. + scheduler.sc.env.blockManager.master.removeExecutor(executorId) + logInfo(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 7d08eae0b4871..1a0864270f582 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,20 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashMap, HashSet} -import com.google.common.collect.HashBiMap -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.netty.RpcEndpointAddress import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -60,28 +60,38 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private[this] val shutdownTimeoutMS = + conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") + + // Synchronization protected by stateLock + private[this] var stopCalled: Boolean = false + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] + val coresByTaskId = new HashMap[String, Int] var totalCoresAcquired = 0 - val slaveIdsWithExecutors = new HashSet[String] - - // Maping from slave Id to hostname - private val slaveIdToHost = new HashMap[String, String] - - val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] - // How many times tasks on each slave failed - val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + // SlaveID -> Slave + // This map accumulates entries for the duration of the job. Slaves are never deleted, because + // we need to maintain e.g. failure state and connection state. + private val slaves = new HashMap[String, Slave] /** - * The total number of executors we aim to have. Undefined when not using dynamic allocation - * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + * The total number of executors we aim to have. Undefined when not using dynamic allocation. + * Initially set to 0 when using dynamic allocation, the executor allocation manager will send + * the real initial limit later. */ - private var executorLimitOption: Option[Int] = None + private var executorLimitOption: Option[Int] = { + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + } /** * Return the current executor limit, which may be [[Int.MaxValue]] @@ -89,13 +99,11 @@ private[spark] class CoarseMesosSchedulerBackend( */ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) - private val pendingRemovedSlaveIds = new HashSet[String] - // private lock object protecting mutable state above. Using the intrinsic lock // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = @@ -105,27 +113,32 @@ private[spark] class CoarseMesosSchedulerBackend( private val rejectOfferDurationForUnmetConstraints = getRejectOfferDurationForUnmetConstraints(sc) - // A client for talking to the external shuffle service, if it is a + // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { - Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf, "shuffle"), - securityManager, - securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled())) + Some(getShuffleClient()) } else { None } } + // This method is factored out for testability + protected def getShuffleClient(): MesosExternalShuffleClient = { + new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf, "shuffle"), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) + } + var nextMesosTaskId = 0 @volatile var appId: String = _ - def newMesosTaskId(): Int = { + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 - id + id.toString } override def start() { @@ -136,11 +149,12 @@ private[spark] class CoarseMesosSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.ui.map(_.appUIAddress)) + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)) + ) startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -179,12 +193,12 @@ private[spark] class CoarseMesosSchedulerBackend( .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { - val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + s" --driver-url $driverURL" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -192,12 +206,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head - val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + s" --driver-url $driverURL" + - s" --executor-id $executorId" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -215,10 +228,10 @@ private[spark] class CoarseMesosSchedulerBackend( if (conf.contains("spark.testing")) { "driverURL" } else { - sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } @@ -245,113 +258,221 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers.asScala) { + if (stopCalled) { + logDebug("Ignoring offers during shutdown") + // Driver should simply return a stopped status on race + // condition between this.stop() and completing here + offers.asScala.map(_.getId).foreach(d.declineOffer) + return + } + + logDebug(s"Received ${offers.size} resource offers.") + + val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => val offerAttributes = toAttributeMap(offer.getAttributesList) - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + } + + declineUnmatchedOffers(d, unmatchedOffers) + handleMatchedOffers(d, matchedOffers) + } + } + + private def declineUnmatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + for (offer <- offers) { + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val filters = Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build() + + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + + d.declineOffer(offer.getId, filters) + } + } + + /** + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ + private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + val tasks = buildMesosTasks(offers) + for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val offerMem = getResource(offer.getResourcesList, "mem") + val offerCpus = getResource(offer.getResourcesList, "cpus") + val id = offer.getId.getValue + + if (tasks.contains(offer.getId)) { // accept + val offerTasks = tasks(offer.getId) + + logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus. Launching ${offerTasks.size} Mesos tasks.") + + for (task <- offerTasks) { + val taskId = task.getTaskId + val mem = getResource(task.getResourcesList, "mem") + val cpus = getResource(task.getResourcesList, "cpus") + + logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus.") + } + + d.launchTasks( + Collections.singleton(offer.getId), + offerTasks.asJava) + } else { // decline + logDebug(s"Declining offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus") + + d.declineOffer(offer.getId) + } + } + } + + /** + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ + private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { + // offerID -> tasks + val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) + + // offerID -> resources + val remainingResources = mutable.Map(offers.map(offer => + (offer.getId.getValue, offer.getResourcesList)): _*) + + var launchTasks = true + + // TODO(mgummelt): combine offers for a single slave + // + // round-robin create executors on the available offers + while (launchTasks) { + launchTasks = false + + for (offer <- offers) { val slaveId = offer.getSlaveId.getValue - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - val id = offer.getId.getValue - if (meetsConstraints) { - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) - } - - // Accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) - } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + val offerId = offer.getId.getValue + val resources = remainingResources(offerId) + + if (canLaunchTask(slaveId, resources)) { + // Create a task + launchTasks = true + val taskId = newMesosTaskId() + val offerCPUs = getResource(resources, "cpus").toInt + + val taskCPUs = executorCores(offerCPUs) + val taskMemory = executorMemory(sc) + + slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) + + val (afterCPUResources, cpuResourcesToUse) = + partitionResources(resources, "cpus", taskCPUs) + val (resourcesLeft, memResourcesToUse) = + partitionResources(afterCPUResources.asJava, "mem", taskMemory) + + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder) } - } else { - // This offer does not meet constraints. We don't need to see it again. - // Decline the offer for a long period of time. - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" - + s" for $rejectOfferDurationForUnmetConstraints seconds") - d.declineOffer(offer.getId, Filters.newBuilder() - .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + + tasks(offer.getId) ::= taskBuilder.build() + remainingResources(offerId) = resourcesLeft.asJava + totalCoresAcquired += taskCPUs + coresByTaskId(taskId) = taskCPUs } } } + tasks.toMap } + private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + val offerMem = getResource(resources, "mem") + val offerCPUs = getResource(resources, "cpus").toInt + val cpus = executorCores(offerCPUs) + val mem = executorMemory(sc) + + cpus > 0 && + cpus <= offerCPUs && + cpus + totalCoresAcquired <= maxCores && + mem <= offerMem && + numExecutors() < executorLimit && + slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES + } + + private def executorCores(offerCPUs: Int): Int = { + sc.conf.getInt("spark.executor.cores", + math.min(offerCPUs, maxCores - totalCoresAcquired)) + } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo(s"Mesos task $taskId is now $state") - val slaveId: String = status.getSlaveId.getValue + val taskId = status.getTaskId.getValue + val slaveId = status.getSlaveId.getValue + val state = TaskState.fromMesos(status.getState) + + logInfo(s"Mesos task $taskId is now ${status.getState}") + stateLock.synchronized { + val slave = slaves(slaveId) + // If the shuffle service is enabled, have the driver register with each one of the // shuffle services. This allows the shuffle services to clean up state associated with // this application when the driver exits. There is currently not a great way to detect // this through Mesos, since the shuffle services are set up independently. - if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && - slaveIdToHost.contains(slaveId) && - shuffleServiceEnabled) { + if (state.equals(TaskState.RUNNING) && + shuffleServiceEnabled && + !slave.shuffleRegistered) { assume(mesosExternalShuffleClient.isDefined, "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) - val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + - s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get - .registerDriverWithShuffleService(hostname, externalShufflePort) + .registerDriverWithShuffleService( + slave.hostname, + externalShufflePort, + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + slave.shuffleRegistered = true } - if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId.get(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId.remove(taskId) + if (TaskState.isFinished(state)) { // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (TaskState.isFailed(TaskState.fromMesos(state))) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { + if (TaskState.isFailed(state)) { + slave.taskFailures += 1 + + if (slave.taskFailures >= MAX_SLAVE_FAILURES) { logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } - executorTerminated(d, slaveId, s"Executor finished with state $state") + executorTerminated(d, slaveId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -364,7 +485,35 @@ private[spark] class CoarseMesosSchedulerBackend( } override def stop() { - super.stop() + // Make sure we're not launching tasks during shutdown + stateLock.synchronized { + if (stopCalled) { + logWarning("Stop called multiple times, ignoring") + return + } + stopCalled = true + super.stop() + } + + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. + // See SPARK-12330 + val startTime = System.nanoTime() + + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + while (numExecutors() > 0 && + System.nanoTime() - startTime < shutdownTimeoutMS * 1000L * 1000L) { + Thread.sleep(100) + } + + if (numExecutors() > 0) { + logWarning(s"Timed out waiting for ${numExecutors()} remaining executors " + + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + + "on the mesos nodes.") + } + + // Close the mesos external shuffle client if used + mesosExternalShuffleClient.foreach(_.close()) + if (mesosDriver != null) { mesosDriver.stop() } @@ -373,40 +522,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** - * Called when a slave is lost or a Mesos task finished. Update local view on - * what tasks are running and remove the terminated slave from the list of pending - * slave IDs that we might have asked to be killed. It also notifies the driver - * that an executor was removed. + * Called when a slave is lost or a Mesos task finished. Updates local view on + * what tasks are running. It also notifies the driver that an executor was removed. */ - private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + private def executorTerminated( + d: SchedulerDriver, + slaveId: String, + taskId: String, + reason: String): Unit = { stateLock.synchronized { - if (slaveIdsWithExecutors.contains(slaveId)) { - val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.containsKey(slaveId)) { - val taskId: Int = slaveIdToTaskId.get(slaveId) - taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) - } - // TODO: This assumes one Spark executor per Mesos slave, - // which may no longer be true after SPARK-5095 - pendingRemovedSlaveIds -= slaveId - slaveIdsWithExecutors -= slaveId - } + removeExecutor(taskId, SlaveLost(reason)) + slaves(slaveId).taskIDs.remove(taskId) } } - private def sparkExecutorId(slaveId: String, taskId: String): String = { - s"$slaveId/$taskId" - } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { logInfo(s"Mesos slave lost: ${slaveId.getValue}") - executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + logInfo("Mesos executor lost: %s".format(e.getValue)) } override def applicationId(): String = @@ -426,23 +561,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def doKillExecutors(executorIds: Seq[String]): Boolean = { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") - return false - } - - val slaveIdToTaskId = taskIdToSlaveId.inverse() - for (executorId <- executorIds) { - val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.containsKey(slaveId)) { - mesosDriver.killTask( - TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) - pendingRemovedSlaveIds += slaveId - } else { - logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + false + } else { + for (executorId <- executorIds) { + val taskId = TaskID.newBuilder().setValue(executorId).build() + mesosDriver.killTask(taskId) } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true } - // no need to adjust `executorLimitOption` since the AllocationManager already communicated - // the desired limit through a call to `doRequestTotalExecutors`. - // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] - true } + + private def numExecutors(): Int = { + slaves.values.map(_.taskIDs.size).sum + } +} + +private class Slave(val hostname: String) { + val taskIDs = new HashSet[String]() + var taskFailures = 0 + var shuffleRegistered = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 573355ba58132..be710f9361b7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods the Mesos scheduler will use. + * methods and Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered @@ -106,44 +106,56 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.await() return } + @volatile + var error: Option[Exception] = None + // We create a new thread that will block inside `mesosDriver.run` + // until the scheduler exists new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { setDaemon(true) - override def run() { - mesosDriver = newDriver try { + mesosDriver = newDriver val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { - System.exit(1) + error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) + markErr() } } catch { - case e: Exception => { + case e: Exception => logError("driver.run() failed", e) - System.exit(1) - } + error = Some(e) + markErr() } } }.start() registerLatch.await() + + // propagate any error to the calling thread. This ensures that SparkContext creation fails + // without leaving a broken context that won't be able to schedule any tasks + error.foreach(throw _) } } - /** - * Signal that the scheduler has registered with Mesos. - */ - protected def getResource(res: JList[Resource], name: String): Double = { + def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } + /** + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } + protected def markErr(): Unit = { + registerLatch.countDown() + } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { val builder = Resource.newBuilder() .setName(name) @@ -170,7 +182,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { - case r => { + case r => if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && @@ -182,7 +194,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } else { r } - } } // Filter any resource that has depleted. @@ -214,7 +225,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.asScala.map(attr => { + offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -222,7 +233,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { case Value.Type.TEXT => attr.getText } (attr.getName, attrValue) - }).toMap + }.toMap } @@ -269,11 +280,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: * {{{ - * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") * // would result in * * Map( - * "tachyon" -> Set("true"), + * "os" -> Set("centos7"), * "zone": -> Set("us-east-1a", "us-east-1b") * ) * }}} @@ -317,6 +328,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private val MEMORY_OVERHEAD_FRACTION = 0.10 private val MEMORY_OVERHEAD_MINIMUM = 384 + def calculateTotalMemory(sc: SparkContext): Int = executorMemory(sc) + /** * Return the amount of memory to allocate to each executor, taking into account * container overheads. @@ -324,7 +337,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ - def calculateTotalMemory(sc: SparkContext): Int = { + def executorMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + sc.executorMemory diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 538272dc00db6..288f756bca39b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,12 +19,14 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.Random import scala.util.control.NonFatal +import scala.collection.JavaConverters._ import sun.nio.ch.DirectBuffer @@ -65,7 +67,7 @@ private[spark] class BlockManager( val master: BlockManagerMaster, defaultSerializer: Serializer, val conf: SparkConf, - memoryManager: MemoryManager, + val memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, @@ -164,6 +166,11 @@ private[spark] class BlockManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + // Blocks are removing by another thread + val pendingToRemove = new ConcurrentHashMap[BlockId, Long]() + + private val NON_TASK_WRITER = -1024L + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -1025,54 +1032,58 @@ private[spark] class BlockManager( val info = blockInfo.get(blockId).orNull // If the block has not already been dropped - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure. Nothing to drop") - return None - } else if (blockInfo.get(blockId).isEmpty) { - logWarning(s"Block $blockId was already dropped.") - return None - } - var blockIsUpdated = false - val level = info.level + if (info != null && pendingToRemove.putIfAbsent(blockId, currentTaskAttemptId) == 0L) { + try { + info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (!info.waitForReady()) { + // If we get here, the block write failed. + logWarning(s"Block $blockId was marked as failure. Nothing to drop") + return None + } else if (blockInfo.get(blockId).isEmpty) { + logWarning(s"Block $blockId was already dropped.") + return None + } + var blockIsUpdated = false + val level = info.level - // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo(s"Writing block $blockId to disk") - data() match { - case Left(elements) => - diskStore.putArray(blockId, elements, level, returnValues = false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + // Drop to disk, if storage level requires + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo(s"Writing block $blockId to disk") + data() match { + case Left(elements) => + diskStore.putArray(blockId, elements, level, returnValues = false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + blockIsUpdated = true } - blockIsUpdated = true - } - // Actually drop from memory store - val droppedMemorySize = - if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockIsRemoved = memoryStore.remove(blockId) - if (blockIsRemoved) { - blockIsUpdated = true - } else { - logWarning(s"Block $blockId could not be dropped from memory as it does not exist") - } + // Actually drop from memory store + val droppedMemorySize = + if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val blockIsRemoved = memoryStore.remove(blockId) + if (blockIsRemoved) { + blockIsUpdated = true + } else { + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") + } - val status = getCurrentBlockStatus(blockId, info) - if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - if (blockIsUpdated) { - return Some(status) + val status = getCurrentBlockStatus(blockId, info) + if (info.tellMaster) { + reportBlockStatus(blockId, info, status, droppedMemorySize) + } + if (!level.useDisk) { + // The block is completely gone from this node;forget it so we can put() it again later. + blockInfo.remove(blockId) + } + if (blockIsUpdated) { + return Some(status) + } } + } finally { + pendingToRemove.remove(blockId) } } None @@ -1108,27 +1119,32 @@ private[spark] class BlockManager( def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - val removedFromExternalBlockStore = - if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) { - logWarning(s"Block $blockId could not be removed as it was not found in either " + - "the disk, memory, or external block store") - } - blockInfo.remove(blockId) - val status = getCurrentBlockStatus(blockId, info) - if (tellMaster && info.tellMaster) { - reportBlockStatus(blockId, info, status) - } - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, status))) + if (info != null && pendingToRemove.putIfAbsent(blockId, currentTaskAttemptId) == 0L) { + try { + info.synchronized { + val level = info.level + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = if (level.useMemory) memoryStore.remove(blockId) else false + val removedFromDisk = if (level.useDisk) diskStore.remove(blockId) else false + val removedFromExternalBlockStore = + if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false + if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) { + logWarning(s"Block $blockId could not be removed as it was not found in either " + + "the disk, memory, or external block store") + } + blockInfo.remove(blockId) + val status = getCurrentBlockStatus(blockId, info) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, info, status) + } + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ Seq((blockId, status))) + } } + } finally { + pendingToRemove.remove(blockId) } } else { // The block has already been removed; do nothing. @@ -1151,14 +1167,19 @@ private[spark] class BlockManager( while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) - if (time < cleanupTime && shouldDrop(id)) { - info.synchronized { - val level = info.level - if (level.useMemory) { memoryStore.remove(id) } - if (level.useDisk) { diskStore.remove(id) } - if (level.useOffHeap) { externalBlockStore.remove(id) } - iterator.remove() - logInfo(s"Dropped block $id") + if (time < cleanupTime && shouldDrop(id) && + pendingToRemove.putIfAbsent(id, currentTaskAttemptId) == 0L) { + try { + info.synchronized { + val level = info.level + if (level.useMemory) { memoryStore.remove(id) } + if (level.useDisk) { diskStore.remove(id) } + if (level.useOffHeap) { externalBlockStore.remove(id) } + iterator.remove() + logInfo(s"Dropped block $id") + } + } finally { + pendingToRemove.remove(id) } val status = getCurrentBlockStatus(id, info) reportBlockStatus(id, info, status) @@ -1166,6 +1187,32 @@ private[spark] class BlockManager( } } + private def currentTaskAttemptId: Long = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(NON_TASK_WRITER) + } + + def getBlockInfo(blockId: BlockId): BlockInfo = { + blockInfo.get(blockId).orNull + } + + /** + * Release all lock held by the given task, clearing that task's pin bookkeeping + * structures and updating the global pin counts. This method should be called at the + * end of a task (either by a task completion handler or in `TaskRunner.run()`). + * + * @return the ids of blocks whose pins were released + */ + def releaseAllLocksForTask(taskAttemptId: Long): ArrayBuffer[BlockId] = { + var selectLocks = ArrayBuffer[BlockId]() + pendingToRemove.entrySet().asScala.foreach { entry => + if (entry.getValue == taskAttemptId) { + pendingToRemove.remove(entry.getKey) + selectLocks += entry.getKey + } + } + selectLocks + } + private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle @@ -1239,6 +1286,7 @@ private[spark] class BlockManager( rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() + pendingToRemove.clear() diskStore.clear() if (externalBlockStoreInitialized) { externalBlockStore.clear() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 47d6c3646c331..08dc17d5887e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -23,6 +23,8 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, ListBuffer} import scala.xml._ +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} @@ -82,9 +84,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { case JobExecutionStatus.UNKNOWN => "unknown" } - // The timeline library treats contents as HTML, so we have to escape them; for the - // data-title attribute string we have to escape them twice since that's in a string. + // The timeline library treats contents as HTML, so we have to escape them. We need to add + // extra layers of escaping in order to embed this in a Javascript string literal. val escapedDesc = Utility.escape(displayJobDescription) + val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" |{ @@ -94,7 +97,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { | 'end': new Date(${completionTime}), | 'content': '
' + | 'Status: ${status}
' + | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + | '${ @@ -104,7 +107,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { "" } }">' + - | '${escapedDesc} (Job ${jobId})
' + | '${jsEscapedDesc} (Job ${jobId})' |} """.stripMargin jobEventJsonAsStr diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 6a35f0e0a87a0..8c6a6681eabbc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -24,6 +24,8 @@ import scala.xml.{NodeSeq, Node, Unparsed, Utility} import javax.servlet.http.HttpServletRequest +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} @@ -64,9 +66,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val submissionTime = stage.submissionTime.get val completionTime = stage.completionTime.getOrElse(System.currentTimeMillis()) - // The timeline library treats contents as HTML, so we have to escape them; for the - // data-title attribute string we have to escape them twice since that's in a string. + // The timeline library treats contents as HTML, so we have to escape them. We need to add + // extra layers of escaping in order to embed this in a Javascript string literal. val escapedName = Utility.escape(name) + val jsEscapedName = StringEscapeUtils.escapeEcmaScript(escapedName) s""" |{ | 'className': 'stage job-timeline-object ${status}', @@ -75,7 +78,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'end': new Date(${completionTime}), | 'content': '
' + | 'Status: ${status.toUpperCase}
' + | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + | '${ @@ -85,7 +88,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { "" } }">' + - | '${escapedName} (Stage ${stageId}.${attemptId})
', + | '${jsEscapedName} (Stage ${stageId}.${attemptId})', |} """.stripMargin } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 555b640cb4244..6a195ef7fe5b3 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -76,6 +76,21 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft ms } + /** + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is + * stubbed to always throw [[RuntimeException]]. + */ + protected def makeBadMemoryStore(mm: MemoryManager): MemoryStore = { + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { + throw new RuntimeException("bad memory store!") + } + }) + mm.setMemoryStore(ms) + ms + } + /** * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. * diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 6cc48597d38f9..46b6916a12fc2 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -255,4 +255,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(evictedBlocks.nonEmpty) } + test("SPARK-15260: atomically resize memory pools") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + makeBadMemoryStore(mm) + val memoryMode = MemoryMode.ON_HEAP + // Acquire 1000 then release 600 bytes of storage memory, leaving the + // storage memory pool at 1000 bytes but only 400 bytes of which are used. + assert(mm.acquireStorageMemory(dummyBlock, 1000L, evictedBlocks)) + mm.releaseStorageMemory(600L) + // Before the fix for SPARK-15260, we would first shrink the storage pool by the amount of + // unused storage memory (600 bytes), try to evict blocks, then enlarge the execution pool + // by the same amount. If the eviction threw an exception, then we would shrink one pool + // without enlarging the other, resulting in an assertion failure. + intercept[RuntimeException] { + mm.acquireExecutionMemory(1000L, 0, memoryMode) + } + val assertInvariant = PrivateMethod[Unit]('assertInvariant) + mm.invokePrivate[Unit](assertInvariant()) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5a..db4dc2d3105d8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -17,26 +17,248 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util import java.util.Collections -import org.apache.mesos.Protos.Value.Scalar -import org.apache.mesos.Protos._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.Matchers import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter +import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with BeforeAndAfter { + private var sparkConf: SparkConf = _ + private var driver: SchedulerDriver = _ + private var taskScheduler: TaskSchedulerImpl = _ + private var backend: CoarseMesosSchedulerBackend = _ + private var externalShuffleClient: MesosExternalShuffleClient = _ + private var driverEndpoint: RpcEndpointRef = _ + + test("mesos supports killing and limiting executors") { + setBackend() + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val minMem = backend.executorMemory(sc) + val minCpu = 4 + val offers = List((minMem, minCpu)) + + // launches a task on a valid offer + offerResources(offers) + verifyTaskLaunched("o1") + + // kills executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("0"))) + val taskID0 = createTaskId("0") + verify(driver, times(1)).killTask(taskID0) + + // doesn't launch a new task when requested executors == 0 + offerResources(offers, 2) + verifyDeclinedOffer(driver, createOfferId("o2")) + + // Launches a new task when requested executors is positive + backend.doRequestTotalExecutors(2) + offerResources(offers, 2) + verifyTaskLaunched("o2") + } + + test("mesos supports killing and relaunching tasks with executors") { + setBackend() + + // launches a task on a valid offer + val minMem = backend.executorMemory(sc) + 1024 + val minCpu = 4 + val offer1 = (minMem, minCpu) + val offer2 = (minMem, 1) + offerResources(List(offer1, offer2)) + verifyTaskLaunched("o1") + + // accounts for a killed task + val status = createTaskStatus("0", "s1", TaskState.TASK_KILLED) + backend.statusUpdate(driver, status) + verify(driver, times(1)).reviveOffers() + + // Launches a new task on a valid offer from the same slave + offerResources(List(offer2)) + verifyTaskLaunched("o2") + } + + test("mesos supports spark.executor.cores") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + val executorMemory = backend.executorMemory(sc) + val offers = List((executorMemory * 2, executorCores + 1)) + offerResources(offers) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == executorCores) + } + + test("mesos supports unset spark.executor.cores") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + val offerCores = 10 + offerResources(List((executorMemory * 2, offerCores))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == offerCores) + } + + test("mesos does not acquire more than spark.cores.max") { + val maxCores = 10 + setBackend(Map("spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory, maxCores + 1))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == maxCores) + } + + test("mesos declines offers that violate attribute constraints") { + setBackend(Map("spark.mesos.constraints" -> "x:true")) + offerResources(List((backend.executorMemory(sc), 4))) + verifyDeclinedOffer(driver, createOfferId("o1"), true) + } + + test("mesos assigns tasks round-robin on offers") { + val executorCores = 4 + val maxCores = executorCores * 2 + setBackend(Map("spark.executor.cores" -> executorCores.toString, + "spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List( + (executorMemory * 2, executorCores * 2), + (executorMemory * 2, executorCores * 2))) + + verifyTaskLaunched("o1") + verifyTaskLaunched("o2") + } + + test("mesos creates multiple executors on a single slave") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + // offer with room for two executors + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory * 2, executorCores * 2))) + + // verify two executors were started on a single offer + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 2) + } + + test("mesos doesn't register twice with the same shuffle service") { + setBackend(Map("spark.shuffle.service.enabled" -> "true")) + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + val offer2 = createOffer("o2", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer2).asJava) + verifyTaskLaunched("o2") + + val status1 = createTaskStatus("0", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status1) + + val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status2) + verify(externalShuffleClient, times(1)) + .registerDriverWithShuffleService(anyString, anyInt, anyLong, anyLong) + } + + test("mesos kills an executor when told") { + setBackend() + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + backend.doKillExecutors(List("0")) + verify(driver, times(1)).killTask(createTaskId("0")) + } + + private def verifyDeclinedOffer(driver: SchedulerDriver, + offerId: OfferID, + filter: Boolean = false): Unit = { + if (filter) { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId), anyObject[Filters]) + } else { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId)) + } + } + + private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + val mesosOffers = offers.zipWithIndex.map {case (offer, i) => + createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} + + backend.resourceOffers(driver, mesosOffers.asJava) + } + + private def verifyTaskLaunched(offerId: String): java.util.Collection[TaskInfo] = { + val captor = ArgumentCaptor.forClass(classOf[java.util.Collection[TaskInfo]]) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + captor.capture()) + captor.getValue + } + + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { + TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setState(state) + .build + } + + + private def createOfferId(offerId: String): OfferID = { + OfferID.newBuilder().setValue(offerId).build() + } + + private def createSlaveId(slaveId: String): SlaveID = { + SlaveID.newBuilder().setValue(slaveId).build() + } + + private def createExecutorId(executorId: String): ExecutorID = { + ExecutorID.newBuilder().setValue(executorId).build() + } + + private def createTaskId(taskId: String): TaskID = { + TaskID.newBuilder().setValue(taskId).build() + } + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() @@ -47,8 +269,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder() - .setValue(offerId).build()) + builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) @@ -58,130 +279,61 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite private def createSchedulerBackend( taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + driver: SchedulerDriver, + shuffleClient: MesosExternalShuffleClient, + endpoint: RpcEndpointRef): CoarseMesosSchedulerBackend = { val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( - masterUrl: String, - scheduler: Scheduler, - sparkUser: String, - appName: String, - conf: SparkConf, - webuiUrl: Option[String] = None, - checkpoint: Option[Boolean] = None, - failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = driver + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + + override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient + + protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint + + // override to avoid race condition with the driver thread on `mesosDriver` + override def startScheduler(newDriver: SchedulerDriver): Unit = { + mesosDriver = newDriver + } + markRegistered() } backend.start() backend } - var sparkConf: SparkConf = _ - - before { + private def setBackend(sparkConfVars: Map[String, String] = null) { sparkConf = (new SparkConf) .setMaster("local[*]") .setAppName("test-mesos-dynamic-alloc") .setSparkHome("/path") + .set("spark.mesos.driver.webui.url", "http://webui") - sc = new SparkContext(sparkConf) - } - - test("mesos supports killing and limiting executors") { - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) - - sparkConf.set("spark.driver.host", "driverHost") - sparkConf.set("spark.driver.port", "1234") - - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) - - val taskID0 = TaskID.newBuilder().setValue("0").build() - - backend.resourceOffers(driver, mesosOffers) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) - - // simulate the allocation manager down-scaling executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("s1/0"))) - verify(driver, times(1)).killTask(taskID0) - - val mesosOffers2 = new java.util.ArrayList[Offer] - mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) - backend.resourceOffers(driver, mesosOffers2) - - verify(driver, times(1)) - .declineOffer(OfferID.newBuilder().setValue("o2").build()) - - // Verify we didn't launch any new executor - assert(backend.slaveIdsWithExecutors.size === 1) - - backend.doRequestTotalExecutors(2) - backend.resourceOffers(driver, mesosOffers2) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) + if (sparkConfVars != null) { + for (attr <- sparkConfVars) { + sparkConf.set(attr._1, attr._2) + } + } - assert(backend.slaveIdsWithExecutors.size === 2) - backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) - assert(backend.slaveIdsWithExecutors.size === 1) - } + sc = new SparkContext(sparkConf) - test("mesos supports killing and relaunching tasks with executors") { - val driver = mock[SchedulerDriver] + driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] + taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] + driverEndpoint = mock[RpcEndpointRef] - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) + 1024 - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - val offer1 = createOffer("o1", "s1", minMem, minCpu) - mesosOffers.add(offer1) - - val offer2 = createOffer("o2", "s1", minMem, 1); - - backend.resourceOffers(driver, mesosOffers) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer1.getId)), - anyObject(), - anyObject[Filters]) - - // Simulate task killed, executor no longer running - val status = TaskStatus.newBuilder() - .setTaskId(TaskID.newBuilder().setValue("0").build()) - .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) - .setState(TaskState.TASK_KILLED) - .build - - backend.statusUpdate(driver, status) - assert(!backend.slaveIdsWithExecutors.contains("s1")) - - mesosOffers.clear() - mesosOffers.add(offer2) - backend.resourceOffers(driver, mesosOffers) - assert(backend.slaveIdsWithExecutors.contains("s1")) - - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer2.getId)), - anyObject(), - anyObject[Filters]) - - verify(driver, times(1)).reviveOffers() + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index c00591fa371aa..4e66714ecbbb4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -424,6 +425,43 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("deadlock between dropFromMemory and removeBlock") { + store = makeBlockManager(2000) + val a1 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + val lock1 = new CountDownLatch(1) + val lock2 = new CountDownLatch(1) + + val t2 = new Thread { + override def run() = { + val info = store.getBlockInfo("a1") + info.synchronized { + store.pendingToRemove.put("a1", 1L) + lock1.countDown() + lock2.await() + store.pendingToRemove.remove("a1") + } + } + } + + val t1 = new Thread { + override def run() = { + store.memoryManager.synchronized { + t2.start() + lock1.await() + val status = store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + assert(status == None, "this thread can not get block a1") + lock2.countDown() + } + } + } + + t1.start() + t1.join() + t2.join() + store.removeBlock("a1", tellMaster = false) + } + test("correct BlockResult returned from get() calls") { store = makeBlockManager(12000) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) diff --git a/docs/configuration.md b/docs/configuration.md index 64a1899d69769..195fa3c953224 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -284,7 +284,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.logs.rolling.maxSize (none) - Set the max size of the file by which the executor logs will be rolled over. + Set the max size of the file in bytes by which the executor logs will be rolled over. Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -296,7 +296,7 @@ Apart from these, the following properties are also available, and may be useful Set the strategy of rolling of executor logs. By default it is disabled. It can be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", use spark.executor.logs.rolling.time.interval to set the rolling interval. - For "size", use spark.executor.logs.rolling.size.maxBytes to set + For "size", use spark.executor.logs.rolling.maxSize to set the maximum file size for rolling. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 675820308bd4c..2add9c83a73d2 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -19,7 +19,12 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +46,13 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); + private final ScheduledExecutorService heartbeaterThread = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("mesos-external-shuffle-client-heartbeater") + .build()); + /** * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. * Please refer to docs on {@link ExternalShuffleClient} for more information. @@ -53,21 +65,59 @@ public MesosExternalShuffleClient( super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); } - public void registerDriverWithShuffleService(String host, int port) throws IOException { + public void registerDriverWithShuffleService( + String host, + int port, + long heartbeatTimeoutMs, + long heartbeatIntervalMs) throws IOException { + checkInit(); - ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); + ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(registerDriver, new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully registered app " + appId + " with external shuffle service."); - } - - @Override - public void onFailure(Throwable e) { - logger.warn("Unable to register app " + appId + " with external shuffle service. " + + client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + private final TransportClient client; + private final long heartbeatIntervalMs; + + private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) { + this.client = client; + this.heartbeatIntervalMs = heartbeatIntervalMs; + } + + @Override + public void onSuccess(ByteBuffer response) { + heartbeaterThread.scheduleAtFixedRate( + new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS); + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + "Please manually remove shuffle data after driver exit. Error: " + e); - } - }); + } + } + + @Override + public void close() { + heartbeaterThread.shutdownNow(); + super.close(); + } + + private class Heartbeater implements Runnable { + + private final TransportClient client; + + private Heartbeater(TransportClient client) { + this.client = client; + } + + @Override + public void run() { + // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout + client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + } } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7fbe3384b4d4f..21c0ff4136aa8 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -24,6 +24,7 @@ import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -40,7 +41,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), + HEARTBEAT(5); private final byte id; @@ -64,6 +66,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); + case 5: return ShuffleServiceHeartbeat.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 94a61d6caadc4..b0b6d73ee9cd6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -31,33 +31,39 @@ */ public class RegisterDriver extends BlockTransferMessage { private final String appId; + private final long heartbeatTimeoutMs; - public RegisterDriver(String appId) { + public RegisterDriver(String appId, long heartbeatTimeoutMs) { this.appId = appId; + this.heartbeatTimeoutMs = heartbeatTimeoutMs; } public String getAppId() { return appId; } + public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } + @Override protected Type type() { return Type.REGISTER_DRIVER; } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId); + return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); + buf.writeLong(heartbeatTimeoutMs); } @Override public int hashCode() { - return Objects.hashCode(appId); + return Objects.hashCode(appId, heartbeatTimeoutMs); } public static RegisterDriver decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - return new RegisterDriver(appId); + long heartbeatTimeout = buf.readLong(); + return new RegisterDriver(appId, heartbeatTimeout); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java new file mode 100644 index 0000000000000..b30bb9aed55b6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A heartbeat sent from the driver to the MesosExternalShuffleService. + */ +public class ShuffleServiceHeartbeat extends BlockTransferMessage { + private final String appId; + + public ShuffleServiceHeartbeat(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.HEARTBEAT; } + + @Override + public int encodedLength() { return Encoders.Strings.encodedLength(appId); } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + public static ShuffleServiceHeartbeat decode(ByteBuf buf) { + return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf)); + } +} diff --git a/pom.xml b/pom.xml index 7df73a9f022fe..996a30a4e4e22 100644 --- a/pom.xml +++ b/pom.xml @@ -144,7 +144,7 @@ 1.7.0 1.6.0 1.2.4 - 8.1.14.v20131031 + 8.1.19.v20160209 3.0.0.v201112011016 0.5.0 2.4.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc62c7fc6a920..04f62d78ea91a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,7 +80,6 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: - DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9c78f6d4cc71b..47d6d3640b1af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -102,11 +102,8 @@ import org.apache.spark.sql.types.IntegerType */ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if !p.resolved => p - // We need to wait until this Aggregate operator is resolved. + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate => rewrite(a) - case p => p } def rewrite(a: Aggregate): Aggregate = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6d807c9ecf302..09bf2a71a8b7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -95,15 +95,17 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* ${this.toCommentSafeString} */" - GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) + GeneratedExpressionCode( + ctx.registerComment(this.toString), + subExprState.isNull, + subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) } } @@ -215,14 +217,6 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") - - /** - * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. - */ - protected def toCommentSafeString: String = this.toString - .replace("*/", "\\*\\/") - .replace("\\u", "\\\\u") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 9b8b6382d753d..84a456684871f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.commons.lang3.StringUtils + /** * An utility class that indents a block of code based on the curly braces and parentheses. * This is used to prettify generated code when in debug mode (or exceptions). @@ -24,7 +26,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen * Written by Matei Zaharia. */ object CodeFormatter { - def format(code: String): String = new CodeFormatter().addLines(code).result() + def format(code: CodeAndComment): String = { + new CodeFormatter().addLines( + StringUtils.replaceEach( + code.body, + code.comment.keys.toArray, + code.comment.values.toArray) + ).result + } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 440c7d2fc1156..0dec50e543fd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -125,6 +125,11 @@ class CodeGenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * The map from a place holder to a corresponding comment + */ + private val placeHolderToComments = new mutable.HashMap[String, String] + /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * @@ -458,6 +463,35 @@ class CodeGenContext { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.gen(this)) } + + /** + * get a map of the pair of a place holder and a corresponding comment + */ + def getPlaceHolderToComments(): collection.Map[String, String] = placeHolderToComments + + /** + * Register a multi-line comment and return the corresponding place holder + */ + private def registerMultilineComment(text: String): String = { + val placeHolder = s"/*${freshName("c")}*/" + val comment = text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") + placeHolderToComments += (placeHolder -> comment) + placeHolder + } + + /** + * Register a comment and return the corresponding place holder + */ + def registerComment(text: String): String = { + if (text.contains("\n") || text.contains("\r")) { + registerMultilineComment(text) + } else { + val placeHolder = s"/*${freshName("c")}*/" + val safeComment = s"// $text" + placeHolderToComments += (placeHolder -> safeComment) + placeHolder + } + } } /** @@ -468,6 +502,19 @@ abstract class GeneratedClass { def generate(expressions: Array[Expression]): Any } +/** + * A wrapper for the source code to be compiled by [[CodeGenerator]]. + */ +class CodeAndComment(val body: String, val comment: collection.Map[String, String]) + extends Serializable { + override def equals(that: Any): Boolean = that match { + case t: CodeAndComment if t.body == body => true + case _ => false + } + + override def hashCode(): Int = body.hashCode +} + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -511,14 +558,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * Compile the Java source code into a Java class, using Janino. */ - protected def compile(code: String): GeneratedClass = { + protected def compile(code: CodeAndComment): GeneratedClass = { cache.get(code) } /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: String): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException @@ -538,7 +585,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin )) evaluator.setExtendedClass(classOf[GeneratedClass]) - def formatted = CodeFormatter.format(code) + lazy val formatted = CodeFormatter.format(code) logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. @@ -547,7 +594,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin }) try { - evaluator.cook("generated.java", code) + evaluator.cook("generated.java", code.body) } catch { case e: Exception => val msg = s"failed to compile: $e\n$formatted" @@ -569,8 +616,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[String, GeneratedClass]() { - override def load(code: String): GeneratedClass = { + new CacheLoader[CodeAndComment, GeneratedClass]() { + override def load(code: CodeAndComment): GeneratedClass = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 26fb143d1e45c..e78ae7da8a909 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -32,8 +32,9 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") + val placeHolder = ctx.registerComment(this.toString) s""" - /* expression: ${this.toCommentSafeString} */ + $placeHolder java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 40189f0877764..764fbf417b393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -81,7 +81,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) - val code = s""" + val codeBody = s""" public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -119,6 +119,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } """ + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1af7c73cd4bf5..1ecebbaeec8ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -110,7 +110,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): BaseOrdering = { val ctx = newCodeGenContext() val comparisons = genComparisons(ctx, ordering) - val code = s""" + val codeBody = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); } @@ -133,6 +133,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } }""" + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 457b4f08424a6..639736749a361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,7 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) - val code = s""" + val codeBody = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); } @@ -61,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool } }""" + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index f229f2000d8e1..f8a3b5489086a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n") - val code = s""" + val codeBody = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); } @@ -230,6 +230,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + CodeFormatter.format(code)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b7926bda3de19..3ae1450922794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -147,7 +147,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] """ } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) - val code = s""" + val codeBody = s""" public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -173,6 +173,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 68005afb21d2e..6e0b5d198e500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -323,7 +323,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val code = s""" + val codeBody = s""" public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -353,6 +353,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index fb3c7b1bb4f72..b3ebc9cfac2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -158,7 +158,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // - val code = s""" + val codeBody = s""" |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} @@ -195,6 +195,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |} """.stripMargin + val code = new CodeAndComment(codeBody, Map.empty) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") val c = compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 682b860672b2d..676e0b7aceb67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.Inner @@ -30,14 +31,13 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -abstract class Optimizer extends RuleExecutor[LogicalPlan] - -object DefaultOptimizer extends Optimizer { +abstract class Optimizer(conf: CatalystConf) extends RuleExecutor[LogicalPlan] { val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: Batch("Aggregate", FixedPoint(100), + DistinctAggregationRewriter(conf), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), @@ -68,6 +68,18 @@ object DefaultOptimizer extends Optimizer { Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil } +case class DefaultOptimizer(conf: CatalystConf) extends Optimizer(conf) + +/** + * An optimizer used in test code. + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ +object SimpleTestOptimizer extends SimpleTestOptimizer + +class SimpleTestOptimizer extends Optimizer( + new SimpleCatalystConf(caseSensitiveAnalysis = true)) /** * Pushes operations down into a Sample. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 66a7e228f84bb..e35a1b2d7c9a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -138,4 +138,47 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { true, InternalRow(UTF8String.fromString("\\u"))) } + + test("check compilation error doesn't occur caused by specific literal") { + // The end of comment (*/) should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("*/Compilation error occurs/*", StringType) :: Nil) + + // `\u002A` is `*` and `\u002F` is `/` + // so if the end of comment consists of those characters in queries, we need to escape them. + GenerateUnsafeProjection.generate( + Literal.create("\\u002A/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002A/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\u002a/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002a/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\u002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\\\u002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\002fCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\\\002fCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\002A\\002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\002A\\002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\002A\\\\002FCompilation error occurs/*", StringType) :: Nil) + + // \ u002X is an invalid unicode literal so it should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("\\u002X/Compilation error occurs", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002X/Compilation error occurs", StringType) :: Nil) + + // \ u001 is an invalid unicode literal so it should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("\\u001/Compilation error occurs", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 465f7d08aa142..074785eb467d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,7 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types.DataType @@ -189,7 +189,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index aacc56fc44186..90f90965697a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -150,7 +150,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9da1068e9ca1d..55e4f7561e7b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -24,7 +24,8 @@ class CodeFormatterSuite extends SparkFunSuite { def testCase(name: String)(input: String)(expected: String): Unit = { test(name) { - assert(CodeFormatter.format(input).trim === expected.trim) + val sourceCode = new CodeAndComment(input, Map.empty) + assert(CodeFormatter.format(sourceCode).trim === expected.trim) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 47fd7fc1178a9..8a07cee909bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -202,7 +202,7 @@ class SQLContext private[sql]( } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer(conf) @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4d01b78c3c10f..4f6f58bde5c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter} import org.apache.spark.sql.types._ /** @@ -152,7 +152,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) } - val code = s""" + val codeBody = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; import scala.collection.Iterator; @@ -226,6 +226,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } }""" + val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f630eb5a1292b..2be834359089d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2037,4 +2037,269 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) } } + + test("check code injection is prevented") { + // The end of comment (*/) should be escaped. + var literal = + """|*/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + var expected = + """|*/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + // `\u002A` is `*` and `\u002F` is `/` + // so if the end of comment consists of those characters in queries, we need to escape them. + literal = + """|\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|*\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|*\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|*\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|*\\\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + expected = + """|\\\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin.replaceAll("\n", "") + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2fb439f50117a..7d7c39c9d78b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -44,10 +44,12 @@ class PlannerSuite extends SharedSQLContext { fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be four aggregate operator for - // distinct aggregations. + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. There used to be four aggregate operators because single + // distinct aggregate used to trigger DistinctAggregationRewriter rewrite. Now the + // the rewrite only happens when there are multiple distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 4, + aggregations.size == 2 || aggregations.size == 3, s"The plan of query $query does not have partial aggregations.") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index e9b2e20236da1..075b4cb0e23eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -496,12 +496,14 @@ private[hive] class ClientWrapper( // Throw an exception if there is an error in query processing. if (response.getResponseCode != 0) { driver.close() + CommandProcessorFactory.clean(conf) throw new QueryExecutionException(response.getErrorMessage) } driver.setMaxRows(maxRows) val results = shim.getDriverResults(driver) driver.close() + CommandProcessorFactory.clean(conf) results case _ => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 64bff827aead9..d21227a00fb76 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -930,6 +930,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(11) :: Nil) } } + + test("SPARK-14495: distinct aggregate in having clause") { + checkAnswer( + sqlContext.sql( + """ + |select key, count(distinct value1), count(distinct value2) + |from agg2 group by key + |having count(distinct value1) > 0 + """.stripMargin), + Seq( + Row(null, 3, 3), + Row(1, 2, 3), + Row(2, 2, 1) + ) + ) + } }