Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
08d70cb
add unused input vars check for OpWithKernel, test=develop
zhiqiu Nov 13, 2019
dc41961
remove unused vars in some ops, test=develop
zhiqiu Nov 19, 2019
032ef2f
fix batch_norm, test=develop
zhiqiu Nov 20, 2019
7aff856
add white list, test=develop
zhiqiu Nov 20, 2019
e69b905
add CI check for white list, test=develop
zhiqiu Nov 20, 2019
afc8da5
:ove white list to c++, test=develop
zhiqiu Nov 21, 2019
1db659b
solve failure of CI, test=develop
zhiqiu Nov 22, 2019
5b1047d
add unittest for unused_var_check, test=develop
zhiqiu Nov 22, 2019
e97a64d
refine code, enable check in operator_test, test=develop
zhiqiu Nov 22, 2019
d45ae40
skip mkldnn, test=develop
zhiqiu Nov 25, 2019
7f2b8ca
extend white list, test=develop
zhiqiu Nov 25, 2019
67b9f4f
refine condition of mkldnn, test=develop
zhiqiu Nov 27, 2019
234dc12
test=develop, Merge branch 'develop' of https://github.com/PaddlePadd…
zhiqiu Nov 27, 2019
a180079
fix paddle_build, test=develop
zhiqiu Nov 27, 2019
a470fbc
follow comments, test=develop
zhiqiu Nov 28, 2019
0fe1a68
test=develop, Merge branch 'develop' of https://github.com/PaddlePadd…
zhiqiu Nov 28, 2019
98f1509
fix GetExpectedKernelType
zhiqiu Nov 28, 2019
6ee0ab0
add wiki ref to err_msg, test=develop
zhiqiu Nov 28, 2019
301f835
test=develop, Merge branch 'develop' of https://github.com/PaddlePadd…
zhiqiu Nov 28, 2019
1d9bb16
test=develop, Merge branch 'develop' of https://github.com/PaddlePadd…
zhiqiu Nov 28, 2019
c8dff60
test=develop, Merge branch 'develop' of https://github.com/PaddlePadd…
zhiqiu Nov 29, 2019
cfc3a6e
follow comment, test=develop
zhiqiu Nov 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)

cc_library(scope SRCS scope.cc DEPS glog threadpool xxhash var_type_traits)

cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
Expand Down Expand Up @@ -126,8 +127,11 @@ cc_test(no_need_buffer_vars_inference_test SRCS no_need_buffer_vars_inference_te

cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)

cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference)

cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check)

cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op_registry device_context)
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"

DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op");
DEFINE_bool(fast_check_nan_inf, false,
"Fast checking NAN/INF after each operation. It will be a little"
Expand Down Expand Up @@ -428,6 +430,8 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
}

const Variable* ExecutionContext::InputVar(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this function with if (FLAGS_enable_unused_var_check)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. To reduce code duplication, I will put the if statement inside LogVarUsageIfUnusedVarCheckEnabled.


auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;

Expand Down Expand Up @@ -457,6 +461,8 @@ const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this function with if (FLAGS_enable_unused_var_check)?

Copy link
Contributor Author

@zhiqiu zhiqiu Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. To reduce code duplication, I will put the if statement inside LogVarUsageIfUnusedVarCheckEnabled.


auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
Expand Down Expand Up @@ -910,6 +916,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx);
}

if (FLAGS_enable_unused_var_check) {
GetThreadLocalUsedVarNameSet()->clear();
}

// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
(*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
Expand All @@ -919,6 +930,14 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// there is inplace variable has been transfered.
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
}
if (FLAGS_enable_unused_var_check) {
// skip op that uses mkldnn because it has different memory reuse strategy.
// use attr here because some GradMakers (like ActivationGradOpMaker) add
// input when use_mkldnn=true;
if (!(HasAttr("use_mkldnn") && Attr<bool>("use_mkldnn"))) {
CheckUnusedVar(*this, scope);
}
}

