Skip to content

Commit afbc022

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into zyf_slice
2 parents 77160ea + dc439a1 commit afbc022

File tree

10 files changed

+780
-12
lines changed

10 files changed

+780
-12
lines changed

paddle/fluid/operators/activation_op_npu.cc

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the Licnse. */
1616
#include <string>
1717

1818
#include "paddle/fluid/framework/ddim.h"
19+
#include "paddle/fluid/framework/framework.pb.h"
1920
#include "paddle/fluid/framework/tensor_util.h"
2021
#include "paddle/fluid/operators/activation_op.h"
2122
#include "paddle/fluid/operators/npu_op_runner.h"
@@ -388,6 +389,155 @@ class SigmoidGradNPUKernel : public framework::OpKernel<T> {
388389
}
389390
};
390391

392+
// HardSwish = min(max(0, x+offset), threshold) * x / scale
393+
template <typename T>
394+
class HardSwishNPUKernel : public framework::OpKernel<T> {
395+
public:
396+
void Compute(const framework::ExecutionContext& ctx) const override {
397+
auto* x = ctx.Input<Tensor>("X");
398+
auto* out = ctx.Output<Tensor>("Out");
399+
400+
float threshold = ctx.Attr<float>("threshold");
401+
float scale = ctx.Attr<float>("scale");
402+
float offset = ctx.Attr<float>("offset");
403+
404+
auto place = ctx.GetPlace();
405+
406+
out->mutable_data<T>(place);
407+
408+
auto stream =
409+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
410+
.stream();
411+
412+
Tensor tensor_offset(x->type());
413+
tensor_offset.mutable_data<T>({1}, place);
414+
FillNpuTensorWithConstant<T>(&tensor_offset, static_cast<T>(offset));
415+
416+
Tensor add_offset_val(x->type());
417+
add_offset_val.mutable_data<T>(x->dims(), place);
418+
const auto& runner_add =
419+
NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val});
420+
runner_add.Run(stream);
421+
422+
Tensor tensor_threshold(x->type());
423+
tensor_threshold.mutable_data<T>({1}, place);
424+
FillNpuTensorWithConstant<T>(&tensor_threshold, static_cast<T>(threshold));
425+
426+
Tensor tensor_zero(x->type());
427+
tensor_zero.mutable_data<T>({1}, place);
428+
FillNpuTensorWithConstant<T>(&tensor_zero, static_cast<T>(0.0));
429+
430+
Tensor clip_val(x->type());
431+
clip_val.mutable_data<T>(x->dims(), place);
432+
const auto& runner_clip = NpuOpRunner(
433+
"ClipByValue", {add_offset_val, tensor_zero, tensor_threshold},
434+
{clip_val});
435+
runner_clip.Run(stream);
436+
437+
Tensor tensor_scale_tmp(x->type());
438+
tensor_scale_tmp.mutable_data<T>({1}, place);
439+
FillNpuTensorWithConstant<T>(&tensor_scale_tmp, static_cast<T>(scale));
440+
Tensor tensor_scale(x->type());
441+
tensor_scale.mutable_data<T>(x->dims(), place);
442+
const auto& runner_fill =
443+
NpuOpRunner("FillD", {tensor_scale_tmp}, {tensor_scale},
444+
{{"dims", framework::vectorize(x->dims())}});
445+
runner_fill.Run(stream);
446+
447+
Tensor div_val(x->type());
448+
div_val.mutable_data<T>(x->dims(), place);
449+
const auto& runner_div =
450+
NpuOpRunner("Div", {clip_val, tensor_scale}, {div_val});
451+
runner_div.Run(stream);
452+
453+
const auto& runner_mul = NpuOpRunner("Mul", {*x, div_val}, {*out});
454+
runner_mul.Run(stream);
455+
}
456+
};
457+
458+
template <typename T>
459+
class HardSwishGradNPUKernel : public framework::OpKernel<T> {
460+
public:
461+
void Compute(const framework::ExecutionContext& ctx) const override {
462+
auto* x = ctx.Input<Tensor>("X");
463+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
464+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
465+
466+
float threshold = ctx.Attr<float>("threshold");
467+
float scale = ctx.Attr<float>("scale");
468+
float offset = ctx.Attr<float>("offset");
469+
470+
auto place = ctx.GetPlace();
471+
472+
dx->mutable_data<T>(place);
473+
474+
auto stream =
475+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
476+
.stream();
477+
478+
Tensor tensor_offset(x->type());
479+
tensor_offset.mutable_data<T>({1}, place);
480+
FillNpuTensorWithConstant<T>(&tensor_offset, static_cast<T>(offset));
481+
482+
Tensor add_offset_val(x->type());
483+
add_offset_val.mutable_data<T>(x->dims(), place);
484+
const auto& runner_add =
485+
NpuOpRunner("AddV2", {*x, tensor_offset}, {add_offset_val});
486+
runner_add.Run(stream);
487+
488+
Tensor tmp1(x->type());
489+
tmp1.mutable_data<T>(x->dims(), place);
490+
const auto& runner_pow1 = NpuOpRunner("Power", {*x}, {tmp1},
491+
{{"scale", 2.0f}, {"shift", offset}});
492+
runner_pow1.Run(stream);
493+
494+
Tensor tmp2(x->type());
495+
tmp2.mutable_data<T>(x->dims(), place);
496+
const auto& runner_ht_grad =
497+
NpuOpRunner("HardtanhGrad", {add_offset_val, tmp1}, {tmp2},
498+
{{"min_val", 0.0f}, {"max_val", threshold}});
499+
runner_ht_grad.Run(stream);
500+
501+
Tensor tmp3(x->type());
502+
tmp3.mutable_data<T>(x->dims(), place);
503+
const auto& runner_pow2 = NpuOpRunner(
504+
"Power", {tmp2}, {tmp3}, {{"scale", 1.0f / scale}, {"shift", 1.0f}});
505+
runner_pow2.Run(stream);
506+
507+
Tensor tensor_threshold_tmp(x->type());
508+
tensor_threshold_tmp.mutable_data<T>({1}, place);
509+
FillNpuTensorWithConstant<T>(&tensor_threshold_tmp,
510+
static_cast<T>(threshold));
511+
Tensor tensor_threshold(x->type());
512+
tensor_threshold.mutable_data<T>(x->dims(), place);
513+
const auto& runner_fill =
514+
NpuOpRunner("FillD", {tensor_threshold_tmp}, {tensor_threshold},
515+
{{"dims", framework::vectorize(x->dims())}});
516+
runner_fill.Run(stream);
517+
518+
Tensor tmp_bool(framework::proto::VarType::BOOL);
519+
tmp_bool.mutable_data<bool>(x->dims(), place);
520+
const auto& runner_less =
521+
NpuOpRunner("Less", {add_offset_val, tensor_threshold}, {tmp_bool});
522+
runner_less.Run(stream);
523+
Tensor tmp4(x->type());
524+
tmp4.mutable_data<T>(x->dims(), place);
525+
auto dst_dtype = ConvertToNpuDtype(x->type());
526+
const auto& runner_cast =
527+
NpuOpRunner("Cast", {tmp_bool}, {tmp4},
528+
{{"dst_type", static_cast<int>(dst_dtype)}});
529+
runner_cast.Run(stream);
530+
531+
Tensor tmp5(x->type());
532+
tmp5.mutable_data<T>(x->dims(), place);
533+
const auto& runner_sub = NpuOpRunner("Sub", {tmp3, tmp4}, {tmp5});
534+
runner_sub.Run(stream);
535+
536+
const auto& runner_final = NpuOpRunner("Mul", {tmp5, *dout}, {*dx});
537+
runner_final.Run(stream);
538+
}
539+
};
540+
391541
template <typename DeviceContext, typename T>
392542
class HardSigmoidNPUKernel : public framework::OpKernel<T> {
393543
public:
@@ -677,6 +827,12 @@ REGISTER_OP_NPU_KERNEL(
677827
ops::SigmoidGradNPUKernel<paddle::platform::NPUDeviceContext,
678828
paddle::platform::float16>);
679829

830+
REGISTER_OP_NPU_KERNEL(hard_swish, ops::HardSwishNPUKernel<float>,
831+
ops::HardSwishNPUKernel<paddle::platform::float16>);
832+
833+
REGISTER_OP_NPU_KERNEL(hard_swish_grad, ops::HardSwishGradNPUKernel<float>,
834+
ops::HardSwishGradNPUKernel<paddle::platform::float16>);
835+
680836
REGISTER_OP_NPU_KERNEL(
681837
hard_sigmoid,
682838
ops::HardSigmoidNPUKernel<paddle::platform::NPUDeviceContext, float>,

paddle/fluid/operators/broadcast_tensors_op.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
3838

3939
int target_rank = 0;
4040
const auto& input_dims = ctx->GetInputsDim("X");
41+
4142
// 1. Find Output rank = max(Inputs rank)
4243
for (const auto& input_ddim : input_dims) {
4344
target_rank = std::max(target_rank, input_ddim.size());
@@ -64,6 +65,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
6465
dim_size = input_ddim[axis];
6566
}
6667

68+
if (target_dim_size != 1 && dim_size != 1 &&
69+
target_dim_size != dim_size) {
70+
PADDLE_THROW(platform::errors::InvalidArgument(
71+
"BroadcastTensorsOp inputs does not satisfy bcast semantics,"
72+
"Please check axis = %d in reverse order",
73+
index));
74+
}
75+
6776
// We performed bcast semantics check at python level
6877
// So input tensors should all have legal shape
6978
target_dim_size = std::max(target_dim_size, dim_size);

paddle/fluid/operators/dist_op.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ class DistOp : public framework::OperatorWithKernel {
2727
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dist");
2828
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Dist");
2929
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Dist");
30+
31+
auto x_dims = ctx->GetInputDim("X");
32+
auto y_dims = ctx->GetInputDim("Y");
33+
34+
PADDLE_ENFORCE_NE(framework::product(x_dims), 0,
35+
platform::errors::InvalidArgument(
36+
"The Input(X) has not been initialized properly. The "
37+
"shape of Input(X) = [%s].",
38+
x_dims));
39+
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
40+
platform::errors::InvalidArgument(
41+
"The Input(Y) has not been initialized properly. The "
42+
"shape of Input(Y) = [%s].",
43+
y_dims));
3044
ctx->SetOutputDim("Out", {1});
3145
}
3246
};

