Skip to content

Commit 36afd97

Browse files
0x45fco63oc
authored andcommitted
[PIR AMP]Fix some error for bert amp (PaddlePaddle#64497)
1 parent 6a0b7b9 commit 36afd97

13 files changed

Lines changed: 230 additions & 55 deletions

File tree

paddle/fluid/operators/fused/fused_feedforward_op.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -518,14 +518,14 @@ void FusedFeedForwardGradKernel(
518518
bool add_residual,
519519
int ring_id,
520520
DenseTensor* x_grad,
521-
DenseTensor* ln1_scale_grad,
522-
DenseTensor* ln1_bias_grad,
523-
DenseTensor* ln2_scale_grad,
524-
DenseTensor* ln2_bias_grad,
525521
DenseTensor* linear1_weight_grad,
526522
DenseTensor* linear1_bias_grad,
527523
DenseTensor* linear2_weight_grad,
528-
DenseTensor* linear2_bias_grad) {
524+
DenseTensor* linear2_bias_grad,
525+
DenseTensor* ln1_scale_grad,
526+
DenseTensor* ln1_bias_grad,
527+
DenseTensor* ln2_scale_grad,
528+
DenseTensor* ln2_bias_grad) {
529529
using U = phi::funcs::LayerNormParamType<T>;
530530

531531
auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
@@ -672,9 +672,9 @@ PD_REGISTER_KERNEL(fused_feedforward_grad,
672672
double,
673673
phi::dtype::float16) {
674674
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
675-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
676-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
677-
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
678-
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
675+
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
676+
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
677+
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
678+
kernel->OutputAt(8).SetDataType(phi::DataType::FLOAT32);
679679
}
680680
}

paddle/fluid/operators/generator/parse_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,8 @@ def validate_backward_inputs(
643643

644644

645645
def validate_backward_outputs(op, forward_inputs, backward_outputs):
646+
if op in ['fused_attention_grad']:
647+
return
646648
assert len(backward_outputs) <= len(
647649
forward_inputs
648650
), f"{op } has too many outputs"

paddle/fluid/operators/ops_signature/fused_feedforward_sig.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ KernelSignature FeedForwardGradFuseOpArgumentMapping(
8484
"add_residual",
8585
"ring_id"},
8686
{"X@GRAD",
87-
"Ln1Scale@GRAD",
88-
"Ln1Bias@GRAD",
89-
"Ln2Scale@GRAD",
90-
"Ln2Bias@GRAD",
9187
"Linear1Weight@GRAD",
9288
"Linear1Bias@GRAD",
9389
"Linear2Weight@GRAD",
94-
"Linear2Bias@GRAD"});
90+
"Linear2Bias@GRAD",
91+
"Ln1Scale@GRAD",
92+
"Ln1Bias@GRAD",
93+
"Ln2Scale@GRAD",
94+
"Ln2Bias@GRAD"});
9595
}
9696
} // namespace phi
9797

paddle/fluid/pir/dialect/op_generator/op_gen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,10 +1119,11 @@ def get_input_grad_semantic(op_info, op_info_items):
11191119

11201120
# get backward op
11211121
bwd_op_name = op_info.backward_name
1122+
sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
11221123
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
11231124
input_grad_semantics = ["false" for i in range(num_inputs)]
11241125
else:
1125-
bwd_op_info = op_info_items[bwd_op_name]
1126+
bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]
11261127

11271128
# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding input
11281129
# thus determine whether each input has grad semantic
@@ -1153,12 +1154,13 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
11531154

11541155
# get backward op
11551156
bwd_op_name = op_info.backward_name
1157+
sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
11561158
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
11571159
mutable_attribute_grad_semantics = [
11581160
"false" for i in range(len(fwd_mutable_attribute_list))
11591161
]
11601162
else:
1161-
bwd_op_info = op_info_items[bwd_op_name]
1163+
bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]
11621164

11631165
# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding attribute
11641166
# thus determine whether each attribute has grad semantic

paddle/fluid/primitive/codegen/gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
VJPS_BLACK_LIST = [
4141
'reshape_grad',
4242
'add_n_grad',
43+
'fused_attention_grad',
4344
]
4445

