@@ -23,7 +23,7 @@ import scala.collection.mutable
2323import scala .collection .mutable .ArrayBuffer
2424
2525import org .mockito .Matchers .{any , anyInt , anyString }
26- import org .mockito .Mockito .{mock , never , spy , verify , when }
26+ import org .mockito .Mockito .{mock , never , spy , verify , when , times }
2727import org .mockito .invocation .InvocationOnMock
2828import org .mockito .stubbing .Answer
2929
@@ -1172,6 +1172,51 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
11721172 assert(blacklistTracker.isNodeBlacklisted(" host1" ))
11731173 }
11741174
1175+ test(" update blacklist before adding pending task to avoid race condition" ) {
1176+ // When a task fails, it should apply the blacklist policy prior to
1177+ // retrying the task otherwise there's a race condition where run on
1178+ // the same executor that it was intended to be black listed from.
1179+ val conf = new SparkConf ().
1180+ set(config.BLACKLIST_ENABLED , true ).
1181+ set(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR , 1 )
1182+
1183+ // Create a task with two executors.
1184+ sc = new SparkContext (" local" , " test" , conf)
1185+ val exec = " executor1"
1186+ val host = " host1"
1187+ val exec2 = " executor2"
1188+ val host2 = " host2"
1189+ sched = new FakeTaskScheduler (sc, (exec, host), (exec2, host2))
1190+ val taskSet = FakeTask .createTaskSet(1 )
1191+
1192+ val clock = new ManualClock
1193+ val mockListenerBus = mock(classOf [LiveListenerBus ])
1194+ val blacklistTracker = new BlacklistTracker (mockListenerBus, conf, None , clock)
1195+ val taskSetManager = new TaskSetManager (sched, taskSet, 1 , Some (blacklistTracker))
1196+ val taskSetManagerSpy = spy(taskSetManager)
1197+
1198+ val taskDesc = taskSetManagerSpy.resourceOffer(exec, host, TaskLocality .ANY )
1199+
1200+ // Assert the task has been black listed on the executor it was last executed on.
1201+ when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer(
1202+ new Answer [Unit ] {
1203+ override def answer (invocationOnMock : InvocationOnMock ): Unit = {
1204+ val task = invocationOnMock.getArgumentAt(0 , classOf [Int ])
1205+ assert(taskSetManager.taskSetBlacklistHelperOpt.get.
1206+ isExecutorBlacklistedForTask(exec, task))
1207+ null
1208+ }
1209+ }
1210+ )
1211+
1212+ // Simulate an out of memory error
1213+ val e = new OutOfMemoryError
1214+ taskSetManagerSpy.handleFailedTask(
1215+ taskDesc.get.taskId, TaskState .FAILED , new ExceptionFailure (e, Seq ()))
1216+
1217+ verify(taskSetManagerSpy, times(1 )).addPendingTask(anyInt())
1218+ }
1219+
11751220 private def createTaskResult (
11761221 id : Int ,
11771222 accumUpdates : Seq [AccumulatorV2 [_, _]] = Seq .empty): DirectTaskResult [Int ] = {
0 commit comments