Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ std::vector<pir::Type> AddNOp::InferMeta(
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::AddNInferMeta(meta_x, &meta_out);
phi::AddNInferMeta(meta_x, &meta_out, phi::MetaConfig(false, false));

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
Expand Down Expand Up @@ -358,7 +358,7 @@ std::vector<pir::Type> AddN_Op::InferMeta(
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::AddNInferMeta(meta_inputs, &meta_out);
phi::AddNInferMeta(meta_inputs, &meta_out, phi::MetaConfig(false, false));

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
Expand Down Expand Up @@ -548,7 +548,7 @@ std::vector<pir::Type> AddNWithKernelOp::InferMeta(
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::AddNInferMeta(meta_inputs, &meta_out);
phi::AddNInferMeta(meta_inputs, &meta_out, phi::MetaConfig(false, false));

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,10 @@ def get_grad_semantic_info(op):
else:
grad_semantic_info = op.get_input_grad_semantics()
return grad_semantic_info


def get_split_op(value):
for op in value.all_used_ops():
if op.name() == "builtin.split":
return op
return None
44 changes: 25 additions & 19 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
dynamic_shape_prim_vjp_guard,
get_grad_semantic_info,
get_real_op_inputs,
get_split_op,
inverse_sort_op,
is_control_flow,
is_inplace_net,
Expand Down Expand Up @@ -90,24 +91,30 @@ def append_add_n(
# need add sum op to accumulate gradient
add_n_list = []
for item in state.value_to_valuegrad[value]:
add_n_list.append(
return_map_value(item[0], bwd_value_to_block_argument_map)
)
if item[0] is not None:
add_n_list.append(
return_map_value(item[0], bwd_value_to_block_argument_map)
)

if value.is_dense_tensor_array_type():
add_n_value = paddle._pir_ops.add_n_array(add_n_list)
if len(add_n_list) == 0:
for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = []
else:
add_n_value = paddle.add_n(add_n_list)
if value.is_dense_tensor_array_type():
add_n_value = paddle._pir_ops.add_n_array(add_n_list)
else:
add_n_value = paddle.add_n(add_n_list)

add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op]
)
add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op]
)

for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = [[add_n_value]]
for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = [[add_n_value]]


def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op_list):
Expand Down Expand Up @@ -342,18 +349,15 @@ def make_output_with_output_grad(op):
value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] == []
):
if (
not value.use_empty()
and value.first_use().owner().name() == "builtin.split"
):
if not value.use_empty() and get_split_op(value) is not None:
# pattern case:
# this fwd_op's output is vectorType, it will split to
# Type by builtin_split op, so need get from split op's outputs.
(
split_zero_flag,
split_outputs,
split_output_grad,
) = make_output_with_output_grad(value.first_use().owner())
) = make_output_with_output_grad(get_split_op(value))
zero_flag[i] = all(split_zero_flag)
grad_values = [value[0] for value in split_output_grad]
state.value_to_valuegrad[value] = [grad_values]
Expand All @@ -374,6 +378,8 @@ def make_output_with_output_grad(op):

outputs.append(new_value)
grad_value = state.value_to_valuegrad[value][0]
if grad_value[0] is None:
zero_flag[i] = True
output_grads.append(
return_map_value_list(
grad_value, bwd_value_to_block_argument_map
Expand Down