/*For profiling/benchmark only*/
if (FLAGS_benchmark) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
Expand Down Expand Up @@ -268,6 +269,8 @@ class ExecutionContext {

const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.


auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
Expand Down Expand Up @@ -298,6 +301,8 @@ class ExecutionContext {

template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.


auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
Expand Down Expand Up @@ -349,6 +354,7 @@ class ExecutionContext {

//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
LogVarUsageIfUnusedVarCheckEnabled(name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

return op_.Inputs(name);
}

Expand Down
118 changes: 118 additions & 0 deletions paddle/fluid/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/init.h"

DECLARE_bool(enable_unused_var_check);

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -139,6 +141,8 @@ class CPUKernelTest : public OpKernel<float> {
cpu_kernel_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1");
auto* x = ctx.Input<Tensor>("X");
ASSERT_EQ(x, nullptr);
}
};

Expand Down Expand Up @@ -591,3 +595,117 @@ void SetGetLoDLevelTestMain(std::string op_type) {
TEST(GetLoDLevelTest, base) { SetGetLoDLevelTestMain("get_lod_level_test"); }

TEST(SetLoDLevelTest, base) { SetGetLoDLevelTestMain("set_lod_level_test"); }

namespace paddle {
namespace framework {

class OpUnusedVarTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
framework::DataLayout::kAnyLayout);
}
};

class OpUnusedVarTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("Y", "output of test op");
AddComment("This is test op for unused var check.");
}
};

template <typename T>
class OpWithUnusedVarKernelTest : public OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const {
ASSERT_EQ(ctx.op().Input("X"), "X");
ASSERT_EQ(ctx.op().Output("Y"), "Y");
}
};

template <typename T>
class OpWithoutUnusedVarKernelTest : public OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const {
ASSERT_EQ(ctx.op().Input("X"), "X");
ASSERT_EQ(ctx.op().Output("Y"), "Y");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
ASSERT_NE(x, y);
ASSERT_NE(y, nullptr);
}
};

} // namespace framework
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
op_with_unused_var, paddle::framework::OpUnusedVarTest,
paddle::framework::OpUnusedVarTestProtoAndCheckerMaker);

REGISTER_OP_CPU_KERNEL(op_with_unused_var,
paddle::framework::OpWithUnusedVarKernelTest<float>);

REGISTER_OP_WITHOUT_GRADIENT(
op_without_unused_var, paddle::framework::OpUnusedVarTest,
paddle::framework::OpUnusedVarTestProtoAndCheckerMaker);

REGISTER_OP_CPU_KERNEL(op_without_unused_var,
paddle::framework::OpWithoutUnusedVarKernelTest<float>);

// test with single input
TEST(OpWithUnusedVar, all) {
// enable the unused_var_check
FLAGS_enable_unused_var_check = true;
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_unused_var");
BuildVar("X", {"X"}, op_desc.add_inputs());
BuildVar("Y", {"Y"}, op_desc.add_outputs());

paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto* x = scope.Var("X")->GetMutable<paddle::framework::LoDTensor>();
auto* y = scope.Var("Y")->GetMutable<paddle::framework::LoDTensor>();
x->Resize({32, 64});
y->Resize({32, 64});
x->mutable_data<float>(cpu_place);
y->mutable_data<float>(cpu_place);

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// should throw exception
ASSERT_THROW(op->Run(scope, cpu_place), paddle::platform::EnforceNotMet);
FLAGS_enable_unused_var_check = false;
}

TEST(OpWithoutUnusedVar, all) {
// enable the unused_var_check
FLAGS_enable_unused_var_check = true;

paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_without_unused_var");
BuildVar("X", {"X"}, op_desc.add_inputs());
BuildVar("Y", {"Y"}, op_desc.add_outputs());

paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto* x = scope.Var("X")->GetMutable<paddle::framework::LoDTensor>();
auto* y = scope.Var("Y")->GetMutable<paddle::framework::LoDTensor>();
x->Resize({32, 64});
y->Resize({32, 64});
x->mutable_data<float>(cpu_place);
y->mutable_data<float>(cpu_place);

auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// should not throw exception
ASSERT_NO_THROW(op->Run(scope, cpu_place));
FLAGS_enable_unused_var_check = false;
}
130 changes: 130 additions & 0 deletions paddle/fluid/framework/unused_var_check.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gflags/gflags.h>
#include <glog/logging.h>

