@@ -14,18 +14,17 @@ limitations under the License. */
1414
1515#include " paddle/framework/executor.h"
1616
17- #include < algorithm>
18- #include < iostream>
19- #include < memory>
2017#include < set>
21- #include < vector>
2218
19+ #include " gflags/gflags.h"
2320#include " paddle/framework/feed_fetch_type.h"
2421#include " paddle/framework/lod_rank_table.h"
25- #include " paddle/framework/lod_tensor.h"
2622#include " paddle/framework/lod_tensor_array.h"
2723#include " paddle/framework/op_registry.h"
28- #include " paddle/framework/scope.h"
24+
25+ DEFINE_bool (check_nan_inf, false ,
26+ " Checking whether operator produce NAN/INF or not. It will be "
27+ " extremely slow so please use this flag wisely." );
2928
3029namespace paddle {
3130namespace framework {
@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
5857 }
5958}
6059
60+ static void CheckTensorNANOrInf (const std::string& name,
61+ const framework::Tensor& tensor) {
62+ if (tensor.memory_size () == 0 ) {
63+ return ;
64+ }
65+ if (tensor.type ().hash_code () != typeid (float ).hash_code () &&
66+ tensor.type ().hash_code () != typeid (double ).hash_code ()) {
67+ return ;
68+ }
69+ PADDLE_ENFORCE (!framework::HasInf (tensor), " Tensor %s has Inf" , name);
70+ PADDLE_ENFORCE (!framework::HasNAN (tensor), " Tensor %s has NAN" , name);
71+ }
72+
6173void Executor::Run (const ProgramDesc& pdesc, Scope* scope, int block_id,
6274 bool create_local_scope, bool create_vars) {
6375 // TODO(tonyyang-svail):
@@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
101113 auto op = paddle::framework::OpRegistry::CreateOp (*op_desc);
102114 VLOG (3 ) << op->DebugString ();
103115 op->Run (*local_scope, place_);
116+ if (FLAGS_check_nan_inf) {
117+ for (auto & vname : op->OutputVars (true )) {
118+ auto * var = local_scope->FindVar (vname);
119+ if (var == nullptr ) continue ;
120+ if (var->IsType <framework::LoDTensor>()) {
121+ CheckTensorNANOrInf (vname, var->Get <framework::LoDTensor>());
122+ }
123+ }
124+ }
104125 }
105126 if (create_vars && create_local_scope) {
106127 scope->DeleteScope (local_scope);
0 commit comments