@@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}
2727private sealed trait OutputCommitCoordinationMessage extends Serializable
2828
2929private case object StopCoordinator extends OutputCommitCoordinationMessage
30- private case class AskPermissionToCommitOutput (stage : Int , partition : Int , attemptNumber : Int )
30+ private case class AskPermissionToCommitOutput (
31+ stage : Int ,
32+ stageAttempt : Int ,
33+ partition : Int ,
34+ attemptNumber : Int )
3135
3236/**
3337 * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4549 // Initialized by SparkEnv
4650 var coordinatorRef : Option [RpcEndpointRef ] = None
4751
48- private type StageId = Int
49- private type PartitionId = Int
50- private type TaskAttemptNumber = Int
51- private val NO_AUTHORIZED_COMMITTER : TaskAttemptNumber = - 1
52+ // Class used to identify a committer. The task ID for a committer is implicitly defined by
53+ // the partition being processed, but the coordinator need to keep track of both the stage
54+ // attempt and the task attempt, because in some situations the same task may be running
55+ // concurrently in two different attempts of the same stage.
56+ private case class TaskIdentifier (stageAttempt : Int , taskAttempt : Int )
57+
5258 private case class StageState (numPartitions : Int ) {
53- val authorizedCommitters = Array .fill[TaskAttemptNumber ](numPartitions)(NO_AUTHORIZED_COMMITTER )
54- val failures = mutable.Map [PartitionId , mutable.Set [TaskAttemptNumber ]]()
59+ val authorizedCommitters = Array .fill[TaskIdentifier ](numPartitions)(null )
60+ val failures = mutable.Map [Int , mutable.Set [TaskIdentifier ]]()
5561 }
5662
5763 /**
@@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
6470 *
6571 * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
6672 */
67- private val stageStates = mutable.Map [StageId , StageState ]()
73+ private val stageStates = mutable.Map [Int , StageState ]()
6874
6975 /**
7076 * Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
8793 * @return true if this task is authorized to commit, false otherwise
8894 */
8995 def canCommit (
90- stage : StageId ,
91- partition : PartitionId ,
92- attemptNumber : TaskAttemptNumber ): Boolean = {
93- val msg = AskPermissionToCommitOutput (stage, partition, attemptNumber)
96+ stage : Int ,
97+ stageAttempt : Int ,
98+ partition : Int ,
99+ attemptNumber : Int ): Boolean = {
100+ val msg = AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber)
94101 coordinatorRef match {
95102 case Some (endpointRef) =>
96103 ThreadUtils .awaitResult(endpointRef.ask[Boolean ](msg),
@@ -109,20 +116,21 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
109116 * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
110117 * the maximum possible value of `context.partitionId`).
111118 */
112- private [scheduler] def stageStart (stage : StageId , maxPartitionId : Int ): Unit = synchronized {
119+ private [scheduler] def stageStart (stage : Int , maxPartitionId : Int ): Unit = synchronized {
113120 stageStates(stage) = new StageState (maxPartitionId + 1 )
114121 }
115122
116123 // Called by DAGScheduler
117- private [scheduler] def stageEnd (stage : StageId ): Unit = synchronized {
124+ private [scheduler] def stageEnd (stage : Int ): Unit = synchronized {
118125 stageStates.remove(stage)
119126 }
120127
121128 // Called by DAGScheduler
122129 private [scheduler] def taskCompleted (
123- stage : StageId ,
124- partition : PartitionId ,
125- attemptNumber : TaskAttemptNumber ,
130+ stage : Int ,
131+ stageAttempt : Int ,
132+ partition : Int ,
133+ attemptNumber : Int ,
126134 reason : TaskEndReason ): Unit = synchronized {
127135 val stageState = stageStates.getOrElse(stage, {
128136 logDebug(s " Ignoring task completion for completed stage " )
@@ -132,15 +140,16 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
132140 case Success =>
133141 // The task output has been committed successfully
134142 case denied : TaskCommitDenied =>
135- logInfo(s " Task was denied committing, stage: $stage, partition: $partition , " +
136- s " attempt: $attemptNumber" )
143+ logInfo(s " Task was denied committing, stage: $stage / $stageAttempt , " +
144+ s " partition: $partition , attempt: $attemptNumber" )
137145 case otherReason =>
138146 // Mark the attempt as failed to blacklist from future commit protocol
139- stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += attemptNumber
140- if (stageState.authorizedCommitters(partition) == attemptNumber) {
147+ val taskId = TaskIdentifier (stageAttempt, attemptNumber)
148+ stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += taskId
149+ if (stageState.authorizedCommitters(partition) == taskId) {
141150 logDebug(s " Authorized committer (attemptNumber= $attemptNumber, stage= $stage, " +
142151 s " partition= $partition) failed; clearing lock " )
143- stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
152+ stageState.authorizedCommitters(partition) = null
144153 }
145154 }
146155 }
@@ -155,47 +164,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
155164
156165 // Marked private[scheduler] instead of private so this can be mocked in tests
157166 private [scheduler] def handleAskPermissionToCommit (
158- stage : StageId ,
159- partition : PartitionId ,
160- attemptNumber : TaskAttemptNumber ): Boolean = synchronized {
167+ stage : Int ,
168+ stageAttempt : Int ,
169+ partition : Int ,
170+ attemptNumber : Int ): Boolean = synchronized {
161171 stageStates.get(stage) match {
162- case Some (state) if attemptFailed(state, partition, attemptNumber) =>
163- logInfo(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
164- s " partition= $partition as task attempt $attemptNumber has already failed. " )
172+ case Some (state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
173+ logInfo(s " Commit denied for stage= $stage/ $attemptNumber , partition= $partition : " +
174+ s " task attempt $attemptNumber already marked as failed. " )
165175 false
166176 case Some (state) =>
167- state.authorizedCommitters(partition) match {
168- case NO_AUTHORIZED_COMMITTER =>
169- logDebug(s " Authorizing attemptNumber= $attemptNumber to commit for stage= $stage, " +
170- s " partition= $partition" )
171- state.authorizedCommitters(partition) = attemptNumber
172- true
173- case existingCommitter =>
174- // Coordinator should be idempotent when receiving AskPermissionToCommit.
175- if (existingCommitter == attemptNumber) {
176- logWarning(s " Authorizing duplicate request to commit for " +
177- s " attemptNumber= $attemptNumber to commit for stage= $stage, " +
178- s " partition= $partition; existingCommitter = $existingCommitter. " +
179- s " This can indicate dropped network traffic. " )
180- true
181- } else {
182- logDebug(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
183- s " partition= $partition; existingCommitter = $existingCommitter" )
184- false
185- }
177+ val existing = state.authorizedCommitters(partition)
178+ if (existing == null ) {
179+ logDebug(s " Commit allowed for stage= $stage/ $attemptNumber, partition= $partition: " +
180+ s " task attempt $attemptNumber" )
181+ state.authorizedCommitters(partition) = TaskIdentifier (stageAttempt, attemptNumber)
182+ true
183+ } else {
184+ logDebug(s " Commit denied for stage= $stage/ $attemptNumber, partition= $partition: " +
185+ s " already committed by $existing" )
186+ false
186187 }
187188 case None =>
188- logDebug(s " Stage $stage has completed, so not allowing " +
189- s " attempt number $attemptNumber of partition $partition to commit " )
189+ logDebug(s " Commit denied for stage= $stage / $attemptNumber , partition= $partition : " +
190+ " stage already marked as completed. " )
190191 false
191192 }
192193 }
193194
194195 private def attemptFailed (
195196 stageState : StageState ,
196- partition : PartitionId ,
197- attempt : TaskAttemptNumber ): Boolean = synchronized {
198- stageState.failures.get(partition).exists(_.contains(attempt))
197+ stageAttempt : Int ,
198+ partition : Int ,
199+ attempt : Int ): Boolean = synchronized {
200+ val failInfo = TaskIdentifier (stageAttempt, attempt)
201+ stageState.failures.get(partition).exists(_.contains(failInfo))
199202 }
200203}
201204
@@ -215,9 +218,10 @@ private[spark] object OutputCommitCoordinator {
215218 }
216219
217220 override def receiveAndReply (context : RpcCallContext ): PartialFunction [Any , Unit ] = {
218- case AskPermissionToCommitOutput (stage, partition, attemptNumber) =>
221+ case AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber) =>
219222 context.reply(
220- outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
223+ outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
224+ attemptNumber))
221225 }
222226 }
223227}
0 commit comments