Skip to content

Commit 5b9f827

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into update_loss_scaling, test=kunlun
2 parents 261e47f + 0710f05 commit 5b9f827

75 files changed

Lines changed: 2419 additions & 846 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class {} : public egr::GradNodeBase {{
292292
#include "paddle/fluid/eager/utils.h"
293293
#include "paddle/fluid/eager/api/utils/global_utils.h"
294294
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
295+
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
295296
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
296297
#include "paddle/fluid/eager/nan_inf_utils.h"
297298
@@ -436,6 +437,20 @@ class {} : public egr::GradNodeBase {{
436437
}
437438

438439

440+
def ExtractForwardApiNameFormInvoke(invoke_config):
441+
api_name = invoke_config.split('(')[0]
442+
if api_name[-1] == '_':
443+
api_name = api_name[:-1]
444+
return re.search(
445+
r"(?P<api_name>[a-zA-Z0-9_]+)(?P<intermediate>_intermediate)?",
446+
api_name).group('api_name')
447+
448+
449+
def IsInvokeForwardApi(api_contents, forward_api_name_list):
450+
return 'invoke' in api_contents and ExtractForwardApiNameFormInvoke(
451+
api_contents['invoke']) in forward_api_name_list
452+
453+
439454
#######################
440455
## Generator Helpers ##
441456
#######################
@@ -478,7 +493,8 @@ def GenerateCoreOpInfoDefinition():
478493
#####################
479494
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
480495

481-
def __init__(self, forward_api_contents, grad_api_contents, namespace):
496+
def __init__(self, forward_api_contents, grad_api_contents,
497+
forward_apis_dict, namespace):
482498
self.forward_api_contents = forward_api_contents
483499
# Members from Parent:
484500
#self.namespace
@@ -495,6 +511,7 @@ def __init__(self, forward_api_contents, grad_api_contents, namespace):
495511
#self.forward_inplace_map
496512
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
497513

514+
self.forward_apis_dict = forward_apis_dict
498515
self.grad_api_contents = grad_api_contents
499516

500517
# Raw Contents
@@ -935,9 +952,11 @@ def run(self):
935952

936953
class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
937954

938-
def __init__(self, forward_api_contents, grad_api_contents, namespace):
955+
def __init__(self, forward_api_contents, grad_api_contents,
956+
forward_apis_dict, namespace):
939957
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
940-
grad_api_contents, namespace)
958+
grad_api_contents,
959+
forward_apis_dict, namespace)
941960

942961
# Generated Results
943962
self.forward_definition_str = ""
@@ -1299,10 +1318,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
12991318
def __init__(self,
13001319
forward_api_contents,
13011320
grad_api_contents,
1321+
forward_apis_dict,
13021322
namespace,
13031323
next_grad_api_contents=None):
13041324
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
1305-
grad_api_contents, namespace)
1325+
grad_api_contents,
1326+
forward_apis_dict, namespace)
13061327

13071328
# Record name mapping from forward_var_name to grad_var_names
13081329
self.to_next_grad_name_mapping = {} # {name : name}
@@ -1346,6 +1367,7 @@ def RecordGrad2NextGradNameMapping(self, next_node_generator):
13461367
def GenerateHigherOrderNodeCreationCode(self):
13471368
namespace = self.namespace
13481369
grad_api_contents = self.grad_api_contents
1370+
forward_apis_dict = self.forward_apis_dict
13491371
next_grad_api_contents = self.next_grad_api_contents
13501372

13511373
next_grad_node_creation_str = ""
@@ -1358,7 +1380,8 @@ def GenerateHigherOrderNodeCreationCode(self):
13581380
backward_api_contents = next_grad_api_contents
13591381

13601382
next_node_generator = DygraphFunctionGeneratorBase(
1361-
forward_api_contents, backward_api_contents, namespace)
1383+
forward_api_contents, backward_api_contents, forward_apis_dict,
1384+
namespace)
13621385
next_node_generator.run()
13631386
next_node_generator.GenerateNodeCreationCodes(True)
13641387

@@ -1443,6 +1466,8 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
14431466
backward_inplace_map = self.backward_inplace_map
14441467
indent = GetIndent(1)
14451468

1469+
is_invoke_forward_api = IsInvokeForwardApi(self.grad_api_contents,
1470+
self.forward_apis_dict)
14461471
# Construct grad_api function args
14471472
# Order: TensorWrappers, GradTensors, Attributes
14481473
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
@@ -1575,11 +1600,18 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
15751600
optional_inplace_str = ""
15761601
# Grad Outputs
15771602
out_index = -1
1603+
out_assign_str = ""
15781604
for name, (ttype, fwd_position,
15791605
grad_api_position) in backward_grad_outputs_map.items():
15801606
transformed_tensor_name = self.TransformToNextGradName(name)
15811607
out_index = out_index + 1
1582-
grad_api_args.append(f"api_output_{out_index}")
1608+
if is_invoke_forward_api:
1609+
if len(backward_grad_outputs_map) == 1:
1610+
out_assign_str += f"{indent}*api_output_{out_index} = api_output;\n"
1611+
else:
1612+
out_assign_str += f"{indent}*api_output_{out_index} = std::get<{out_index}>(api_output);\n"
1613+
else:
1614+
grad_api_args.append(f"api_output_{out_index}")
15831615
if inplace_grad_input_str in optional_inplace_var_name:
15841616
optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input: {inplace_grad_input_str}\";"
15851617
else:
@@ -1621,7 +1653,24 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
16211653

