Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
*/
private[spark] var mergerLocs: Seq[BlockManagerId] = Nil

/**
* Stores the information about whether the shuffle merge is finalized for the shuffle map stage
* associated with this shuffle dependency
*/
private[this] var shuffleMergedFinalized: Boolean = false

def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
if (mergerLocs != null) {
this.mergerLocs = mergerLocs
Expand All @@ -110,6 +116,12 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](

def getMergerLocs: Seq[BlockManagerId] = mergerLocs

def markShuffleMergeFinalized: Unit = {
shuffleMergedFinalized = true
}

def shuffleMergeFinalized : Boolean = shuffleMergedFinalized

_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,25 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT =
ConfigBuilder("spark.shuffle.push.merge.results.timeout")
.doc("Specify the max amount of time DAGScheduler waits for the merge results from " +
"all remote shuffle services for a given shuffle. DAGScheduler will start to submit " +
"following stages if not all results are received within the timeout.")
.version("3.1.0")
.stringConf
.createWithDefault("10s")

private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT =
ConfigBuilder("spark.shuffle.push.merge.finalize.timeout")
.doc("Specify the amount of time DAGScheduler waits after all mappers finish for " +
"a given shuffle map stage before it starts sending merge finalize requests to " +
"remote shuffle services. This allows the shuffle services some extra time to " +
"merge as many blocks as possible.")
.version("3.1.0")
.stringConf
.createWithDefault("10s")

private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
.doc("Maximum number of shuffle push merger locations cached for push based shuffle. " +
Expand Down Expand Up @@ -2108,7 +2127,7 @@ package object config {
s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " +
"at least 50 mergers to enable push based shuffle for that stage.")
.version("3.1.0")
.doubleConf
.intConf
.createWithDefault(5)

private[spark] val SHUFFLE_NUM_PUSH_THREADS =
Expand Down
229 changes: 185 additions & 44 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.scheduler

import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}

import scala.annotation.tailrec
import scala.collection.Map
Expand All @@ -29,12 +29,17 @@ import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.duration._
import scala.util.control.NonFatal

import com.google.common.util.concurrent.{Futures, SettableFuture}

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.{ExternalBlockStoreClient, MergeFinalizerListener}
import org.apache.spark.network.shuffle.protocol.MergeStatuses
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
Expand Down Expand Up @@ -254,6 +259,28 @@ private[spark] class DAGScheduler(
private val blockManagerMasterDriverHeartbeatTimeout =
sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis

private val shuffleMergeResultsTimeoutSec =
JavaUtils.timeStringAsSec(sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT))

private val shuffleMergeFinalizeWaitSec =
JavaUtils.timeStringAsSec(sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT))

// lazy initialized so that the shuffle client can be properly initialized
private lazy val externalShuffleClient: Option[ExternalBlockStoreClient] =
if (pushBasedShuffleEnabled) {
val transConf = SparkTransportConf.fromSparkConf(sc.conf, "shuffle", 1)
val shuffleClient = new ExternalBlockStoreClient(transConf, env.securityManager,
env.securityManager.isAuthenticationEnabled(),
sc.conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))
shuffleClient.init(sc.conf.getAppId)
Some(shuffleClient)
} else {
None
}

private val shuffleMergeFinalizeScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-merge-finalizer")