paddle/fluid/operators/increment_op_npu.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,5 @@ REGISTER_OP_NPU_KERNEL(
6464
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, float>,
6565
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, double>,
6666
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int>,
67-
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
6867
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext,
6968
plat::float16>)

paddle/scripts/paddle_build.sh

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -689,18 +689,18 @@ function get_precision_ut_mac() {
689689
on_precision=1
690690
re=$(cat ut_list|awk -F ' ' '{print }' | awk 'BEGIN{ all_str=""}{if (all_str==""){all_str=$1}else{all_str=all_str"$|^"$1}} END{print "^"all_str"$"}')
691691
UT_list_prec_1='ut_list_prec2'
692-
for case in $UT_list; do
693-
flag=$(echo $case|grep -oE $re)
692+
for ut_case in $UT_list; do
693+
flag=$(echo $ut_case|grep -oE $re)
694694
if [ -n "$flag" ];then
695695
if [ -z "$UT_list_prec" ];then
696-
UT_list_prec="^$case$"
696+
UT_list_prec="^$ut_case$"
697697
elif [[ "${#UT_list_prec}" -gt 10000 ]];then
698-
UT_list_prec_1="$UT_list_prec_1|^$case$"
698+
UT_list_prec_1="$UT_list_prec_1|^$ut_case$"
699699
else
700-
UT_list_prec="$UT_list_prec|^$case$"
700+
UT_list_prec="$UT_list_prec|^$ut_case$"
701701
fi
702702
else
703-
echo ${case} "won't run in PRECISION_TEST mode."
703+
echo ${ut_case} "won't run in PRECISION_TEST mode."
704704
fi
705705
done
706706
fi
@@ -722,6 +722,32 @@ function fetch_upstream_develop_if_not_exist() {
722722
fi
723723
}
724724

725+
function check_whl_size() {
726+
if [ ! "${pr_whl_size}" ];then
727+
echo "pr whl size not found "
728+
exit 1
729+
fi
730+
731+
set +x
732+
dev_whl_size=`du -m ${PADDLE_ROOT}/build/python/dist/*.whl|awk '{print $1}'`
733+
echo "dev_whl_size: ${dev_whl_size}"
734+
735+
whldiffSize=`expr ${pr_whl_size} - ${dev_whl_size}`
736+
if [ ${whldiffSize} -gt 10 ] ; then
737+
approval_line=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000`
738+
APPROVALS=`echo ${approval_line}|python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 22334008 22361972`
739+
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
740+
if [ "${APPROVALS}" == "FALSE" ]; then
741+
echo "=========================================================================================="
742+
echo "This PR make the release paddlepaddle whl size growth exceeds 10 M."
743+
echo "Then you must have one RD (jim19930609 (Recommend) or JiabinYang) approval for this PR\n"
744+
echo "=========================================================================================="
745+
exit 6
746+
fi
747+
fi
748+
set -x
749+
}
750+
725751
function generate_upstream_develop_api_spec() {
726752
fetch_upstream_develop_if_not_exist
727753
cur_branch=`git branch | grep \* | cut -d ' ' -f2`
@@ -730,6 +756,9 @@ function generate_upstream_develop_api_spec() {
730756
cmake_gen $1
731757
build $2
732758
cp ${PADDLE_ROOT}/python/requirements.txt /tmp
759+
pr_whl_size=`du -m ${PADDLE_ROOT}/build/python/dist/*.whl|awk '{print $1}'`
760+
echo "pr_whl_size: ${pr_whl_size}"
761+
733762

734763
git checkout $cur_branch
735764
generate_api_spec "$1" "DEV"
@@ -2234,6 +2263,7 @@ function main() {
22342263
example_code=$?
22352264
summary_check_problems $check_style_code $[${example_code_gpu} + ${example_code}] "$check_style_info" "${example_info_gpu}\n${example_info}"
22362265
assert_api_spec_approvals
2266+
check_whl_size
22372267
;;
22382268
build)
22392269
cmake_gen ${PYTHON_ABI:-""}

python/paddle/fluid/layers/nn.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7105,11 +7105,11 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
71057105

71067106

71077107
Parameters:
7108-
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_D]`, where :math:`N_1` is
7109-
the batch_size, :math:`N_D` is 1. It is usually the output predictions of sigmoid activation.
7110-
The data type can be float32 or float64.
7111-
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_D]`.
7112-
where :math:`N_1` is the batch_size, :math:`N_D` is 1. The data type can be float32 or float64.
7108+
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is
7109+
the batch_size, :math:`D` is the number of categories. It is usually the output
7110+
predictions of sigmoid activation. The data type can be float32 or float64.
7111+
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`.
7112+
where :math:`N_1` is the batch_size. The data type can be int32 or int64.
71137113
epsilon (float): The epsilon will be added to the numerator and denominator.
71147114
If both input and label are empty, it makes sure dice is 1.
71157115
Default: 0.00001
@@ -7131,6 +7131,21 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
71317131
predictions = F.softmax(x)
71327132
loss = F.dice_loss(input=predictions, label=label)
71337133
"""
7134+
assert input.dtype in (paddle.float32, paddle.float64)
7135+
assert label.dtype in (paddle.int32, paddle.int64)
7136+
assert len(input.shape) >= 2, \
7137+
"The rank of input should be greater than or equal to 2."
7138+
assert len(input.shape) == len(label.shape), (
7139+
"The rank of input and label should be equal, "
7140+
"but received input: %d, label: %d." %
7141+
(len(input.shape), len(label.shape)))
7142+
assert label.shape[-1] == 1, ("The last dimension of label should be 1, "
7143+
"but received %d." % label.shape[-1])
7144+
assert input.shape[:-1] == label.shape[:-1], (
7145+
"All dimensions should be equal except the last one.")
7146+
assert input.numel() > 0 and label.numel() > 0, \
7147+
"Any dimension of input and label cannot be equal to 0."
7148+
71347149
label = one_hot(label, depth=input.shape[-1])
71357150
reduce_dim = list(range(1, len(input.shape)))
71367151
inse = reduce_sum(input * label, dim=reduce_dim)

0 commit comments

Comments
 (0)