Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2715,11 +2715,28 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto* name : ctx.InNameList()) {
if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(ctx.InputVar(*name), *name, &data_type);

auto tmp_list = ctx.InNameList();
std::vector<std::string> in_name_list;
std::transform(tmp_list.begin(),
tmp_list.end(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以去掉

std::back_inserter(in_name_list),
[](const std::string* name) { return *name; });
if (Info().HasOpProtoAndChecker()) {
for (auto& attr : Info().Proto().attrs()) {
auto it =
std::find(in_name_list.begin(), in_name_list.end(), attr.name());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用std::find_if,第三个参数用lambda函数即可,会更简洁一些。用法参考:https://en.cppreference.com/w/cpp/algorithm/find

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改感谢

if (attr.support_tensor() && it != in_name_list.end()) {
in_name_list.erase(it);
}
}
}

for (auto& name : in_name_list) {
if (ctx.InputSize(name) == 1UL) {
ParseInputDataType(ctx.InputVar(name), name, &data_type);
} else {
ParseMultiInputDataType(ctx.MultiInputVar(*name), *name, &data_type);
ParseMultiInputDataType(ctx.MultiInputVar(name), name, &data_type);
}
}
PADDLE_ENFORCE_NE(
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/fluid/tests/unittests/test_pad_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ def call_func(self, x):
return out


class TestPaddingValueTensor3(unittest.TestCase):
def test_static(self):
np_x = np.random.random((16, 16)).astype('float32')
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
x = paddle.assign(np_x).astype('float32')
pad_value = paddle.assign([0.0]).astype('float64')
y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value)

exe = paddle.static.Executor(paddle.CPUPlace())
[pd_out] = exe.run(main_prog, fetch_list=[y])
np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0)
np.testing.assert_allclose(pd_out, np_out)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()