Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ const char *IncrementOp::attributes_name[1] = {"value"};
OpInfoTuple IncrementOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo(
"x", "paddle::dialect::DenseTensorType", false, false, false, false)};
"x", "paddle::dialect::DenseTensorType", false, false, false, true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,9 @@ std::vector<std::vector<pir::OpResult>> IncrementOp::Vjp(

VLOG(6) << "Vjp prepare Prepare attributes of increment_grad";

float value = op->attribute("value").dyn_cast<pir::FloatAttribute>().data();

VLOG(6) << "Vjp prepare call increment's vjp inteface";

pir::OpResult tensor_res = paddle::dialect::increment(inputs_[0][0], -value);
pir::OpResult tensor_res = paddle::dialect::scale(out_grads[0][0]);

std::vector<std::vector<pir::OpResult>> res{{tensor_res}};

Expand Down
33 changes: 28 additions & 5 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
'''
intersection_op_flags = [True] * len(total_ops)
union_op_flags = [False] * len(total_ops)

# from input to output
if inputs_set:
for i, op in enumerate(total_ops):
Expand Down Expand Up @@ -1047,7 +1048,11 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
)

inputs_set = ValueSet(inputs)
outputs_set = ValueSet(complete_outputs)
stop_gradient_false_outputs = []
for output in complete_outputs:
if output not in no_grad_set:
stop_gradient_false_outputs.append(output)
outputs_set = ValueSet(stop_gradient_false_outputs)

if inplace_net(total_ops):
effective_forward_ops = total_ops
Expand Down Expand Up @@ -1080,7 +1085,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
)
if not inplace_net(backward_ops):

if not inplace_net(backward_ops) and inputs:
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
Expand Down Expand Up @@ -1291,10 +1297,27 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
loss.get_defining_op().get_parent_block().all_parameters()
)

inputs_grad = paddle.autograd.ir_backward.grad(loss, parameter_list)
if no_grad_set is None:
no_grad_set_ = ValueSet()
else:
no_grad_set_ = ValueSet(no_grad_set)

input_to_inputgrad_map = calc_gradient_helper(
_as_list(loss),
[],
grad_outputs=[],
no_grad_set=ValueSet(no_grad_set_),
)

input_inputs_grad = []
for input, input_grad in zip(parameter_list, inputs_grad):
input_inputs_grad.append((input, input_grad))
for input in parameter_list:
input_inputs_grad.append(
(
input,
input_to_inputgrad_map[input][0][0]
if input_to_inputgrad_map[input] != []
else None,
)
)

return input_inputs_grad
20 changes: 19 additions & 1 deletion test/legacy_test/test_increment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TestIncrement(unittest.TestCase):
@test_with_pir_api
def test_api(self):
def _test_api(self):
with base.program_guard(base.Program(), base.Program()):
input = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=5
Expand All @@ -41,6 +41,24 @@ def test_api(self):
output = paddle.tensor.math.increment(input, value=1)
self.assertEqual((output.numpy() == expected_result).all(), True)

def test_no_inplace_increment(self):
with paddle.pir_utils.IrGuard():
with base.program_guard(base.Program(), base.Program()):
x = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
x.stop_gradient = False
input = paddle._pir_ops.increment(x, 1.0)
input = paddle._pir_ops.increment(input, 1.0)
input = paddle._pir_ops.increment(input, 1.0)
out = paddle._pir_ops.increment(input, 1.0)

dx = paddle.base.gradients(out, x)
exe = base.Executor(base.CPUPlace())
result = exe.run(fetch_list=[out, dx])

print(result)


class TestInplaceApiWithDataTransform(unittest.TestCase):
@test_with_pir_api
Expand Down