@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
2222import java .util .concurrent .atomic .AtomicInteger
2323
2424import scala .collection .JavaConverters ._
25+ import scala .collection .mutable
2526import scala .collection .mutable .HashMap
2627import scala .concurrent .duration ._
2728import scala .io .Source
@@ -47,28 +48,37 @@ class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extend
4748 val id = seq.toString
4849 override val rpcEnv : RpcEnv = RpcEnv .create(" worker" , " localhost" , seq,
4950 conf, new SecurityManager (conf))
50- var appRegistered = false
51- def newDriver (): RpcEndpointRef = {
51+ var apps = new mutable.HashMap [String , String ]()
52+ val driverIdToAppId = new mutable.HashMap [String , String ]()
53+ def newDriver (driverId : String ): RpcEndpointRef = {
5254 val name = s " driver_ ${drivers.size}"
5355 rpcEnv.setupEndpoint(name, new RpcEndpoint {
5456 override val rpcEnv : RpcEnv = MockWorker .this .rpcEnv
5557 override def receive : PartialFunction [Any , Unit ] = {
56- case RegisteredApplication (_, _) => appRegistered = true
58+ case RegisteredApplication (appId, _) =>
59+ apps(appId) = appId
60+ driverIdToAppId(driverId) = appId
5761 }
5862 })
5963 }
6064
6165 val appDesc = DeployTestUtils .createAppDesc()
62- val drivers = new HashMap [String , String ]
66+ val drivers = new mutable. HashMap [String , String ]
6367 override def receive : PartialFunction [Any , Unit ] = {
6468 case RegisteredWorker (masterRef, _, _) =>
65- masterRef.send(WorkerLatestState (" 1 " , Nil , drivers.keys.toSeq))
69+ masterRef.send(WorkerLatestState (id , Nil , drivers.keys.toSeq))
6670 case LaunchDriver (driverId, desc) =>
6771 drivers(driverId) = driverId
68- master.send(RegisterApplication (appDesc, newDriver()))
72+ master.send(RegisterApplication (appDesc, newDriver(driverId )))
6973 case KillDriver (driverId) =>
7074 master.send(DriverStateChanged (driverId, DriverState .KILLED , None ))
7175 drivers.remove(driverId)
76+ driverIdToAppId.get(driverId) match {
77+ case Some (appId) =>
78+ apps.remove(appId)
79+ master.send(UnregisterApplication (appId))
80+ }
81+ driverIdToAppId.remove(driverId)
7282 }
7383}
7484
@@ -560,7 +570,7 @@ class MasterSuite extends SparkFunSuite
560570 master.self.askSync[SubmitDriverResponse ](RequestSubmitDriver (driver))
561571
562572 eventually(timeout(10 .seconds)) {
563- assert(! worker1.appRegistered )
573+ assert(worker1.apps.nonEmpty )
564574 }
565575
566576 eventually(timeout(10 .seconds)) {
@@ -580,7 +590,7 @@ class MasterSuite extends SparkFunSuite
580590 " http://localhost:8081" ,
581591 RpcAddress (" localhost" , 10001 )))
582592 eventually(timeout(10 .seconds)) {
583- assert(! worker2.appRegistered )
593+ assert(worker2.apps.nonEmpty )
584594 }
585595
586596 master.self.send(worker1Reg)
@@ -591,6 +601,10 @@ class MasterSuite extends SparkFunSuite
591601 assert(worker.length == 1 )
592602 // make sure the `DriverStateChanged` arrives at Master.
593603 assert(worker(0 ).drivers.isEmpty)
604+ assert(worker1.apps.isEmpty)
605+ assert(worker1.drivers.isEmpty)
606+ assert(worker2.apps.size == 1 )
607+ assert(worker2.drivers.size == 1 )
594608 assert(masterState.activeDrivers.length == 1 )
595609 assert(masterState.activeApps.length == 1 )
596610 }
0 commit comments