Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,14 @@ void FusedFeedForwardGradKernel(
bool add_residual,
int ring_id,
DenseTensor* x_grad,
DenseTensor* ln1_scale_grad,
DenseTensor* ln1_bias_grad,
DenseTensor* ln2_scale_grad,
DenseTensor* ln2_bias_grad,
DenseTensor* linear1_weight_grad,
DenseTensor* linear1_bias_grad,
DenseTensor* linear2_weight_grad,
DenseTensor* linear2_bias_grad) {
DenseTensor* linear2_bias_grad,
DenseTensor* ln1_scale_grad,
DenseTensor* ln1_bias_grad,
DenseTensor* ln2_scale_grad,
DenseTensor* ln2_bias_grad) {
using U = phi::funcs::LayerNormParamType<T>;

auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
Expand Down Expand Up @@ -672,9 +672,9 @@ PD_REGISTER_KERNEL(fused_feedforward_grad,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(8).SetDataType(phi::DataType::FLOAT32);
}
}
2 changes: 2 additions & 0 deletions paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ def validate_backward_inputs(


def validate_backward_outputs(op, forward_inputs, backward_outputs):
if op in ['fused_attention_grad']:
return
assert len(backward_outputs) <= len(
forward_inputs
), f"{op } has too many outputs"
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/ops_signature/fused_feedforward_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ KernelSignature FeedForwardGradFuseOpArgumentMapping(
"add_residual",
"ring_id"},
{"X@GRAD",
"Ln1Scale@GRAD",
"Ln1Bias@GRAD",
"Ln2Scale@GRAD",
"Ln2Bias@GRAD",
"Linear1Weight@GRAD",
"Linear1Bias@GRAD",
"Linear2Weight@GRAD",
"Linear2Bias@GRAD"});
"Linear2Bias@GRAD",
"Ln1Scale@GRAD",
"Ln1Bias@GRAD",
"Ln2Scale@GRAD",
"Ln2Bias@GRAD"});
}
} // namespace phi

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,10 +1119,11 @@ def get_input_grad_semantic(op_info, op_info_items):

# get backward op
bwd_op_name = op_info.backward_name
sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
input_grad_semantics = ["false" for i in range(num_inputs)]
else:
bwd_op_info = op_info_items[bwd_op_name]
bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]

# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding input
# thus determine whether each input has grad semantic
Expand Down Expand Up @@ -1153,12 +1154,13 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):

# get backward op
bwd_op_name = op_info.backward_name
sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
mutable_attribute_grad_semantics = [
"false" for i in range(len(fwd_mutable_attribute_list))
]
else:
bwd_op_info = op_info_items[bwd_op_name]
bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]

# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding attribute
# thus determine whether each attribute has grad semantic
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VJPS_BLACK_LIST = [
'reshape_grad',
'add_n_grad',
'fused_attention_grad',
]

