Skip to content

Commit e5ccac2

Browse files
advancedxyMarcelo Vanzin
authored andcommitted
[SPARK-22897][CORE] Expose stageAttemptId in TaskContext
stageAttemptId added in TaskContext and corresponding construction modification Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE <advancedxy@gmail.com> Closes apache#20082 from advancedxy/SPARK-22897. (cherry picked from commit a6fc300) Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
1 parent 77d11df commit e5ccac2

12 files changed

Lines changed: 51 additions & 10 deletions

File tree

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ object TaskContext {
6565
* An empty task context that does not represent an actual task. This is only used in tests.
6666
*/
6767
private[spark] def empty(): TaskContextImpl = {
68-
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
68+
new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
6969
}
7070
}
7171

@@ -145,6 +145,13 @@ abstract class TaskContext extends Serializable {
145145
*/
146146
def stageId(): Int
147147

148+
/**
149+
* How many times the stage that this task belongs to has been attempted. The first stage attempt
150+
* will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt
151+
* numbers.
152+
*/
153+
def stageAttemptNumber(): Int
154+
148155
/**
149156
* The ID of the RDD partition that is computed by this task.
150157
*/

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ import org.apache.spark.metrics.source.Source
2929
import org.apache.spark.util._
3030

3131
private[spark] class TaskContextImpl(
32-
val stageId: Int,
33-
val partitionId: Int,
32+
override val stageId: Int,
33+
override val stageAttemptNumber: Int,
34+
override val partitionId: Int,
3435
override val taskAttemptId: Long,
3536
override val attemptNumber: Int,
3637
override val taskMemoryManager: TaskMemoryManager,

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ private[spark] abstract class Task[T](
7878
SparkEnv.get.blockManager.registerTask(taskAttemptId)
7979
context = new TaskContextImpl(
8080
stageId,
81+
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
8182
partitionId,
8283
taskAttemptId,
8384
attemptNumber,

core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static void test() {
3838
tc.attemptNumber();
3939
tc.partitionId();
4040
tc.stageId();
41+
tc.stageAttemptNumber();
4142
tc.taskAttemptId();
4243
}
4344

@@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) {
5152
context.isCompleted();
5253
context.isInterrupted();
5354
context.stageId();
55+
context.stageAttemptNumber();
5456
context.partitionId();
5557
context.addTaskCompletionListener(this);
5658
}

core/src/test/scala/org/apache/spark/ShuffleSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
336336

337337
// first attempt -- its successful
338338
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
339-
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
339+
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
340340
val data1 = (1 to 10).map { x => x -> x}
341341

342342
// second attempt -- also successful. We'll write out different data,
343343
// just to simulate the fact that the records may get written differently
344344
// depending on what gets spilled, what gets combined, etc.
345345
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
346-
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
346+
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
347347
val data2 = (11 to 20).map { x => x -> x}
348348

349349
// interleave writes of both attempts -- we want to test that both attempts can occur
@@ -371,7 +371,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
371371
}
372372

373373
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
374-
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
374+
new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
375375
val readData = reader.read().toIndexedSeq
376376
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
377377

core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
2929
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
3030
new TaskContextImpl(
3131
stageId = 0,
32+
stageAttemptNumber = 0,
3233
partitionId = 0,
3334
taskAttemptId = 0,
3435
attemptNumber = 0,

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
2929
import org.apache.spark.metrics.source.JvmSource
3030
import org.apache.spark.network.util.JavaUtils
3131
import org.apache.spark.rdd.RDD
32+
import org.apache.spark.shuffle.FetchFailedException
3233
import org.apache.spark.util._
3334

3435
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
@@ -143,6 +144,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
143144
assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
144145
}
145146

147+
test("TaskContext.stageAttemptNumber getter") {
148+
sc = new SparkContext("local[1,2]", "test")
149+
150+
// Check stageAttemptNumbers are 0 for initial stage
151+
val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
152+
Seq(TaskContext.get().stageAttemptNumber()).iterator
153+
}.collect()
154+
assert(stageAttemptNumbers.toSet === Set(0))
155+
156+
// Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException
157+
val stageAttemptNumbersWithFailedStage =
158+
sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
159+
val stageAttemptNumber = TaskContext.get().stageAttemptNumber()
160+
if (stageAttemptNumber < 2) {
161+
// Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception
162+
// will only trigger task resubmission in the same stage.
163+
throw new FetchFailedException(null, 0, 0, 0, "Fake")
164+
}
165+
Seq(stageAttemptNumber).iterator
166+
}.collect()
167+
168+
assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
169+
}
170+
146171
test("accumulators are updated on exception failures") {
147172
// This means use 1 core and 4 max task failures
148173
sc = new SparkContext("local[1,4]", "test")
@@ -175,7 +200,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
175200
// accumulator updates from it.
176201
val taskMetrics = TaskMetrics.empty
177202
val task = new Task[Int](0, 0, 0) {
178-
context = new TaskContextImpl(0, 0, 0L, 0,
203+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
179204
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
180205
new Properties,
181206
SparkEnv.get.metricsSystem,
@@ -198,7 +223,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
198223
// accumulator updates from it.
199224
val taskMetrics = TaskMetrics.empty
200225
val task = new Task[Int](0, 0, 0) {
201-
context = new TaskContextImpl(0, 0, 0L, 0,
226+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
202227
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
203228
new Properties,
204229
SparkEnv.get.metricsSystem,

core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
6262
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
6363
try {
6464
TaskContext.setTaskContext(
65-
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
65+
new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
6666
block
6767
} finally {
6868
TaskContext.unset()

project/MimaExcludes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ object MimaExcludes {
3737
// Exclude rules for 2.1.x
3838
lazy val v21excludes = v20excludes ++ {
3939
Seq(
40+
// [SPARK-22897] Expose stageAttemptId in TaskContext
41+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"),
4042
// [SPARK-19652][UI] Do auth checks for REST API access.
4143
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"),
4244
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"),

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class UnsafeFixedWidthAggregationMapSuite
6969

7070
TaskContext.setTaskContext(new TaskContextImpl(
7171
stageId = 0,
72+
stageAttemptNumber = 0,
7273
partitionId = 0,
7374
taskAttemptId = Random.nextInt(10000),
7475
attemptNumber = 0,

0 commit comments

Comments
 (0)