#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/platform/enforce.h"

DEFINE_bool(enable_unused_var_check, false,
"Checking whether operator contains unused inputs, "
"especially for grad operator. It should be in unittest.");

const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"auc",
"batch_norm",
"batch_norm_grad",
"sync_batch_norm_grad",
"center_loss_grad",
"crop",
"cvm",
"cos_sim_grad",
"dgc_momentum",
"fake_quantize_range_abs_max",
"fill_zeros_like",
"fusion_seqpool_cvm_concat",
"reshape2_grad_grad",
"reshape2_grad",
"gru_grad",
"hierarchical_sigmoid_grad",
"nce_grad",
"roi_perspective_transform_grad",
"sequence_conv_grad",
"gru_unit_grad",
"affine_grid_grad",
"fill_any_like",
"precision_recall",
"unsqueeze_grad",
"kldiv_loss_grad",
"cvm_grad",
"stack_grad",
"warpctc_grad",
"sync_batch_norm",
"match_matrix_tensor_grad",
"ngraph_engine"};

namespace paddle {
namespace framework {

std::unordered_set<std::string> *GetThreadLocalUsedVarNameSet() {
thread_local std::unordered_set<std::string> used_var_name_set;
return &used_var_name_set;
}

void LogVarUsageIfUnusedVarCheckEnabled(const std::string &name) {
if (FLAGS_enable_unused_var_check) {
VLOG(6) << "Variable used:" << name;
GetThreadLocalUsedVarNameSet()->insert(name);
}
}

void CheckUnusedVar(const OperatorBase &op, const Scope &scope) {
// skip op in white list and it should be fixed in the future.
if (op_has_unsed_vars_white_list.count(op.Type()) != 0) {
return;
}
auto *used_set = GetThreadLocalUsedVarNameSet();
std::vector<std::string> unsed_input_var_names;
auto &inferer = op.Info().NoNeedBufferVarsInferer();
std::unordered_set<std::string> no_need_buffer_ins = {};
if (inferer) {
no_need_buffer_ins = inferer(op.Inputs(), op.Outputs(), op.Attrs());
}

for (auto &pair : op.Inputs()) {
// skip no need buffer vars declared
if (no_need_buffer_ins.count(pair.first) != 0) {
VLOG(6) << op.Type() << " " << pair.first;
continue;
}
if (used_set->count(pair.first) == 0) {
for (auto &in_var_name : pair.second) {
auto *in_var = scope.FindVar(in_var_name);
if (in_var != nullptr && in_var->IsInitialized()) {
auto *tensor = &in_var->Get<LoDTensor>();
if (tensor != nullptr && tensor->IsInitialized()) {
unsed_input_var_names.emplace_back(pair.first);
break;
}
}
}
}
}
if (!unsed_input_var_names.empty()) {
std::string err_msg = "Operator " + op.Type() + " has input(s) not uesed: ";
for (auto &in_var_name : unsed_input_var_names) {
err_msg += in_var_name;
err_msg += ", ";
}
err_msg +=
"please make sure it(they) is(are) needed. If not, remove it(them) "
"from inputs of the operator; if yes, register "
"NoNeedBufferVarsInference or add "
"the operator to "
"white list in unused_var_check.cc. See more details at "
"[https://github.com/PaddlePaddle/Paddle/wiki/"
"OP-Should-Not-Have-Unused-Input]";
PADDLE_ENFORCE_EQ(unsed_input_var_names.size(), 0,
platform::errors::PermissionDenied(
"Unused input variables check failed: %s", err_msg));
}
}

} // namespace framework
} // namespace paddle
Loading