Skip to content

Commit c16c32e

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into tile_op_npu
2 parents 849bd23 + fa16c21 commit c16c32e

31 files changed

+1258
-278
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ELSE ()
3535
ENDIF()
3636

3737
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
38-
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210729")
38+
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210804")
3939
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
4040
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
4141
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS grap
5959

6060
pass_library(graph_to_program_pass base)
6161
pass_library(graph_viz_pass base)
62-
pass_library(lock_free_optimize_pass base)
62+
pass_library(lock_free_optimize_pass base DEPS string_helper)
6363
pass_library(fc_fuse_pass inference)
6464
pass_library(map_matmul_to_mul_pass inference)
6565
pass_library(attention_lstm_fuse_pass inference)

paddle/fluid/imperative/tracer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ DECLARE_string(tracer_mkldnn_ops_off);
3030
namespace paddle {
3131
namespace imperative {
3232

33+
thread_local bool Tracer::has_grad_ = true;
34+
3335
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
3436

3537
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }

paddle/fluid/imperative/tracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ class Tracer {
118118
bool enable_program_desc_tracing_{false};
119119
std::unique_ptr<UniqueNameGenerator> generator_;
120120
platform::Place expected_place_;
121-
bool has_grad_{true};
122121
bool enable_autocast_{false};
123122
GarbageCollectorMap gcs_;
123+
static thread_local bool has_grad_;
124124
};
125125

126126
// To access static variable current_tracer

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
703703
return false;
704704
}
705705
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor
706-
if (desc.Input("Shape").size() >= 1 ||
707-
desc.Input("ShapeTensor").size() >= 1) {
706+
auto reshape_inputs = desc.Inputs();
707+
if (reshape_inputs.find("Shape") != reshape_inputs.end() ||
708+
reshape_inputs.find("ShapeTensor") != reshape_inputs.end()) {
708709
return false;
709710
}
710711
std::vector<int> shape =

paddle/fluid/operators/activation_op_npu.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,39 @@ class CosGradNPUKernel : public framework::OpKernel<T> {
527527
}
528528
};
529529

530+
template <typename DeviceContext, typename T>
531+
class AtanNPUKernel : public framework::OpKernel<T> {
532+
public:
533+
void Compute(const framework::ExecutionContext& ctx) const override {
534+
auto* x = ctx.Input<Tensor>("X");
535+
auto* out = ctx.Output<Tensor>("Out");
536+
auto place = ctx.GetPlace();
537+
out->mutable_data<T>(place);
538+
const auto& runner = NpuOpRunner("Atan", {*x}, {*out}, {});
539+
auto stream =
540+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
541+
.stream();
542+
runner.Run(stream);
543+
}
544+
};
545+
546+
template <typename DeviceContext, typename T>
547+
class AtanGradNPUKernel : public framework::OpKernel<T> {
548+
public:
549+
void Compute(const framework::ExecutionContext& ctx) const override {
550+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
551+
auto* x = ctx.Input<Tensor>("X");
552+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
553+
auto place = ctx.GetPlace();
554+
dx->mutable_data<T>(place);
555+
auto stream =
556+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
557+
.stream();
558+
const auto& runner_dx = NpuOpRunner("AtanGrad", {*x, *dout}, {*dx}, {});
559+
runner_dx.Run(stream);
560+
}
561+
};
562+
530563
} // namespace operators
531564
} // namespace paddle
532565

@@ -648,3 +681,14 @@ REGISTER_OP_NPU_KERNEL(
648681
cos_grad, ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
649682
ops::CosGradNPUKernel<paddle::platform::NPUDeviceContext,
650683
paddle::platform::float16>);
684+
685+
REGISTER_OP_NPU_KERNEL(
686+
atan, ops::AtanNPUKernel<paddle::platform::NPUDeviceContext, float>,
687+
ops::AtanNPUKernel<paddle::platform::NPUDeviceContext,
688+
paddle::platform::float16>);
689+
690+
REGISTER_OP_NPU_KERNEL(
691+
atan_grad,
692+
ops::AtanGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
693+
ops::AtanGradNPUKernel<paddle::platform::NPUDeviceContext,
694+
paddle::platform::float16>);