BACKENDS_BLACK_LIST = [
Expand Down
119 changes: 119 additions & 0 deletions paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,124 @@ std::vector<std::vector<paddle::Tensor>> reshape_vjp(
return vjp_res;
}

std::vector<std::vector<paddle::Tensor>> fused_attention_vjp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个看起来好像可以codegen?

const Tensor& out_grad,
const Tensor& x,
const Tensor& qkv_weight,
const paddle::optional<Tensor>& qkv_bias,
const paddle::optional<Tensor>& qkv_bias_out,
const paddle::optional<Tensor>& src_mask,
const paddle::optional<Tensor>& src_mask_out,
const Tensor& out_linear_weight,
const paddle::optional<Tensor>& out_linear_bias,
const paddle::optional<Tensor>& ln_scale,
const paddle::optional<Tensor>& ln_bias,
const paddle::optional<Tensor>& ln_scale_2,
const paddle::optional<Tensor>& ln_bias_2,
const paddle::optional<Tensor>& ln_out,
const paddle::optional<Tensor>& ln_mean,
const paddle::optional<Tensor>& ln_var,
const paddle::optional<Tensor>& ln_mean_2,
const paddle::optional<Tensor>& ln_var_2,
const paddle::optional<Tensor>& bias_dropout_residual_out,
const Tensor& qkv_out,
const Tensor& transpose_out_2,
const Tensor& qk_out,
const Tensor& qktv_out,
const Tensor& softmax_out,
const Tensor& attn_dropout_mask_out,
const Tensor& attn_dropout_out,
const Tensor& fmha_out,
const Tensor& out_linear_out,
const Tensor& dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string& attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string& dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg : stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
auto op_res =
backend::fused_attention_grad<LazyTensor>(out_grad,
x,
qkv_weight,
qkv_bias,
qkv_bias_out,
src_mask,
src_mask_out,
out_linear_weight,
out_linear_bias,
ln_scale,
ln_bias,
ln_scale_2,
ln_bias_2,
ln_out,
ln_mean,
ln_var,
ln_mean_2,
ln_var_2,
bias_dropout_residual_out,
qkv_out,
transpose_out_2,
qk_out,
qktv_out,
softmax_out,
attn_dropout_mask_out,
attn_dropout_out,
fmha_out,
out_linear_out,
dropout_mask_out,
num_heads,
transpose_qkv_wb,
pre_layer_norm,
epsilon,
attn_dropout_rate,
is_test,
attn_dropout_fix_seed,
attn_dropout_seed,
attn_dropout_implementation,
dropout_rate,
dropout_fix_seed,
dropout_seed,
dropout_implementation,
ln_epsilon,
add_residual,
ring_id);
// x_grad
vjp_res[0][0] = std::get<8>(op_res);
// ln_scale_grad
vjp_res[1][0] = std::get<4>(op_res);
// ln_bias_grad
vjp_res[2][0] = std::get<5>(op_res);
// qkv_weight_grad
vjp_res[3][0] = std::get<9>(op_res);
// qkv_bias_grad
vjp_res[4][0] = std::get<0>(op_res);
// out_linear_weight_grad
vjp_res[5][0] = std::get<10>(op_res);
// out_linear_bias_grad
vjp_res[6][0] = std::get<3>(op_res);
// ln_scale_2_grad
vjp_res[7][0] = std::get<6>(op_res);
// ln_bias_2_grad
vjp_res[8][0] = std::get<7>(op_res);
vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients);
return vjp_res;
}

} // namespace primitive
} // namespace paddle
48 changes: 48 additions & 0 deletions paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,53 @@ std::vector<std::vector<paddle::Tensor>> reshape_vjp(
const Tensor& out_grad,
const std::vector<std::vector<bool>>& stop_gradients);

std::vector<std::vector<paddle::Tensor>> fused_attention_vjp(
const Tensor& out_grad,
const Tensor& x,
const Tensor& qkv_weight,
const paddle::optional<Tensor>& qkv_bias,
const paddle::optional<Tensor>& qkv_bias_out,
const paddle::optional<Tensor>& src_mask,
const paddle::optional<Tensor>& src_mask_out,
const Tensor& out_linear_weight,
const paddle::optional<Tensor>& out_linear_bias,
const paddle::optional<Tensor>& ln_scale,
const paddle::optional<Tensor>& ln_bias,
const paddle::optional<Tensor>& ln_scale_2,
const paddle::optional<Tensor>& ln_bias_2,
const paddle::optional<Tensor>& ln_out,
const paddle::optional<Tensor>& ln_mean,
const paddle::optional<Tensor>& ln_var,
const paddle::optional<Tensor>& ln_mean_2,
const paddle::optional<Tensor>& ln_var_2,
const paddle::optional<Tensor>& bias_dropout_residual_out,
const Tensor& qkv_out,
const Tensor& transpose_out_2,
const Tensor& qk_out,
const Tensor& qktv_out,
const Tensor& softmax_out,
const Tensor& attn_dropout_mask_out,
const Tensor& attn_dropout_out,
const Tensor& fmha_out,
const Tensor& out_linear_out,
const Tensor& dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string& attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string& dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
const std::vector<std::vector<bool>>& stop_gradients);

} // namespace primitive
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
if (dx) {
dx->share_meta(x);
}
if (dy) {
if (dy && y) {
dy->share_meta(y);
}
if (dz) {
Expand Down
37 changes: 19 additions & 18 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,7 @@ void FusedAttentionInferMeta(const MetaTensor& x,
}

out->set_dims(x.dims());
out->set_dtype(x.dtype());
}

void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
Expand Down Expand Up @@ -998,19 +999,19 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
"GradOp is only callable when is_test is false"));

