Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def inverse_sort_op(ops):
return sorted_list


def inplace_net(op_list):
def is_inplace_net(op_list):
'''
when program has inpalce op , it's difficult to find the actual pending_count.
'''
Expand All @@ -388,7 +388,7 @@ def inplace_net(op_list):
return True
if is_control_flow(op):
for block in op.blocks():
if inplace_net(block.ops):
if is_inplace_net(block.ops):
return True

return False
Expand Down Expand Up @@ -452,3 +452,61 @@ def parent_total_ops(block):
total_ops += block.ops

return total_ops


# only for control_flow to find corresponding value or value_list
def return_map_value(value, map):
output = value
while output in map:
output = map[output]
return output


def return_map_value_list(value, map):
output = []
for i in range(len(value)):
if value[i] in map:
output.append(map[value[i]])
else:
output.append(value[i])
return output


def argument_to_value(while_op):
'''
return while op's relationship of (block_argument to input value) and (input value to block_argument).
'''
if while_op.name() != "pd_op.while":
return ValueDict(), ValueDict()

assert len(while_op.as_while_op().block_arguments()) + 1 == len(
while_op.operands_source()
), "while op's block_arguments size + 1 should same to while op's operands_source size"
arg_to_value_map = ValueDict()
value_to_arg_map = ValueDict()
for arg, value in zip(
while_op.as_while_op().block_arguments(),
while_op.operands_source()[1:],
):
arg_to_value_map[arg] = value
value_to_arg_map[value] = arg
return arg_to_value_map, value_to_arg_map


def get_grad_semantic_info(op):
'''
return whether op's inputs has grad, usually handled from yaml.
some op has uncertain inputs need special handling.
'''
if op.name() in [
"builtin.combine",
"pd_op.if",
"pd_op.while",
"cf.tuple_push",
]:
grad_semantic_info = [True for _ in range(len(get_real_op_inputs(op)))]
if op.name() == "pd_op.if":
grad_semantic_info[0] = False
else:
grad_semantic_info = op.get_input_grad_semantics()
return grad_semantic_info
176 changes: 86 additions & 90 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
ValueSet,
_as_list,
all_stop_gradient_true,
argument_to_value,
check_type,
dynamic_shape_prim_vjp_guard,
get_grad_semantic_info,
get_real_op_inputs,
inplace_net,
inverse_sort_op,
is_control_flow,
is_inplace_net,
parent_total_ops,
remove_op,
remove_useless_full_like_ops,
return_map_value,
return_map_value_list,
some_in_set,
update_no_grad_set_by_stopgradient,
)
Expand Down Expand Up @@ -77,6 +81,34 @@ def append_full_like(float_value, copy_value, value, state, backward_ops):
return value_grad


def append_add_n(
op, value, state, backward_ops, bwd_value_to_block_argument_map
):
# value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
add_n_list = []
for item in state.value_to_valuegrad[value]:
add_n_list.append(
return_map_value(item[0], bwd_value_to_block_argument_map)
)

if value.is_tensorarray():
add_n_value = paddle._pir_ops.add_n_array(add_n_list)
else:
add_n_value = paddle.add_n(add_n_list)

add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op]
)

for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = [[add_n_value]]


