Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4246,7 +4246,7 @@ const char *IncrementOp::attributes_name[1] = {"value"};
OpInfoTuple IncrementOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
paddle::dialect::OpInputInfo(
"x", "paddle::dialect::DenseTensorType", false, false, false, false)};
"x", "paddle::dialect::DenseTensorType", false, false, false, true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {
paddle::dialect::OpAttributeInfo("value", "pir::FloatAttribute", "")};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,9 @@ std::vector<std::vector<pir::OpResult>> IncrementOp::Vjp(

VLOG(6) << "Vjp prepare Prepare attributes of increment_grad";

float value = op->attribute("value").dyn_cast<pir::FloatAttribute>().data();

VLOG(6) << "Vjp prepare call increment's vjp inteface";

pir::OpResult tensor_res = paddle::dialect::increment(inputs_[0][0], -value);
pir::OpResult tensor_res = paddle::dialect::scale(out_grads[0][0]);

std::vector<std::vector<pir::OpResult>> res{{tensor_res}};

Expand Down
33 changes: 28 additions & 5 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
'''
intersection_op_flags = [True] * len(total_ops)
union_op_flags = [False] * len(total_ops)

# from input to output
if inputs_set:
for i, op in enumerate(total_ops):
Expand Down Expand Up @@ -1045,7 +1046,11 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
)

inputs_set = ValueSet(inputs)
outputs_set = ValueSet(complete_outputs)
stop_gradient_false_outputs = []
for output in complete_outputs:
if output not in no_grad_set:
stop_gradient_false_outputs.append(output)
outputs_set = ValueSet(stop_gradient_false_outputs)

if inplace_net(total_ops):
effective_forward_ops = total_ops
Expand Down Expand Up @@ -1078,7 +1083,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
)
if not inplace_net(backward_ops):

if not inplace_net(backward_ops) and inputs:
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
Expand Down Expand Up @@ -1289,10 +1295,27 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
loss.get_defining_op().get_parent_block().all_parameters()
)

inputs_grad = paddle.autograd.ir_backward.grad(loss, parameter_list)
if no_grad_set is None:
no_grad_set_ = ValueSet()
else:
no_grad_set_ = ValueSet(no_grad_set)

input_to_inputgrad_map = calc_gradient_helper(
_as_list(loss),
[],
grad_outputs=[],
no_grad_set=ValueSet(no_grad_set_),
)

input_inputs_grad = []
for input, input_grad in zip(parameter_list, inputs_grad):
input_inputs_grad.append((input, input_grad))
for input in parameter_list:
input_inputs_grad.append(
(
input,
input_to_inputgrad_map[input][0][0]
if input_to_inputgrad_map[input] != []
else None,
)
)

return input_inputs_grad
20 changes: 20 additions & 0 deletions test/legacy_test/test_increment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class TestIncrement(unittest.TestCase):
@test_with_pir_api
def test_api(self):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
input = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=5
Expand All @@ -41,6 +42,25 @@ def test_api(self):
output = paddle.tensor.math.increment(input, value=1)
self.assertEqual((output.numpy() == expected_result).all(), True)

def test_no_inplace_increment(self):
with paddle.pir_utils.IrGuard():
with base.program_guard(base.Program(), base.Program()):
x = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
x.stop_gradient = False
input = paddle._pir_ops.increment(x, 1.0)
input = paddle._pir_ops.increment(input, 1.0)
input = paddle._pir_ops.increment(input, 1.0)
out = paddle._pir_ops.increment(input, 1.0)

dx = paddle.base.gradients(out, x)
exe = base.Executor(base.CPUPlace())
result = exe.run(fetch_list=[out, dx])

self.assertEqual(result[0], 5.0)
self.assertEqual(result[1], 1.0)


class TestInplaceApiWithDataTransform(unittest.TestCase):
@test_with_pir_api
Expand Down