4546
BACKENDS_BLACK_LIST = [

paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,124 @@ std::vector<std::vector<paddle::Tensor>> reshape_vjp(
6767
return vjp_res;
6868
}
6969

70+
std::vector<std::vector<paddle::Tensor>> fused_attention_vjp(
71+
const Tensor& out_grad,
72+
const Tensor& x,
73+
const Tensor& qkv_weight,
74+
const paddle::optional<Tensor>& qkv_bias,
75+
const paddle::optional<Tensor>& qkv_bias_out,
76+
const paddle::optional<Tensor>& src_mask,
77+
const paddle::optional<Tensor>& src_mask_out,
78+
const Tensor& out_linear_weight,
79+
const paddle::optional<Tensor>& out_linear_bias,
80+
const paddle::optional<Tensor>& ln_scale,
81+
const paddle::optional<Tensor>& ln_bias,
82+
const paddle::optional<Tensor>& ln_scale_2,
83+
const paddle::optional<Tensor>& ln_bias_2,
84+
const paddle::optional<Tensor>& ln_out,
85+
const paddle::optional<Tensor>& ln_mean,
86+
const paddle::optional<Tensor>& ln_var,
87+
const paddle::optional<Tensor>& ln_mean_2,
88+
const paddle::optional<Tensor>& ln_var_2,
89+
const paddle::optional<Tensor>& bias_dropout_residual_out,
90+
const Tensor& qkv_out,
91+
const Tensor& transpose_out_2,
92+
const Tensor& qk_out,
93+
const Tensor& qktv_out,
94+
const Tensor& softmax_out,
95+
const Tensor& attn_dropout_mask_out,
96+
const Tensor& attn_dropout_out,
97+
const Tensor& fmha_out,
98+
const Tensor& out_linear_out,
99+
const Tensor& dropout_mask_out,
100+
int num_heads,
101+
bool transpose_qkv_wb,
102+
bool pre_layer_norm,
103+
float epsilon,
104+
float attn_dropout_rate,
105+
bool is_test,
106+
bool attn_dropout_fix_seed,
107+
int attn_dropout_seed,
108+
const std::string& attn_dropout_implementation,
109+
float dropout_rate,
110+
bool dropout_fix_seed,
111+
int dropout_seed,
112+
const std::string& dropout_implementation,
113+
float ln_epsilon,
114+
bool add_residual,
115+
int ring_id,
116+
const std::vector<std::vector<bool>>& stop_gradients) {
117+
std::vector<std::vector<paddle::Tensor>> vjp_res;
118+
for (auto arg : stop_gradients) {
119+
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
120+
}
121+
auto op_res =
122+
backend::fused_attention_grad<LazyTensor>(out_grad,
123+
x,
124+
qkv_weight,
125+
qkv_bias,
126+
qkv_bias_out,
127+
src_mask,
128+
src_mask_out,
129+
out_linear_weight,
130+
out_linear_bias,
131+
ln_scale,
132+
ln_bias,
133+
ln_scale_2,
134+
ln_bias_2,
135+
ln_out,
136+
ln_mean,
137+
ln_var,
138+
ln_mean_2,
139+
ln_var_2,
140+
bias_dropout_residual_out,
141+
qkv_out,
142+
transpose_out_2,
143+
qk_out,
144+
qktv_out,
145+
softmax_out,
146+
attn_dropout_mask_out,
147+
attn_dropout_out,
148+
fmha_out,
149+
out_linear_out,
150+
dropout_mask_out,
151+
num_heads,
152+
transpose_qkv_wb,
153+
pre_layer_norm,
154+
epsilon,
155+
attn_dropout_rate,
156+
is_test,
157+
attn_dropout_fix_seed,
158+
attn_dropout_seed,
159+
attn_dropout_implementation,
160+
dropout_rate,
161+
dropout_fix_seed,
162+
dropout_seed,
163+
dropout_implementation,
164+
ln_epsilon,
165+
add_residual,
166+
ring_id);
167+
// x_grad
168+
vjp_res[0][0] = std::get<8>(op_res);
169+
// ln_scale_grad
170+
vjp_res[1][0] = std::get<4>(op_res);
171+
// ln_bias_grad
172+
vjp_res[2][0] = std::get<5>(op_res);
173+
// qkv_weight_grad
174+
vjp_res[3][0] = std::get<9>(op_res);
175+
// qkv_bias_grad
176+
vjp_res[4][0] = std::get<0>(op_res);
177+
// out_linear_weight_grad
178+
vjp_res[5][0] = std::get<10>(op_res);
179+
// out_linear_bias_grad
180+
vjp_res[6][0] = std::get<3>(op_res);
181+
// ln_scale_2_grad
182+
vjp_res[7][0] = std::get<6>(op_res);
183+
// ln_bias_2_grad
184+
vjp_res[8][0] = std::get<7>(op_res);
185+
vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients);
186+
return vjp_res;
187+
}
188+
70189
} // namespace primitive
71190
} // namespace paddle

paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,53 @@ std::vector<std::vector<paddle::Tensor>> reshape_vjp(
3333
const Tensor& out_grad,
3434
const std::vector<std::vector<bool>>& stop_gradients);
3535

36+
std::vector<std::vector<paddle::Tensor>> fused_attention_vjp(
37+
const Tensor& out_grad,
38+
const Tensor& x,
39+
const Tensor& qkv_weight,
40+
const paddle::optional<Tensor>& qkv_bias,
41+
const paddle::optional<Tensor>& qkv_bias_out,
42+
const paddle::optional<Tensor>& src_mask,
43+
const paddle::optional<Tensor>& src_mask_out,
44+
const Tensor& out_linear_weight,
45+
const paddle::optional<Tensor>& out_linear_bias,
46+
const paddle::optional<Tensor>& ln_scale,
47+
const paddle::optional<Tensor>& ln_bias,
48+
const paddle::optional<Tensor>& ln_scale_2,
49+
const paddle::optional<Tensor>& ln_bias_2,
50+
const paddle::optional<Tensor>& ln_out,
51+
const paddle::optional<Tensor>& ln_mean,
52+
const paddle::optional<Tensor>& ln_var,
53+
const paddle::optional<Tensor>& ln_mean_2,
54+
const paddle::optional<Tensor>& ln_var_2,
55+
const paddle::optional<Tensor>& bias_dropout_residual_out,
56+
const Tensor& qkv_out,
57+
const Tensor& transpose_out_2,
58+
const Tensor& qk_out,
59+
const Tensor& qktv_out,
60+
const Tensor& softmax_out,
61+
const Tensor& attn_dropout_mask_out,
62+
const Tensor& attn_dropout_out,
63+
const Tensor& fmha_out,
64+
const Tensor& out_linear_out,
65+
const Tensor& dropout_mask_out,
66+
int num_heads,
67+
bool transpose_qkv_wb,
68+
bool pre_layer_norm,
69+
float epsilon,
70+
float attn_dropout_rate,
71+
bool is_test,
72+
bool attn_dropout_fix_seed,
73+
int attn_dropout_seed,
74+
const std::string& attn_dropout_implementation,
75+
float dropout_rate,
76+
bool dropout_fix_seed,
77+
int dropout_seed,
78+
const std::string& dropout_implementation,
79+
float ln_epsilon,
80+
bool add_residual,
81+
int ring_id,
82+
const std::vector<std::vector<bool>>& stop_gradients);
83+
3684
} // namespace primitive
3785
} // namespace paddle

paddle/phi/infermeta/backward.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
510510
if (dx) {
511511
dx->share_meta(x);
512512
}
513-
if (dy) {
513+
if (dy && y) {
514514
dy->share_meta(y);
515515
}
516516
if (dz) {

paddle/phi/infermeta/fusion.cc

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,7 @@ void FusedAttentionInferMeta(const MetaTensor& x,
924924
}
925925

926926
out->set_dims(x.dims());
927+
out->set_dtype(x.dtype());
927928
}
928929

929930
void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
@@ -998,19 +999,19 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
998999
"GradOp is only callable when is_test is false"));
9991000

