2323PADDLE_DEFINE_EXPORTED_bool (new_executor_use_inplace, true ,
2424 " Use inplace in new executor" );
2525
26+ constexpr const char * kExceptionCaught = " ExceptionCaught" ;
27+
2628namespace paddle {
2729namespace framework {
2830// NOTE(Aurelius84): Need a better strategy to determine it.
@@ -42,6 +44,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
4244
4345 feed_names_ = feed_names;
4446
47+ exception_notifier_ = main_thread_blocker_.RegisterEvent (
48+ kExceptionCaught , [this ]() { return exception_holder_.IsCaught (); });
49+
4550 // Step1: add feedop and fetchop to main_program
4651 AddFetch (fetch_names);
4752
@@ -360,6 +365,8 @@ void InterpreterCore::ExecuteInstructionList(
360365 async_work_queue_.PrepareAtomicVarRef (vec_meta_info_);
361366 op_run_number_ = 0 ;
362367
368+ exception_holder_.Clear ();
369+
363370 for (size_t i = 0 ; i < dependecy_count_.size (); ++i) {
364371 if (dependecy_count_[i] == 0 ) {
365372 async_work_queue_.AddTask (vec_instr[i].type_ ,
@@ -370,6 +377,11 @@ void InterpreterCore::ExecuteInstructionList(
370377 auto event_id = main_thread_blocker_.WaitEvent ();
371378 VLOG (3 ) << " event_id " << event_id;
372379
380+ if (UNLIKELY (exception_holder_.IsCaught ())) {
381+ VLOG (4 ) << " Exception caught " << exception_holder_.Type ();
382+ exception_holder_.ReThrow ();
383+ }
384+
373385 PADDLE_ENFORCE_EQ (
374386 op_run_number_.load (), vec_instr.size (),
375387 platform::errors::Fatal (
@@ -441,11 +453,34 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
441453 instr_id = ready_ops.front ();
442454 ready_ops.pop ();
443455 auto & instr_node = vec_instruction_[instr_id];
444- platform::RecordEvent instruction_event (
445- instr_node. kernel_func_ . operator_base_ ->Type ());
456+ auto * op = instr_node. kernel_func_ . operator_base_ ;
457+ platform::RecordEvent instruction_event (op ->Type ());
446458 event_manager_.WaitEvent (instr_node, place_);
447459
448- RunInstruction (instr_node);
460+ try {
461+ RunInstruction (instr_node);
462+ } catch (platform::EnforceNotMet& ex) {
463+ framework::InsertCallStackInfo (op->Type (), op->Attrs (), &ex);
464+ exception_holder_.Catch (std::make_exception_ptr (std::move (ex)));
465+ } catch (platform::EOFException&) {
466+ exception_holder_.Catch (std::current_exception ());
467+ } catch (std::exception& ex) {
468+ LOG (WARNING) << op->Type () << " raises an exception "
469+ << platform::demangle (typeid (ex).name ()) << " , "
470+ << ex.what ();
471+ exception_holder_.Catch (std::current_exception ());
472+ } catch (...) {
473+ LOG (WARNING) << op->Type () << " raises an unknown exception" ;
474+ exception_holder_.Catch (std::current_exception ());
475+ }
476+
477+ if (UNLIKELY (exception_holder_.IsCaught ())) {
478+ VLOG (4 ) << " Exception caught" ;
479+ if (exception_notifier_ != nullptr ) {
480+ exception_notifier_->NotifyEvent ();
481+ }
482+ return ;
483+ }
449484
450485 event_manager_.RecordEvent (instr_node, place_);
451486 op_run_number_.fetch_add (1 , std::memory_order_relaxed);
0 commit comments