@@ -21,7 +21,8 @@ import java.nio.ByteBuffer
2121import scala .collection .mutable .ArrayBuffer
2222import scala .collection .mutable .HashMap
2323
24- import org .apache .spark .{ExceptionFailure , Logging , SparkEnv , SparkException , Success , TaskState }
24+ import org .apache .spark .{ExceptionFailure , Logging , SparkEnv , SparkException , Success ,
25+ TaskEndReason , TaskResultLost , TaskState }
2526import org .apache .spark .TaskState .TaskState
2627import org .apache .spark .scheduler .{DirectTaskResult , IndirectTaskResult , Pool , Schedulable , Task ,
2728 TaskDescription , TaskInfo , TaskLocality , TaskResult , TaskSet , TaskSetManager }
@@ -144,7 +145,18 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
144145 val result = ser.deserialize[TaskResult [_]](serializedData, getClass.getClassLoader) match {
145146 case directResult : DirectTaskResult [_] => directResult
146147 case IndirectTaskResult (blockId) => {
147- throw new SparkException (" Expect only DirectTaskResults when using LocalScheduler" )
148+ logDebug(" Fetching indirect task result for TID %s" .format(tid))
149+ val serializedTaskResult = env.blockManager.getRemoteBytes(blockId)
150+ if (! serializedTaskResult.isDefined) {
151+ /* We won't be able to get the task result if the block manager had to flush the
152+ * result. */
153+ taskFailed(tid, state, serializedData)
154+ return
155+ }
156+ val deserializedResult = ser.deserialize[DirectTaskResult [_]](
157+ serializedTaskResult.get)
158+ env.blockManager.master.removeBlock(blockId)
159+ deserializedResult
148160 }
149161 }
150162 result.metrics.resultSize = serializedData.limit()
@@ -164,18 +176,28 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
164176 val task = taskSet.tasks(index)
165177 info.markFailed()
166178 decreaseRunningTasks(1 )
167- val reason : ExceptionFailure = ser.deserialize[ExceptionFailure ](
168- serializedData, getClass.getClassLoader)
169- sched.dagScheduler.taskEnded(task, reason, null , null , info, reason.metrics.getOrElse(null ))
179+ var failureReason = " unknown"
180+ ser.deserialize[TaskEndReason ](serializedData, getClass.getClassLoader) match {
181+ case ef : ExceptionFailure =>
182+ failureReason = " Exception failure: %s" .format(ef.description)
183+ val locs = ef.stackTrace.map(loc => " \t at %s" .format(loc.toString))
184+ logInfo(" Task loss due to %s\n %s\n %s" .format(
185+ ef.className, ef.description, locs.mkString(" \n " )))
186+ sched.dagScheduler.taskEnded(task, ef, null , null , info, ef.metrics.getOrElse(null ))
187+
188+ case TaskResultLost =>
189+ failureReason = " Lost result for TID %s" .format(tid)
190+ logWarning(failureReason)
191+ sched.dagScheduler.taskEnded(task, TaskResultLost , null , null , info, null )
192+
193+ case _ => {}
194+ }
170195 if (! finished(index)) {
171196 copiesRunning(index) -= 1
172197 numFailures(index) += 1
173- val locs = reason.stackTrace.map(loc => " \t at %s" .format(loc.toString))
174- logInfo(" Loss was due to %s\n %s\n %s" .format(
175- reason.className, reason.description, locs.mkString(" \n " )))
176198 if (numFailures(index) > MAX_TASK_FAILURES ) {
177- val errorMessage = " Task %s:%d failed more than %d times; aborting job %s " .format(
178- taskSet.id, index, MAX_TASK_FAILURES , reason.description )
199+ val errorMessage = ( " Task %s:%d failed more than %d times; aborting job" +
200+ " (most recent failure: %s " ).format( taskSet.id, index, MAX_TASK_FAILURES , failureReason )
179201 decreaseRunningTasks(runningTasks)
180202 sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
181203 // need to delete failed Taskset from schedule queue
0 commit comments