/**
* Called by the TaskSetManager to report task's starting.
*/
Expand Down Expand Up @@ -689,7 +716,7 @@ private[spark] class DAGScheduler(
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
if (!mapStage.isAvailable || !mapStage.isMergeFinalized) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
Expand Down Expand Up @@ -1271,21 +1298,19 @@ private[spark] class DAGScheduler(
* locations for block push/merge by getting the historical locations of past executors.
*/
private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = {
// TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize
// TODO changes we cannot disable shuffle merge for the retry/reuse cases
val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)

if (mergerLocs.nonEmpty) {
stage.shuffleDep.setMergerLocs(mergerLocs)
logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
s" ${stage.shuffleDep.getMergerLocs.size} merger locations")

logDebug("List of shuffle push merger locations " +
s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
} else {
logInfo("No available merger locations." +
s" Push-based shuffle disabled for $stage (${stage.name})")
if (!stage.shuffleDep.shuffleMergeFinalized) {
val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
if (mergerLocs.nonEmpty) {
stage.shuffleDep.setMergerLocs(mergerLocs)
logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
s" ${stage.shuffleDep.getMergerLocs.size} merger locations")

logDebug("List of shuffle push merger locations " +
s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
} else {
logInfo("Push-based shuffle disabled for $stage (${stage.name})")
}
}
}

Expand Down Expand Up @@ -1678,33 +1703,10 @@ private[spark] class DAGScheduler(
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
logInfo("running: " + runningStages)
logInfo("waiting: " + waitingStages)
logInfo("failed: " + failedStages)

// This call to increment the epoch may not be strictly necessary, but it is retained
// for now in order to minimize the changes in behavior from an earlier version of the
// code. This existing behavior of always incrementing the epoch following any
// successful shuffle map stage completion may have benefits by causing unneeded
// cached map outputs to be cleaned up earlier on executors. In the future we can
// consider removing this call, but this will require some extra investigation.
// See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
mapOutputTracker.incrementEpoch()

clearCacheLocs()

if (!shuffleStage.isAvailable) {
// Some tasks had failed; let's resubmit this shuffleStage.
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
") because some of its tasks had failed: " +
shuffleStage.findMissingPartitions().mkString(", "))
submitStage(shuffleStage)
if (pushBasedShuffleEnabled) {
scheduleShuffleMergeFinalize(shuffleStage)
} else {
markMapStageJobsAsFinished(shuffleStage)
submitWaitingChildStages(shuffleStage)
processShuffleMapStageCompletion(shuffleStage)
}
}
}
Expand Down Expand Up @@ -2004,6 +2006,142 @@ private[spark] class DAGScheduler(
}
}

/**
* Schedules shuffle merge finalize.
*/
private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): Unit = {
// TODO Use the default single threaded scheduler or extend ThreadUtils to
// TODO support the multi-threaded scheduler?
logInfo(("%s (%s) scheduled for finalizing" +
" shuffle merge in %s s").format(stage, stage.name, shuffleMergeFinalizeWaitSec))
shuffleMergeFinalizeScheduler.schedule(
new Runnable {
override def run(): Unit = finalizeShuffleMerge(stage)
},
shuffleMergeFinalizeWaitSec,
TimeUnit.SECONDS
)
}

/**
* DAGScheduler notifies all the remote shuffle services chosen to serve shuffle merge request for
* the given shuffle map stage to finalize the shuffle merge process for this shuffle. This is
* invoked in a separate thread to reduce the impact on the DAGScheduler main thread, as the
* scheduler might need to talk to 1000s of shuffle services to finalize shuffle merge.
*/
private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
externalShuffleClient.foreach { shuffleClient =>
val shuffleId = stage.shuffleDep.shuffleId
val numMergers = stage.shuffleDep.getMergerLocs.length
val numResponses = new AtomicInteger()
val results = (0 until numMergers).map(_ => SettableFuture.create[Boolean]())
val timedOut = new AtomicBoolean()

// NOTE: This is a defensive check to post finalize event if numMergers is 0 (i.e. no shuffle
// service available).
if (numMergers == 0) {
eventProcessLoop.post(ShuffleMergeFinalized(stage))
return
}

def increaseAndCheckResponseCount: Unit = {
if (numResponses.incrementAndGet() == numMergers) {
// Since this runs in the netty client thread and is outside of DAGScheduler
// event loop, we only post ShuffleMergeFinalized event into the event queue.
// The processing of this event should be done inside the event loop, so it
// can safely modify scheduler's internal state.
logInfo("%s (%s) shuffle merge finalized".format(stage, stage.name))
eventProcessLoop.post(ShuffleMergeFinalized(stage))
}
}

stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
case (shuffleServiceLoc, index) =>
// Sends async request to shuffle service to finalize shuffle merge on that host
shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
shuffleServiceLoc.port, shuffleId,
new MergeFinalizerListener {
override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = {
assert(shuffleId == statuses.shuffleId)
// Register the merge results even if already timed out, in case the reducer
// needing this merged block starts after dag scheduler receives this response.
mapOutputTracker.registerMergeResults(statuses.shuffleId,
MergeStatus.convertMergeStatusesToMergeStatusArr(statuses, shuffleServiceLoc))
if (!timedOut.get()) {
increaseAndCheckResponseCount
results(index).set(true)
}
}

override def onShuffleMergeFailure(e: Throwable): Unit = {
if (!timedOut.get()) {
logWarning(s"Exception encountered when trying to finalize shuffle " +
s"merge on ${shuffleServiceLoc.host} for shuffle $shuffleId", e)
increaseAndCheckResponseCount
// Do not fail the future as this would cause dag scheduler to prematurely
// give up on waiting for merge results from the remaining shuffle services
// if one fails
results(index).set(false)
}
}
})
}
// DAGScheduler only waits for a limited amount of time for the merge results.
// It will attempt to submit the next stage(s) irrespective of whether merge results
// from all shuffle services are received or not.
// TODO what are the reasonable configurations for the 2 timeouts? When # mappers
// TODO and # reducers for a shuffle is really large, and if the merge ratio is not
// TODO high enough, the MergeStatuses to be retrieved from 1 shuffle service could
// TODO be pretty large (10s MB to 100s MB). How to properly handle this scenario?
try {
Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec, TimeUnit.SECONDS)
} catch {
case _: TimeoutException =>
logInfo(s"Timed out on waiting for merge results from all " +
s"$numMergers mergers for shuffle $shuffleId")
timedOut.set(true)
eventProcessLoop.post(ShuffleMergeFinalized(stage))
}
}
}

