Skip to content

Commit 21bd1c3

Browse files
author
liyuanjian
committed
[SPARK-25017][Core] Add test suite for BarrierCoordinator and ContextBarrierState
1 parent ea63a7a commit 21bd1c3

2 files changed

Lines changed: 156 additions & 3 deletions

File tree

core/src/main/scala/org/apache/spark/BarrierCoordinator.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ private[spark] class BarrierCoordinator(
6565

6666
// Record all active stage attempts that make barrier() call(s), and the corresponding internal
6767
// state.
68-
private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState]
68+
private[spark] val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState]
6969

7070
override def onStart(): Unit = {
7171
super.onStart()
@@ -90,14 +90,14 @@ private[spark] class BarrierCoordinator(
9090
* @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall
9191
* collect `numTasks` requests to succeed.
9292
*/
93-
private class ContextBarrierState(
93+
private[spark] class ContextBarrierState(
9494
val barrierId: ContextBarrierId,
9595
val numTasks: Int) {
9696

9797
// There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
9898
// to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
9999
// reset when a barrier() call fails due to timeout.
100-
private var barrierEpoch: Int = 0
100+
private[spark] var barrierEpoch: Int = 0
101101

102102
// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
103103
// call.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.scheduler
19+
20+
import java.util.concurrent.TimeoutException
21+
22+
import scala.concurrent.duration._
23+
import scala.language.postfixOps
24+
25+
import org.apache.spark._
26+
import org.apache.spark.rpc.RpcTimeout
27+
28+
class BarrierCoordinatorSuite extends SparkFunSuite with LocalSparkContext {
29+
30+
/**
31+
* Get the current barrierEpoch from barrierCoordinator.states by ContextBarrierId
32+
*/
33+
def getCurrentBarrierEpoch(
34+
stageId: Int, stageAttemptId: Int, barrierCoordinator: BarrierCoordinator): Int = {
35+
val barrierId = ContextBarrierId(stageId, stageAttemptId)
36+
barrierCoordinator.states.get(barrierId).barrierEpoch
37+
}
38+
39+
test("normal test for single task") {
40+
sc = new SparkContext("local", "test")
41+
val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv)
42+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
43+
val stageId = 0
44+
val stageAttemptNumber = 0
45+
rpcEndpointRef.askSync[Unit](
46+
message = RequestToSync(numTasks = 1, stageId, stageAttemptNumber, taskAttemptId = 0,
47+
barrierEpoch = 0),
48+
timeout = new RpcTimeout(5 seconds, "rpcTimeOut"))
49+
// sleep for waiting barrierEpoch value change
50+
Thread.sleep(500)
51+
assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1)
52+
}
53+
54+
test("normal test for multi tasks") {
55+
sc = new SparkContext("local", "test")
56+
val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv)
57+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
58+
val numTasks = 3
59+
val stageId = 0
60+
val stageAttemptNumber = 0
61+
val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut")
62+
// sync request from 3 tasks
63+
(0 until numTasks).foreach { taskId =>
64+
new Thread(s"task-$taskId-thread") {
65+
setDaemon(true)
66+
override def run(): Unit = {
67+
rpcEndpointRef.askSync[Unit](
68+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId,
69+
barrierEpoch = 0),
70+
timeout = rpcTimeOut)
71+
}
72+
}.start()
73+
}
74+
// sleep for waiting barrierEpoch value change
75+
Thread.sleep(500)
76+
assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1)
77+
}
78+
79+
test("abnormal test for syncing with illegal barrierId") {
80+
sc = new SparkContext("local", "test")
81+
val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv)
82+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
83+
val numTasks = 3
84+
val stageId = 0
85+
val stageAttemptNumber = 0
86+
val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut")
87+
intercept[SparkException](
88+
rpcEndpointRef.askSync[Unit](
89+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0,
90+
barrierEpoch = -1), // illegal barrierId = -1
91+
timeout = rpcTimeOut))
92+
}
93+
94+
test("abnormal test for syncing with old barrierId") {
95+
sc = new SparkContext("local", "test")
96+
val barrierCoordinator = new BarrierCoordinator(5, sc.listenerBus, sc.env.rpcEnv)
97+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
98+
val numTasks = 3
99+
val stageId = 0
100+
val stageAttemptNumber = 0
101+
val rpcTimeOut = new RpcTimeout(5 seconds, "rpcTimeOut")
102+
// sync request from 3 tasks
103+
(0 until numTasks).foreach { taskId =>
104+
new Thread(s"task-$taskId-thread") {
105+
setDaemon(true)
106+
override def run(): Unit = {
107+
rpcEndpointRef.askSync[Unit](
108+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = taskId,
109+
barrierEpoch = 0),
110+
timeout = rpcTimeOut)
111+
}
112+
}.start()
113+
}
114+
// sleep for waiting barrierEpoch value change
115+
Thread.sleep(500)
116+
assert(getCurrentBarrierEpoch(stageId, stageAttemptNumber, barrierCoordinator) == 1)
117+
intercept[SparkException](
118+
rpcEndpointRef.askSync[Unit](
119+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0,
120+
barrierEpoch = 0),
121+
timeout = rpcTimeOut))
122+
}
123+
124+
test("abnormal test for timeout when rpcTimeOut < barrierTimeOut") {
125+
sc = new SparkContext("local", "test")
126+
val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv)
127+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
128+
val numTasks = 3
129+
val stageId = 0
130+
val stageAttemptNumber = 0
131+
val rpcTimeOut = new RpcTimeout(1 seconds, "rpcTimeOut")
132+
intercept[TimeoutException](
133+
rpcEndpointRef.askSync[Unit](
134+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0,
135+
barrierEpoch = 0),
136+
timeout = rpcTimeOut))
137+
}
138+
139+
test("abnormal test for timeout when rpcTimeOut > barrierTimeOut") {
140+
sc = new SparkContext("local", "test")
141+
val barrierCoordinator = new BarrierCoordinator(2, sc.listenerBus, sc.env.rpcEnv)
142+
val rpcEndpointRef = sc.env.rpcEnv.setupEndpoint("barrierCoordinator", barrierCoordinator)
143+
val numTasks = 3
144+
val stageId = 0
145+
val stageAttemptNumber = 0
146+
val rpcTimeOut = new RpcTimeout(4 seconds, "rpcTimeOut")
147+
intercept[SparkException](
148+
rpcEndpointRef.askSync[Unit](
149+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId = 0,
150+
barrierEpoch = 0),
151+
timeout = rpcTimeOut))
152+
}
153+
}

0 commit comments

Comments
 (0)