16221654
grad_api_args_str = ", ".join(grad_api_args)
16231655

1624-
grad_function_call_str = f"""
1656+
if is_invoke_forward_api:
1657+
autograd_api_out = "auto"
1658+
if len(self.backward_inplace_map) > 0 and len(
1659+
backward_grad_outputs_map) == 1:
1660+
autograd_api_out = "auto&"
1661+
forward_api_name = self.grad_api_contents['invoke'].split(
1662+
'(')[0].strip()
1663+
autograd_api = self.grad_api_contents['invoke'].replace(
1664+
forward_api_name, forward_api_name + '_dygraph_function', 1)
1665+
grad_function_call_str = f"""
1666+
if (trace_backward) {{
1667+
{indent}{autograd_api_out} api_output = {autograd_api};
1668+
{out_assign_str}}} else {{
1669+
{indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']};
1670+
{out_assign_str}{indent}}}
1671+
"""
1672+
else:
1673+
grad_function_call_str = f"""
16251674
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
16261675

16271676
# Check Nan and Inf
@@ -1631,7 +1680,7 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
16311680
# Prepare for Node Creation if Necessary
16321681
outputs_autograd_meta_str = ""
16331682
compute_require_next_grad_str = ""
1634-
if len(next_grad_node_creation_str) > 0:
1683+
if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api:
16351684
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
16361685

16371686
# 3. Get Output AutoGradMeta
@@ -1754,6 +1803,9 @@ def GetBackwardAPIContents(self, forward_api_contents):
17541803
def GenerateCode(self):
17551804
forward_api_list = self.forward_api_list
17561805
grad_api_dict = self.grad_api_dict
1806+
forward_apis_dict = {}
1807+
for api_item in forward_api_list:
1808+
forward_apis_dict[api_item['api']] = api_item
17571809
namespace = self.namespace
17581810

17591811
for forward_api_contents in forward_api_list:
@@ -1769,7 +1821,8 @@ def GenerateCode(self):
17691821

17701822
# Generate Dygraph Forward Function
17711823
function_generator = DygraphForwardFunctionGenerator(
1772-
forward_api_contents, backward_api_contents, namespace)
1824+
forward_api_contents, backward_api_contents, forward_apis_dict,
1825+
namespace)
17731826
function_generator.run()
17741827

17751828
self.forward_definition_str += function_generator.forward_definition_str + "\n"
@@ -1784,6 +1837,7 @@ def GenerateCode(self):
17841837

17851838
node_generator = DygraphNodeGenerator(forward_api_contents,
17861839
backward_api_contents,
1840+
forward_apis_dict,
17871841
namespace,
17881842
next_grad_api_contents)
17891843
node_generator.run()

paddle/fluid/eager/pylayer/py_layer_node.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ GradNodePyLayer::operator()(
104104
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
105105
"Get backward function faild."));
106106
}
107+
bool need_grad_tmp = egr::Controller::Instance().HasGrad();
108+
egr::Controller::Instance().SetHasGrad(create_graph && need_grad_tmp);
107109
auto outputs = PyObject_CallObject(backward_fn, backward_args);
110+
egr::Controller::Instance().SetHasGrad(need_grad_tmp);
108111
if (!outputs) {
109112
PADDLE_THROW(paddle::platform::errors::External(
110113
pybind11::detail::error_string().c_str()));

paddle/fluid/framework/infershape_utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
490490
infer_meta_context.EmplaceBackAttr(
491491
phi::Scalar(PADDLE_GET_CONST(int, attr)));
492492
break;
493+
case framework::proto::AttrType::LONG:
494+
infer_meta_context.EmplaceBackAttr(
495+
phi::Scalar(PADDLE_GET_CONST(int64_t, attr)));
496+
break;
493497
case framework::proto::AttrType::STRING:
494498
infer_meta_context.EmplaceBackAttr(
495499
phi::Scalar(PADDLE_GET_CONST(std::string, attr)));

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,6 +2753,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
27532753
phi_kernel_context->EmplaceBackAttr(std::move(
27542754
phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second))));
27552755
break;
2756+
case proto::AttrType::LONG:
2757+
phi_kernel_context->EmplaceBackAttr(std::move(
2758+
phi::Scalar(PADDLE_GET_CONST(int64_t, attr_iter->second))));
2759+
break;
27562760
case proto::AttrType::STRING:
27572761
phi_kernel_context->EmplaceBackAttr(std::move(phi::Scalar(
27582762
PADDLE_GET_CONST(std::string, attr_iter->second))));

paddle/fluid/imperative/prepared_operator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
420420
kernel_ctx->EmplaceBackAttr(
421421
std::move(phi::Scalar(PADDLE_GET_CONST(int, attr))));
422422
break;
423+
case framework::proto::AttrType::LONG:
424+
kernel_ctx->EmplaceBackAttr(
425+
std::move(phi::Scalar(PADDLE_GET_CONST(int64_t, attr))));
426+
break;
423427
case framework::proto::AttrType::STRING:
424428
kernel_ctx->EmplaceBackAttr(
425429
std::move(phi::Scalar(PADDLE_GET_CONST(std::string, attr))));

paddle/fluid/operators/assign_op_xpu.cc

Lines changed: 0 additions & 166 deletions
This file was deleted.

0 commit comments

Comments
 (0)