def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op_list):
for grad_op in grad_op_list:
backward_ops.append(grad_op)
Expand Down Expand Up @@ -272,45 +304,6 @@ def append_backward_ops(
else continue to next op.
'''

def return_map_value(value, map):
output = value
while output in map:
output = map[output]
return output

def return_map_value_list(grad_value, map):
output = []
for i in range(len(grad_value)):
if grad_value[i] in map:
output.append(map[grad_value[i]])
else:
output.append(grad_value[i])
return output

def append_add_n(value):
# value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
add_n_list = []
for item in state.value_to_valuegrad[value]:
add_n_list.append(
return_map_value(item[0], bwd_value_to_block_argument_map)
)
if value.is_tensorarray():
add_n_value = paddle._pir_ops.add_n_array(add_n_list)
else:
add_n_value = paddle.add_n(add_n_list)

add_n_op = add_n_value.get_defining_op()
combine_op = add_n_op.operand_source(0).get_defining_op()
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], [combine_op, add_n_op]
)

for tmp in state.value_to_valuegrad[value]:
state.value_to_sumvaluegrad[value].append(tmp)
state.value_to_valuegrad[value] = [[add_n_value]]

def make_output_with_output_grad(op):
zero_flag = [False] * op.num_results()
outputs = []
Expand All @@ -326,14 +319,22 @@ def make_output_with_output_grad(op):
new_value = [
return_map_value(value, control_flow_value_to_copyvalue_map)
]
while value in state.inside_value_to_outside_value_map:
value = state.inside_value_to_outside_value_map[value]

value = return_map_value(
value, state.inside_value_to_outside_value_map
)

if (
value in state.value_to_valuegrad
and len(state.value_to_valuegrad[value]) > 1
):
append_add_n(value)
append_add_n(
op,
value,
state,
backward_ops,
bwd_value_to_block_argument_map,
)

if (
value not in state.value_to_valuegrad
Expand Down Expand Up @@ -379,12 +380,19 @@ def make_output_with_output_grad(op):

if op.name() == "pd_op.array_read":
value = op.operand_source(0)
while value in state.inside_value_to_outside_value_map:
value = state.inside_value_to_outside_value_map[value]
value = return_map_value(
value, state.inside_value_to_outside_value_map
)

if value in state.value_to_valuegrad:
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(value)
append_add_n(
op,
value,
state,
backward_ops,
bwd_value_to_block_argument_map,
)

if (
value not in state.value_to_valuegrad
Expand All @@ -409,22 +417,6 @@ def make_output_with_output_grad(op):

return zero_flag, outputs, output_grads

def get_grad_semantic_info(op):
if op.name() in [
"builtin.combine",
"pd_op.if",
"pd_op.while",
"cf.tuple_push",
]:
grad_semantic_info = [
True for _ in range(len(get_real_op_inputs(op)))
]
if op.name() == "pd_op.if":
grad_semantic_info[0] = False
else:
grad_semantic_info = op.get_input_grad_semantics()
return grad_semantic_info

def make_input_with_input_stopgradient(op):
inputs = []
input_grad_stopgradients = []
Expand Down Expand Up @@ -516,28 +508,37 @@ def append_yield(
fwd_block_argument_to_value_map,
fwd_value_to_block_argument_map,
) = argument_to_value(base_op)

with block:
inputs_grad = []
if base_op.name() == "pd_op.while":
new_cond = paddle.base.libpaddle.pir.cf_has_elements(base_op)
inputs_grad.append(new_cond)

# while use block_arg to create grad_op
for idx in range(len(base_inputs[: base_op.num_operands()])):
operands = base_inputs[idx]
if operands in fwd_value_to_block_argument_map:
operands = fwd_value_to_block_argument_map[operands]
operands = return_map_value(
operands, fwd_value_to_block_argument_map
)
base_inputs[idx] = operands

for value, value_grad in zip(base_inputs, base_inputs_grad):
if value_grad is None:
continue

while value in state.inside_value_to_outside_value_map:
value = state.inside_value_to_outside_value_map[value]
value = return_map_value(
value, state.inside_value_to_outside_value_map
)

if value in state.value_to_valuegrad:
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(value)
append_add_n(
base_op,
value,
state,
backward_ops,
bwd_value_to_block_argument_map,
)
else:
new_value = return_map_value(
value, control_flow_value_to_copyvalue_map
Expand All @@ -550,23 +551,6 @@ def append_yield(

paddle.base.libpaddle.pir.cf_yield(inputs_grad)

def argument_to_value(while_op):
if while_op.name() != "pd_op.while":
return ValueDict(), ValueDict()

assert len(while_op.as_while_op().block_arguments()) + 1 == len(
while_op.operands_source()
), "while op's block_arguments size + 1 should same to whiel op's operands_source"
arg_to_value_map = ValueDict()
value_to_arg_map = ValueDict()
for arg, value in zip(
while_op.as_while_op().block_arguments(),
while_op.operands_source()[1:],
):
arg_to_value_map[arg] = value
value_to_arg_map[value] = arg
return arg_to_value_map, value_to_arg_map

# there are four patterns:
# [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType)
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType)
Expand Down Expand Up @@ -600,7 +584,7 @@ def argument_to_value(while_op):
else:
forward_ops = effective_forward_ops

if inplace_net(forward_ops):
if is_inplace_net(forward_ops):
inverse_effective_forward_ops = reversed(forward_ops)
else:
inverse_effective_forward_ops = inverse_sort_op(forward_ops)
Expand Down Expand Up @@ -716,7 +700,13 @@ def argument_to_value(while_op):
):
if input in sub_state.value_to_valuegrad:
if len(sub_state.value_to_valuegrad[input]) > 1:
append_add_n(input)
append_add_n(
op,
input,
state,
backward_ops,
bwd_value_to_block_argument_map,
)

if (
input not in sub_state.value_to_valuegrad
Expand Down Expand Up @@ -805,7 +795,13 @@ def argument_to_value(while_op):
if op.num_operands() == 0 and op.num_results() != 0:
for value in op.results():
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(value)
append_add_n(
op,
value,
state,
backward_ops,
bwd_value_to_block_argument_map,
)
else:
state.op_to_opgrad[op] = []
else:
Expand Down Expand Up @@ -896,7 +892,7 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
stop_gradient_false_outputs.append(output)
outputs_set = ValueSet(stop_gradient_false_outputs)

if inplace_net(total_ops):
if is_inplace_net(total_ops):
effective_forward_ops = total_ops
else:
effective_forward_ops, _ = prune_ops(
Expand Down Expand Up @@ -926,7 +922,7 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
)

remove_ops = []
if not inplace_net(backward_ops) and inputs:
if not is_inplace_net(backward_ops) and inputs:
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
Expand Down