Skip to content

Commit da0f977

Browse files
committed
Add some regression tests.
1 parent 6ab9a0f commit da0f977

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
2222
import java.util.concurrent.atomic.AtomicInteger
2323

2424
import scala.collection.JavaConverters._
25+
import scala.collection.mutable
2526
import scala.collection.mutable.HashMap
2627
import scala.concurrent.duration._
2728
import scala.io.Source
@@ -47,28 +48,37 @@ class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extend
4748
val id = seq.toString
4849
override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq,
4950
conf, new SecurityManager(conf))
50-
var appRegistered = false
51-
def newDriver(): RpcEndpointRef = {
51+
var apps = new mutable.HashMap[String, String]()
52+
val driverIdToAppId = new mutable.HashMap[String, String]()
53+
def newDriver(driverId: String): RpcEndpointRef = {
5254
val name = s"driver_${drivers.size}"
5355
rpcEnv.setupEndpoint(name, new RpcEndpoint {
5456
override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv
5557
override def receive: PartialFunction[Any, Unit] = {
56-
case RegisteredApplication(_, _) => appRegistered = true
58+
case RegisteredApplication(appId, _) =>
59+
apps(appId) = appId
60+
driverIdToAppId(driverId) = appId
5761
}
5862
})
5963
}
6064

6165
val appDesc = DeployTestUtils.createAppDesc()
62-
val drivers = new HashMap[String, String]
66+
val drivers = new mutable.HashMap[String, String]
6367
override def receive: PartialFunction[Any, Unit] = {
6468
case RegisteredWorker(masterRef, _, _) =>
65-
masterRef.send(WorkerLatestState("1", Nil, drivers.keys.toSeq))
69+
masterRef.send(WorkerLatestState(id, Nil, drivers.keys.toSeq))
6670
case LaunchDriver(driverId, desc) =>
6771
drivers(driverId) = driverId
68-
master.send(RegisterApplication(appDesc, newDriver()))
72+
master.send(RegisterApplication(appDesc, newDriver(driverId)))
6973
case KillDriver(driverId) =>
7074
master.send(DriverStateChanged(driverId, DriverState.KILLED, None))
7175
drivers.remove(driverId)
76+
driverIdToAppId.get(driverId) match {
77+
case Some(appId) =>
78+
apps.remove(appId)
79+
master.send(UnregisterApplication(appId))
80+
}
81+
driverIdToAppId.remove(driverId)
7282
}
7383
}
7484

@@ -560,7 +570,7 @@ class MasterSuite extends SparkFunSuite
560570
master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver))
561571

562572
eventually(timeout(10.seconds)) {
563-
assert(!worker1.appRegistered)
573+
assert(worker1.apps.nonEmpty)
564574
}
565575

566576
eventually(timeout(10.seconds)) {
@@ -580,7 +590,7 @@ class MasterSuite extends SparkFunSuite
580590
"http://localhost:8081",
581591
RpcAddress("localhost", 10001)))
582592
eventually(timeout(10.seconds)) {
583-
assert(!worker2.appRegistered)
593+
assert(worker2.apps.nonEmpty)
584594
}
585595

586596
master.self.send(worker1Reg)
@@ -591,6 +601,10 @@ class MasterSuite extends SparkFunSuite
591601
assert(worker.length == 1)
592602
// make sure the `DriverStateChanged` arrives at Master.
593603
assert(worker(0).drivers.isEmpty)
604+
assert(worker1.apps.isEmpty)
605+
assert(worker1.drivers.isEmpty)
606+
assert(worker2.apps.size == 1)
607+
assert(worker2.drivers.size == 1)
594608
assert(masterState.activeDrivers.length == 1)
595609
assert(masterState.activeApps.length == 1)
596610
}

0 commit comments

Comments
 (0)