Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions paddle/fluid/operators/controlflow/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <set>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
Expand Down Expand Up @@ -70,6 +72,23 @@ class WhileOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock);

auto *program = block->Program();
bool is_test = Attr<bool>("is_test");

std::set<std::string> no_copy_var_names;
if (!is_test) {
const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
for (const framework::OpDesc *op : all_ops) {
const framework::VariableNameMap &input_var_names = op->Inputs();
const framework::VariableNameMap &output_var_names = op->Outputs();
for (auto &ipt : input_var_names) {
for (const std::string &var_name : ipt.second) {
if (StrInVaraiableNameMap(var_name, output_var_names)) {
no_copy_var_names.insert(var_name);
}
}
}
}
}

auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
Expand All @@ -89,7 +108,6 @@ class WhileOp : public framework::OperatorBase {
"The Output(StepScope) of WhileOp should be empty."));

bool cond_data = GetCondData(cond);
bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);

Expand All @@ -98,8 +116,32 @@ class WhileOp : public framework::OperatorBase {
while (cond_data) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);

std::vector<std::string> rename_vars;
for (const std::string &input_var_name : Inputs(kX)) {
if (no_copy_var_names.find(input_var_name) ==
no_copy_var_names.end()) {
std::string input_var_rename = input_var_name + kSuffix;
framework::Variable *input_var = scope.FindVar(input_var_name);
if (input_var->IsType<framework::LoDTensor>()) {
rename_vars.push_back(input_var_rename);
auto input_var_tensor = input_var->Get<LoDTensor>();
auto *rename_input_var_tensor =
current_scope.Var(input_var_rename)->GetMutable<LoDTensor>();
framework::TensorCopy(input_var_tensor, dev_place,
rename_input_var_tensor);
rename_input_var_tensor->set_lod(input_var_tensor.lod());
}
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true);

for (auto &var_rename : rename_vars) {
std::string input_var_name =
var_rename.substr(0, var_rename.size() - strlen(kSuffix));
current_scope.Rename(var_rename, input_var_name);
}
cond_data =
GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
}
Expand Down Expand Up @@ -312,6 +354,10 @@ class WhileGradOp : public framework::OperatorBase {
// continue;
// }

auto var_iter =
std::find(outside_og_names.begin(), outside_og_names.end(),
pg_ig_names[param_id]);

// zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) {
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
Expand All @@ -326,7 +372,8 @@ class WhileGradOp : public framework::OperatorBase {
"or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, framework::ToTypeName(var->Type())));

if (var->IsType<LoDTensor>()) {
if ((var_iter == outside_og_names.end()) &&
var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["dtype"] = inside_tensor.type();
Expand All @@ -343,13 +390,18 @@ class WhileGradOp : public framework::OperatorBase {
->set_lod(inside_tensor.lod());
}
}
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
{{"Out", {pg_ig_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
auto var_outside = scope.FindVar(pg_ig_names[param_id]);
if ((var_iter == outside_og_names.end()) ||
((var_iter != outside_og_names.end()) &&
var_outside->IsType<framework::LoDTensorArray>())) {
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
{{"Out", {pg_ig_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
}
dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/controlflow/while_op_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,16 @@ bool GetCondData(const framework::LoDTensor &cond) {
return cpu_cond->data<bool>()[0];
}

bool StrInVaraiableNameMap(const std::string &name,
const framework::VariableNameMap &var_names) {
for (auto &ipt : var_names) {
if (std::find(ipt.second.begin(), ipt.second.end(), name) !=
ipt.second.end()) {
return true;
}
}
return false;
}

} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/controlflow/while_op_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
static constexpr char kSuffix[] = "@TMP_COPY";

void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program, int block_id,
Expand All @@ -50,5 +51,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(

bool GetCondData(const framework::LoDTensor &cond);

bool StrInVaraiableNameMap(const std::string &,
const framework::VariableNameMap &);

} // namespace operators
} // namespace paddle
29 changes: 15 additions & 14 deletions python/paddle/fluid/tests/unittests/test_while_loop_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import unittest

import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
Expand All @@ -24,6 +25,8 @@
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()


class TestApiWhileLoop(unittest.TestCase):
def test_var_tuple(self):
Expand Down Expand Up @@ -199,16 +202,10 @@ def test_while_loop_backward(self):
def cond(i, x):
return layers.less_than(i, eleven)

def body(j, x):
# TODO: In while block, if the var created in parent block
# participates in the calculation of gradient, the result of gradient
# is incorrect because each step scope always returns the same value
# generated by last step.
# Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR.
i = layers.assign(j)
def body(i, x):
x = layers.elementwise_mul(x=i, y=i)
j = layers.increment(j)
return [j, x]
i = layers.increment(i)
return [i, x]

main_program = Program()
startup_program = Program()
Expand Down Expand Up @@ -244,10 +241,10 @@ def body(j, x):

def test_while_loop_backward2(self):
def cond(i, x):
return i < 5
return i < 3

def body(i, x):
x = x + i
x = x * i
i = i + 1
return [i, x]

Expand All @@ -269,17 +266,21 @@ def body(i, x):

feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([11]).astype('float32')
i_grad = np.asarray([1]).astype('float32')
data = np.asarray([2]).astype('float32')
i_grad = np.asarray([3]).astype('float32')
x_grad = np.asarray([2]).astype('float32')

res = exe.run(main_program,
feed={'i': feed_i,
'x': feed_x},
fetch_list=[mean.name, i.grad_name])
fetch_list=[mean.name, i.grad_name, x.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
self.assertTrue(
np.allclose(np.asarray(res[2]), x_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[2], x_grad))


class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_while_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import numpy
from paddle.fluid import compiler, Program, program_guard

paddle.enable_static()


class TestWhileOp(unittest.TestCase):
def simple_net(self):
Expand Down