private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review note: no changes here. Method extracted from handleTaskCompletion

markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
logInfo("running: " + runningStages)
logInfo("waiting: " + waitingStages)
logInfo("failed: " + failedStages)

// This call to increment the epoch may not be strictly necessary, but it is retained
// for now in order to minimize the changes in behavior from an earlier version of the
// code. This existing behavior of always incrementing the epoch following any
// successful shuffle map stage completion may have benefits by causing unneeded
// cached map outputs to be cleaned up earlier on executors. In the future we can
// consider removing this call, but this will require some extra investigation.
// See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
mapOutputTracker.incrementEpoch()

clearCacheLocs()

if (!shuffleStage.isAvailable) {
// Some tasks had failed; let's resubmit this shuffleStage.
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
") because some of its tasks had failed: " +
shuffleStage.findMissingPartitions().mkString(", "))
submitStage(shuffleStage)
} else {
markMapStageJobsAsFinished(shuffleStage)
submitWaitingChildStages(shuffleStage)
}
}

private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage): Unit = {
stage.shuffleDep.markShuffleMergeFinalized
processShuffleMapStageCompletion(stage)
}

private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = {
logInfo(s"Resubmitted $task, so marking it as still running.")
stage match {
Expand Down Expand Up @@ -2451,6 +2589,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler

case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()

case ShuffleMergeFinalized(stage) =>
dagScheduler.handleShuffleMergeFinalized(stage)
}

override def onError(e: Throwable): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ private[scheduler]
case class UnschedulableTaskSetRemoved(stageId: Int, stageAttemptId: Int)
extends DAGSchedulerEvent

private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
extends DAGSchedulerEvent
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,17 @@ private[spark] class ShuffleMapStage(
.findMissingPartitions(shuffleDep.shuffleId)
.getOrElse(0 until numPartitions)
}

/**
* Returns true if push based shuffle is disabled for this stage, or if the shuffle merge for
* this stage is finalized, i.e. the shuffle merge results for all partitions are available.
*/
def isMergeFinalized: Boolean = {
// EmptyRDD should not be computed
if (numPartitions > 0 && shuffleDep.mergerLocs.nonEmpty) {
shuffleDep.shuffleMergeFinalized
} else {
true
}
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
org.apache.spark.scheduler.DummyExternalClusterManager
org.apache.spark.scheduler.MockExternalClusterManager
org.apache.spark.scheduler.CSMockExternalClusterManager
org.apache.spark.scheduler.PushBasedClusterManager
Loading