paddle/fluid/operators/collective/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ if(WITH_ASCEND_CL)
5959
DEPS send_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
6060
cc_test(recv_v2_op_npu_test SRCS recv_v2_op_npu_test.cc
6161
DEPS recv_v2_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
62+
cc_test(checknumeric SRCS checknumeric_npu_test.cc
63+
DEPS c_allreduce_sum_op ${COLLECTIVE_DEPS} ${COMMON_TEST_DEPS_FOR_HCOM})
6264
cc_test(c_sync_comm_stream_op_npu_test SRCS c_sync_comm_stream_op_npu_test.cc
6365
DEPS op_registry c_broadcast_op c_comm_init_hccl_op c_sync_comm_stream_op c_gen_hccl_id_op gen_hccl_id_op_helper ${COLLECTIVE_DEPS} ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor)
6466
cc_test(c_sync_calc_stream_op_npu_test SRCS c_sync_calc_stream_op_npu_test.cc

paddle/fluid/operators/collective/c_allreduce_op.h

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -121,35 +121,44 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
121121
};
122122

123123
#if defined(PADDLE_WITH_ASCEND_CL)
124-
// return true if found_inf_or_nan or return false;
125-
template <typename T>
126-
bool CheckNumerics(const framework::ExecutionContext& exe_ctx,
127-
aclrtStream stream, const paddle::framework::Tensor* in) {
128-
auto& dev_ctx =
129-
exe_ctx.template device_context<paddle::platform::NPUDeviceContext>();
124+
// return true if found_nan or return false;
125+
inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx,
126+
aclrtStream stream,
127+
const paddle::framework::Tensor* in) {
130128
using Tensor = paddle::framework::Tensor;
131129
Tensor out(in->type());
132-
out.Resize(in->dims());
133-
out.mutable_data<T>(dev_ctx.GetPlace());
134130

135-
bool found_inf_data = false;
131+
Tensor mean(in->type());
132+
mean.Resize({1});
133+
mean.mutable_data<float>(dev_ctx.GetPlace());
134+
std::vector<int> axes;
135+
for (int i = 0; i < in->dims().size(); ++i) {
136+
axes.push_back(i);
137+
}
136138

139+
std::vector<float> vec;
137140
try {
138-
const auto& runner =
139-
NpuOpRunner("CheckNumerics", {*in}, {out},
140-
{{"message", std::string("check_numberics")}});
141-
runner.Run(stream);
142-
dev_ctx.Wait();
143-
} catch (platform::EnforceNotMet& exception) {
144-
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
145-
found_inf_data = true;
141+
const auto& runner_mean = paddle::operators::NpuOpRunner(
142+
"ReduceMeanD", {*in}, {mean}, {{"axes", axes}, {"keep_dims", false}});
143+
TensorToVector(mean, dev_ctx, &vec);
146144
} catch (...) {
147-
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
148-
found_inf_data = true;
145+
LOG(WARNING) << "ContainsNan catch exception";
146+
return true;
147+
}
148+
149+
VLOG(4) << "reducemeand result:" << vec[0];
150+
if (std::isnan(static_cast<float>(vec[0]))) {
151+
LOG(WARNING) << "ContainsNan detects nan";
152+
return true;
153+
}
154+
155+
if (std::isinf(static_cast<float>(vec[0]))) {
156+
LOG(WARNING) << "ContainsNan detects inf";
149157
}
150158

151-
return found_inf_data;
159+
return false;
152160
}
161+
153162
#endif
154163

155164
template <ReduceType red_type, typename T>
@@ -216,22 +225,24 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
216225
framework::Tensor tmp;
217226
tmp.mutable_data<float>({8}, ctx.GetPlace());
218227

219-
bool check_numerics = false;
228+
bool found_nan = false;
220229

221230
auto d_type = in->type();
222231
switch (d_type) {
223-
case framework::proto::VarType::FP16:
232+
case framework::proto::VarType::FP16: {
233+
break;
234+
}
224235
case framework::proto::VarType::FP32: {
225236
VLOG(4) << "prepare to FoundNanInf";
226-
check_numerics = CheckNumerics<T>(ctx, dev_ctx->stream(), in);
227-
VLOG(4) << "check_numerics:" << check_numerics;
237+
found_nan = ContainsNan(*dev_ctx, dev_ctx->stream(), in);
238+
VLOG(4) << "check_numerics:" << found_nan;
228239
break;
229240
}
230241
default:
231242
break;
232243
}
233244

234-
if (check_numerics) {
245+
if (found_nan) {
235246
T inf = static_cast<T>(std::numeric_limits<float>::infinity());
236247
VLOG(4) << "fill input data constant inf";
237248
auto dims = in->dims();

paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ limitations under the License. */
3838
#include "paddle/fluid/platform/hccl_helper.h"
3939
#endif
4040

41+
// Node1: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=1 GLOG_v=4 RANK_ID=1
42+
// DEVICE_ID=1 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test
43+
// Node2: HCCL_WHITELIST_DISABLE=1 FLAGS_selected_npus=0 GLOG_v=4 RANK_ID=0
44+
// DEVICE_ID=0 ./paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test
45+
4146
namespace f = paddle::framework;
4247
namespace p = paddle::platform;
4348
namespace m = paddle::operators::math;
@@ -52,10 +57,11 @@ DECLARE_string(selected_npus);
5257
template <typename T>
5358
void PrintDebugInfo(const std::string preStr, const std::vector<T>& data) {
5459
std::string debugstring = "";
60+
std::cout << preStr << ":" << std::endl << debugstring;
5561
for (auto ele : data) {
56-
debugstring += std::to_string(ele) + std::string(",");
62+
std::cout << ele << " ";
5763
}
58-
VLOG(3) << preStr << ":" << std::endl << debugstring;
64+
std::cout << std::endl;
5965
}
6066

6167
void PrepareUniqueId(f::Scope* scope, const p::DeviceContext& ctx,
@@ -120,6 +126,7 @@ void Prepare(f::Scope* scope, const p::DeviceContext& ctx,
120126
ctx.Wait();
121127
}
122128

129+
template <typename T>
123130
void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
124131
int iter) {
125132
// init
@@ -130,10 +137,11 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
130137
int num1 = 3;
131138
int num2 = 128;
132139

133-
std::vector<float> init;
140+
std::vector<T> init;
134141
for (int64_t i = 0; i < num1 * num2; ++i) {
135-
init.push_back(1.0 + rank_id);
142+
init.push_back(static_cast<T>(1.0 + rank_id));
136143
}
144+
init[0] = static_cast<T>(std::numeric_limits<float>::quiet_NaN());
137145
PrintDebugInfo("input data", init);
138146

139147
auto place = ctx.GetPlace();
@@ -145,31 +153,33 @@ void TestHCCLAllReduceOp(f::Scope* scope, const p::DeviceContext& ctx,
145153
auto out = scope->Var("OutData");
146154
auto tensor_out = out->GetMutable<f::LoDTensor>();
147155
tensor_out->Resize({num1, num2});
148-
tensor_out->mutable_data<float>(place); // allocate
156+
tensor_out->mutable_data<T>(place); // allocate
149157
ctx.Wait();
150158

151159
// run
152160
f::AttributeMap attrs;
153161
attrs["tag"] = std::string("tagx_" + std::to_string(iter));
154162
attrs["ring_id"] = 0;
163+
attrs["use_calc_stream"] = 1;
155164

156165
auto op = f::OpRegistry::CreateOp("c_allreduce_sum", {{"X", {"Data"}}},
157166
{{"Out", {"OutData"}}}, attrs);
158-
159-
for (int i = 0; i < 10; i++) {
167+
for (int i = 0; i < 1; i++) {
160168
op->Run(*scope, place);
161169
}
162170
ctx.Wait();
163171

164-
std::vector<float> out_vec;
172+
std::vector<T> out_vec;
165173
TensorToVector(*tensor_out, ctx, &out_vec);
166174
ctx.Wait();
167175

168176
PrintDebugInfo("output data", out_vec);
169177

178+
float diff = static_cast<float>(out_vec[0]) - 65504;
179+
EXPECT_TRUE(diff < 0.1 && diff > -0.1);
170180
EXPECT_EQ(out_vec.size(), init.size());
171-
for (uint32_t i = 0; i < out_vec.size(); i++) {
172-
EXPECT_EQ(out_vec[i], 3.0);
181+
for (uint32_t i = 1; i < 10; i++) {
182+
EXPECT_EQ(out_vec[i], static_cast<paddle::platform::float16>(3.0));
173183
}
174184
}
175185

@@ -182,8 +192,7 @@ TEST(c_allreduce_sum, NPU) {
182192
// only support one device, if more than one device, use first default
183193
PrepareUniqueId(&scope, ctx, &hccl_id);
184194
Prepare(&scope, ctx, &hccl_id);
185-
for (int i = 0; i < 1; i++) {
186-
VLOG(2) << "iter num: " << i;
187-
TestHCCLAllReduceOp(&scope, ctx, i);
188-
}
195+
196+
TestHCCLAllReduceOp<paddle::platform::float16>(&scope, ctx, 1);
197+
// TestHCCLAllReduceOp<float>(&scope, ctx, 0);
189198
}

0 commit comments

Comments
 (0)