Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/fluid/framework/details/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
Expand All @@ -30,9 +31,28 @@ void CheckVarHasNanOrInf(const std::string& op_type,
const std::string& var_name,
const platform::Place& place);

void CheckVarHasNanOrInf(const std::string& op_type,
const std::string& var_name,
const framework::Variable* var,
const platform::Place& place);

void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place);

template <typename VarType>
void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
const imperative::NameVarMap<VarType>& op_outs,
platform::Place place) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过这个没有过滤op、var的功能了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦是的,那我再补充一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先合入再补充也行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经讨论,后续有需求再补充

for (const auto& pair : op_outs) {
for (const auto& ivar : pair.second) {
auto* var = ivar->MutableVar();
if (var == nullptr) continue;
CheckVarHasNanOrInf(op_type, ivar->Name(), var, place);
}
}
}

} // namespace details
} // namespace framework
} // namespace paddle
15 changes: 11 additions & 4 deletions paddle/fluid/framework/details/nan_inf_utils_detail.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,12 @@ void tensor_check<platform::CPUDeviceContext>(const std::string& op_type,
}

void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope,
const std::string& var_name,
const framework::Variable* var,
const platform::Place& place) {
auto* var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("In op=%s, can't find var:%s", op_type,
var_name));
var, platform::errors::NotFound("Cannot find var: `%s` in op `%s`.",
var_name, op_type));

const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) {
Expand Down Expand Up @@ -393,6 +392,14 @@ void CheckVarHasNanOrInf(const std::string& op_type,
tensor_check<platform::CPUDeviceContext>(op_type, var_name, *tensor, place);
}

void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place) {
auto* var = scope.FindVar(var_name);
CheckVarHasNanOrInf(op_type, var_name, var, place);
}

bool IsSkipOp(const framework::OperatorBase& op) {
if (op_type_nan_inf_white_list().count(op.Type()) != 0) return true;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cc_library(imperative_flag SRCS flags.cc DEPS gflags)

cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils)
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer )
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#include "paddle/fluid/imperative/prepared_operator.h"

#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/imperative/infer_shape_context.h"

DECLARE_bool(check_nan_inf);

namespace paddle {
namespace imperative {

Expand Down Expand Up @@ -175,6 +178,11 @@ static void PreparedOpRunImpl(
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs));

if (FLAGS_check_nan_inf) {
framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
op.Type(), outs, dev_ctx->GetPlace());
}

/**
* [ Why need handle complex gradient to real gradient? ]
*
Expand Down
112 changes: 112 additions & 0 deletions python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import unicode_literals
from __future__ import print_function

import os
import sys
import time
import numpy as np

os.environ[str("FLAGS_check_nan_inf")] = str("1")
os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10")

import paddle
import paddle.nn as nn

np.random.seed(0)


def generator():
batch_size = 5
for i in range(5):
curr_train_x = np.random.randint(
batch_size, size=(batch_size, 3)).astype("float32")
if i >= 2:
curr_train_x[0, :] = np.nan
curr_train_x[-1, :] = np.inf
res = []
for i in range(batch_size):
y = i % 3
res.append([y])
y_label = np.array(res).astype('int64')
yield [curr_train_x, y_label]


class TestLayer(nn.Layer):
def __init__(self):
super(TestLayer, self).__init__()
self.linear1 = nn.Linear(3, 400)
self.linear2 = nn.Linear(400, 400)
self.linear3 = nn.Linear(400, 3)

def forward(self, x):
x = self.linear1(x)
x = nn.functional.sigmoid(x)
x = self.linear2(x)
x = nn.functional.sigmoid(x)
x = self.linear3(x)
x = nn.functional.softmax(x)

return x


def check(use_cuda):
paddle.set_device('gpu' if use_cuda else 'cpu')

net = TestLayer()
sgd = paddle.optimizer.SGD(learning_rate=0.05, parameters=net.parameters())

for step, (x, y) in enumerate(generator()):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)

zero = paddle.zeros(shape=[1], dtype='int64')
fp16_zero = paddle.cast(zero, dtype='float16')

y = y + zero

y_pred = net(x)

cost = nn.functional.cross_entropy(y_pred, y, use_softmax=False)
avg_cost = paddle.mean(cost)

acc_top1 = paddle.metric.accuracy(input=y_pred, label=y, k=1)

print('iter={:.0f}, cost={}, acc1={}'.format(
step, avg_cost.numpy(), acc_top1.numpy()))

sgd.step()
sgd.clear_grad()


if __name__ == '__main__':
if paddle.is_compiled_with_cuda():
try:
check(use_cuda=True)
assert False
except Exception as e:
print(e)
print(type(e))
# Note. Enforce in cuda kernel may not catch in paddle, and
# Exception type will be RuntimeError
assert type(e) == OSError or type(e) == RuntimeError
try:
check(use_cuda=False)
assert False
except Exception as e:
print(e)
print(type(e))
assert type(e) == RuntimeError
11 changes: 9 additions & 2 deletions python/paddle/fluid/tests/unittests/test_nan_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ def setUp(self):
self._python_interp = sys.executable
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
self._python_interp += " -m coverage run --branch -p"
self._python_interp += " check_nan_inf_base.py"

self.env = os.environ.copy()

def test_nan_inf(self):
def check_nan_inf(self):
cmd = self._python_interp

proc = subprocess.Popen(
Expand All @@ -53,6 +52,14 @@ def test_nan_inf(self):
assert (out + err
).find('There are `nan` or `inf` in tensor'.encode()) != -1

def test_nan_inf_in_static_mode(self):
self._python_interp += " check_nan_inf_base.py"
self.check_nan_inf()

def test_nan_inf_in_dynamic_mode(self):
self._python_interp += " check_nan_inf_base_dygraph.py"
self.check_nan_inf()


class TestNanInfEnv(TestNanInf):
def setUp(self):
Expand Down