Skip to content

Commit 5753ee0

Browse files
committed
[SPARK-22897][CORE]: Expose stageAttemptId in TaskContext
1 parent 0e68330 commit 5753ee0

File tree

11 files changed

+42
-9
lines changed

11 files changed

+42
-9
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object TaskContext {
6666
* An empty task context that does not represent an actual task. This is only used in tests.
6767
*/
6868
private[spark] def empty(): TaskContextImpl = {
69-
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
69+
new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
7070
}
7171
}
7272

@@ -150,6 +150,10 @@ abstract class TaskContext extends Serializable {
150150
*/
151151
def stageId(): Int
152152

153+
/**
154+
* The attempt ID of the stage that this task belongs to.
155+
*/
156+
def stageAttemptId(): Int
153157
/**
154158
* The ID of the RDD partition that is computed by this task.
155159
*/

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import org.apache.spark.util._
4242
*/
4343
private[spark] class TaskContextImpl(
4444
val stageId: Int,
45+
val stageAttemptId: Int,
4546
val partitionId: Int,
4647
override val taskAttemptId: Long,
4748
override val attemptNumber: Int,

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ private[spark] abstract class Task[T](
7979
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8080
context = new TaskContextImpl(
8181
stageId,
82+
stageAttemptId,
8283
partitionId,
8384
taskAttemptId,
8485
attemptNumber,

core/src/test/scala/org/apache/spark/ShuffleSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
363363

364364
// first attempt -- its successful
365365
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
366-
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
366+
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
367367
val data1 = (1 to 10).map { x => x -> x}
368368

369369
// second attempt -- also successful. We'll write out different data,
370370
// just to simulate the fact that the records may get written differently
371371
// depending on what gets spilled, what gets combined, etc.
372372
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
373-
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
373+
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
374374
val data2 = (11 to 20).map { x => x -> x}
375375

376376
// interleave writes of both attempts -- we want to test that both attempts can occur
@@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
398398
}
399399

400400
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
401-
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
401+
new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
402402
val readData = reader.read().toIndexedSeq
403403
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
404404

core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
2929
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
3030
new TaskContextImpl(
3131
stageId = 0,
32+
stageAttemptId = 0,
3233
partitionId = 0,
3334
taskAttemptId = 0,
3435
attemptNumber = 0,

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
2929
import org.apache.spark.metrics.source.JvmSource
3030
import org.apache.spark.network.util.JavaUtils
3131
import org.apache.spark.rdd.RDD
32+
import org.apache.spark.shuffle.FetchFailedException
3233
import org.apache.spark.util._
3334

3435
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
@@ -158,6 +159,28 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
158159
assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
159160
}
160161

162+
test("TaskContext.stageAttemptId getter") {
163+
sc = new SparkContext("local[1,2]", "test")
164+
165+
// Check stage attemptIds are 0 for initial stage
166+
val stageAttemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
167+
Seq(TaskContext.get().stageAttemptId()).iterator
168+
}.collect()
169+
assert(stageAttemptIds.toSet === Set(0))
170+
171+
// Check stage attemptIds that are resubmitted when task fails
172+
val stageAttemptIdsWithFailedStage =
173+
sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
174+
val stageAttemptId = TaskContext.get().stageAttemptId()
175+
if (stageAttemptId < 2) {
176+
throw new FetchFailedException(null, 0, 0, 0, "Fake")
177+
}
178+
Seq(stageAttemptId).iterator
179+
}.collect()
180+
181+
assert(stageAttemptIdsWithFailedStage.toSet === Set(2))
182+
}
183+
161184
test("accumulators are updated on exception failures") {
162185
// This means use 1 core and 4 max task failures
163186
sc = new SparkContext("local[1,4]", "test")
@@ -190,7 +213,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
190213
// accumulator updates from it.
191214
val taskMetrics = TaskMetrics.empty
192215
val task = new Task[Int](0, 0, 0) {
193-
context = new TaskContextImpl(0, 0, 0L, 0,
216+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
194217
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
195218
new Properties,
196219
SparkEnv.get.metricsSystem,
@@ -213,7 +236,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
213236
// accumulator updates from it.
214237
val taskMetrics = TaskMetrics.registered
215238
val task = new Task[Int](0, 0, 0) {
216-
context = new TaskContextImpl(0, 0, 0L, 0,
239+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
217240
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
218241
new Properties,
219242
SparkEnv.get.metricsSystem,

core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
6262
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
6363
try {
6464
TaskContext.setTaskContext(
65-
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
65+
new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
6666
block
6767
} finally {
6868
TaskContext.unset()

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite
7070

7171
TaskContext.setTaskContext(new TaskContextImpl(
7272
stageId = 0,
73+
stageAttemptId = 0,
7374
partitionId = 0,
7475
taskAttemptId = Random.nextInt(10000),
7576
attemptNumber = 0,

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
116116
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
117117
TaskContext.setTaskContext(new TaskContextImpl(
118118
stageId = 0,
119+
stageAttemptId = 0,
119120
partitionId = 0,
120121
taskAttemptId = 98456,
121122
attemptNumber = 0,

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
114114
(i, converter(Row(i)))
115115
}
116116
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
117-
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null)
117+
val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)
118118

119119
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
120120
taskContext,

0 commit comments

Comments
 (0)