10001001
if (!pre_layer_norm) {
1001-
if (ln_scale_2_grad) {
1002+
if (ln_scale_2_grad && ln_scale_2) {
10021003
ln_scale_2_grad->set_dims(ln_scale_2.dims());
10031004
}
1004-
if (ln_bias_2_grad) {
1005+
if (ln_bias_2_grad && ln_bias_2) {
10051006
ln_bias_2_grad->set_dims(ln_bias_2.dims());
10061007
}
10071008
}
10081009

1009-
if (pre_layer_norm) {
1010+
if (pre_layer_norm && ln_scale) {
10101011
if (ln_scale_grad) {
10111012
ln_scale_grad->set_dims(ln_scale.dims());
10121013
}
1013-
if (ln_bias_grad) {
1014+
if (ln_bias_grad && ln_bias) {
10141015
ln_bias_grad->set_dims(ln_bias.dims());
10151016
}
10161017
}
@@ -1019,7 +1020,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10191020
x_grad->set_dims(x.dims());
10201021
}
10211022

1022-
if (out_linear_bias_grad) {
1023+
if (out_linear_bias_grad && out_linear_bias) {
10231024
out_linear_bias_grad->set_dims(out_linear_bias.dims());
10241025
}
10251026

@@ -1031,7 +1032,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10311032
qkv_weight_grad->set_dims(qkv_weight.dims());
10321033
}
10331034

1034-
if (qkv_bias_grad) {
1035+
if (qkv_bias_grad && qkv_bias) {
10351036
qkv_bias_grad->set_dims(qkv_bias.dims());
10361037
}
10371038

@@ -1040,7 +1041,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10401041
ln_out_grad->set_dims(ln_out.dims());
10411042
}
10421043
} else {
1043-
if (bias_dropout_residual_out_grad) {
1044+
if (bias_dropout_residual_out_grad && bias_dropout_residual_out) {
10441045
bias_dropout_residual_out_grad->set_dims(
10451046
bias_dropout_residual_out.dims());
10461047
}
@@ -1556,36 +1557,36 @@ void FusedFeedForwardGradInferMeta(const MetaTensor& out_grad,
15561557
bool add_residual,
15571558
int ring_id,
15581559
MetaTensor* x_grad,
1559-
MetaTensor* ln1_scale_grad,
1560-
MetaTensor* ln1_bias_grad,
1561-
MetaTensor* ln2_scale_grad,
1562-
MetaTensor* ln2_bias_grad,
15631560
MetaTensor* linear1_weight_grad,
15641561
MetaTensor* linear1_bias_grad,
15651562
MetaTensor* linear2_weight_grad,
1566-
MetaTensor* linear2_bias_grad) {
1563+
MetaTensor* linear2_bias_grad,
1564+
MetaTensor* ln1_scale_grad,
1565+
MetaTensor* ln1_bias_grad,
1566+
MetaTensor* ln2_scale_grad,
1567+
MetaTensor* ln2_bias_grad) {
15671568
auto d_out_dim = out_grad.dims();
15681569
x_grad->set_dims(d_out_dim);
1569-
if (ln1_scale_grad) {
1570+
if (ln1_scale_grad && ln1_scale) {
15701571
ln1_scale_grad->set_dims(ln1_scale.dims());
15711572
}
1572-
if (ln1_bias_grad) {
1573+
if (ln1_bias_grad && ln1_bias) {
15731574
ln1_bias_grad->set_dims(ln1_bias.dims());
15741575
}
1575-
if (ln2_scale_grad) {
1576+
if (ln2_scale_grad && ln2_scale) {
15761577
ln2_scale_grad->set_dims(ln2_scale.dims());
15771578
}
1578-
if (ln2_bias_grad) {
1579+
if (ln2_bias_grad && ln2_bias) {
15791580
ln2_bias_grad->set_dims(ln2_bias.dims());
15801581
}
15811582

15821583
linear1_weight_grad->set_dims(linear1_weight.dims());
1583-
if (linear1_bias_grad) {
1584+
if (linear1_bias_grad && linear1_bias) {
15841585
linear1_bias_grad->set_dims(linear1_bias.dims());
15851586
}
15861587

15871588
linear2_weight_grad->set_dims(linear2_weight.dims());
1588-
if (linear2_bias_grad) {
1589+
if (linear2_bias_grad && linear2_bias) {
15891590
linear2_bias_grad->set_dims(linear2_bias.dims());
15901591
}
15911592
}

0 commit comments

Comments
 (0)