if (!pre_layer_norm) {
if (ln_scale_2_grad) {
if (ln_scale_2_grad && ln_scale_2) {
ln_scale_2_grad->set_dims(ln_scale_2.dims());
}
if (ln_bias_2_grad) {
if (ln_bias_2_grad && ln_bias_2) {
ln_bias_2_grad->set_dims(ln_bias_2.dims());
}
}

if (pre_layer_norm) {
if (pre_layer_norm && ln_scale) {
if (ln_scale_grad) {
ln_scale_grad->set_dims(ln_scale.dims());
}
if (ln_bias_grad) {
if (ln_bias_grad && ln_bias) {
ln_bias_grad->set_dims(ln_bias.dims());
}
}
Expand All @@ -1019,7 +1020,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
x_grad->set_dims(x.dims());
}

if (out_linear_bias_grad) {
if (out_linear_bias_grad && out_linear_bias) {
out_linear_bias_grad->set_dims(out_linear_bias.dims());
}

Expand All @@ -1031,7 +1032,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
qkv_weight_grad->set_dims(qkv_weight.dims());
}

if (qkv_bias_grad) {
if (qkv_bias_grad && qkv_bias) {
qkv_bias_grad->set_dims(qkv_bias.dims());
}

Expand All @@ -1040,7 +1041,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
ln_out_grad->set_dims(ln_out.dims());
}
} else {
if (bias_dropout_residual_out_grad) {
if (bias_dropout_residual_out_grad && bias_dropout_residual_out) {
bias_dropout_residual_out_grad->set_dims(
bias_dropout_residual_out.dims());
}
Expand Down Expand Up @@ -1556,36 +1557,36 @@ void FusedFeedForwardGradInferMeta(const MetaTensor& out_grad,
bool add_residual,
int ring_id,
MetaTensor* x_grad,
MetaTensor* ln1_scale_grad,
MetaTensor* ln1_bias_grad,
MetaTensor* ln2_scale_grad,
MetaTensor* ln2_bias_grad,
MetaTensor* linear1_weight_grad,
MetaTensor* linear1_bias_grad,
MetaTensor* linear2_weight_grad,
MetaTensor* linear2_bias_grad) {
MetaTensor* linear2_bias_grad,
MetaTensor* ln1_scale_grad,
MetaTensor* ln1_bias_grad,
MetaTensor* ln2_scale_grad,
MetaTensor* ln2_bias_grad) {
auto d_out_dim = out_grad.dims();
x_grad->set_dims(d_out_dim);
if (ln1_scale_grad) {
if (ln1_scale_grad && ln1_scale) {
ln1_scale_grad->set_dims(ln1_scale.dims());
}
if (ln1_bias_grad) {
if (ln1_bias_grad && ln1_bias) {
ln1_bias_grad->set_dims(ln1_bias.dims());
}
if (ln2_scale_grad) {
if (ln2_scale_grad && ln2_scale) {
ln2_scale_grad->set_dims(ln2_scale.dims());
}
if (ln2_bias_grad) {
if (ln2_bias_grad && ln2_bias) {
ln2_bias_grad->set_dims(ln2_bias.dims());
}

linear1_weight_grad->set_dims(linear1_weight.dims());
if (linear1_bias_grad) {
if (linear1_bias_grad && linear1_bias) {
linear1_bias_grad->set_dims(linear1_bias.dims());
}

linear2_weight_grad->set_dims(linear2_weight.dims());
if (linear2_bias_grad) {
if (linear2_bias_grad && linear2_bias) {
linear2_bias_grad->set_dims(linear2_bias.dims());
}
}
Expand Down
Loading