Skip to content

Commit 7ab4a71

Browse files
authored
Optimize fuse_seqpool_cvm supports nncross and add nncross sync stats configuration (#52)
1 parent 9ca71a8 commit 7ab4a71

File tree

7 files changed

+141
-137
lines changed

7 files changed

+141
-137
lines changed

paddle/fluid/operators/batch_fc_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using framework::Tensor;
2929
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
3030
i += blockDim.x * gridDim.x)
3131

32-
const int CUDA_NUM_THREADS = 1024;
32+
const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS;
3333
static inline int GET_BLOCKS(const int N) {
3434
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
3535
}

paddle/fluid/operators/cross_norm_hadamard.cu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include <memory.h>
1717
#include "cub/cub.cuh"
1818
#include "paddle/fluid/operators/math/math_function.h"
19+
#include "paddle/fluid/platform/cuda_primitives.h"
20+
#include "paddle/fluid/platform/gpu_info.h"
1921

2022
#define NORM_POS(idx, row, col) (((idx)*block_cols + (col)) * ins_num + (row))
2123
#define SCALE_MEAN_POS(idx, col) ((idx)*block_cols + (col))
@@ -29,7 +31,7 @@ limitations under the License. */
2931
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
3032
i += blockDim.x * gridDim.x)
3133

32-
const int CUDA_NUM_THREADS = 1024;
34+
const int CUDA_NUM_THREADS = paddle::platform::PADDLE_CUDA_NUM_THREADS;
3335
static inline int GET_BLOCKS(const int N) {
3436
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
3537
}

paddle/fluid/operators/cross_norm_hadamard_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class CrossNormHadamardOpMaker : public framework::OpProtoAndCheckerMaker {
113113
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
114114
"'epsilon' should be between 0.0 and 0.001.");
115115
});
116+
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
117+
.SetDefault(false);
116118
AddOutput("Out", "Output tensor of cross_norm_hadamard_op operator.");
117119
AddOutput("CudaMeans", "Output tensor of cross_norm_hadamard_op operator.");
118120
AddOutput("CudaScales",

paddle/fluid/operators/cross_norm_hadamard_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel<T> {
8585
auto embed_dim = ctx.Attr<int64_t>("embed_dim");
8686
const float epsilon = ctx.Attr<float>("epsilon");
8787
const float dr = ctx.Attr<float>("summary_decay_rate");
88+
const bool need_sync_stats = ctx.Attr<bool>("sync_stats");
8889

8990
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
9091
auto* summary_grad =
@@ -173,7 +174,6 @@ class CrossNormHadamardOpCUDAKernel : public framework::OpKernel<T> {
173174
T* summary_input_data =
174175
ctx.Output<Tensor>("SummaryInput")->mutable_data<T>(ctx.GetPlace());
175176

176-
bool need_sync_stats = true;
177177
if (need_sync_stats) {
178178
#if defined(PADDLE_WITH_NCCL)
179179
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());

paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fused/fused_seqpool_cvm_op.h"
16-
16+
#include <string>
1717
namespace paddle {
1818
namespace operators {
1919

@@ -30,12 +30,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
3030
PADDLE_ENFORCE_EQ(
3131
cvm_dims.size(), 2UL,
3232
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
33-
PADDLE_ENFORCE_EQ(
34-
cvm_dims[1], 2UL,
35-
platform::errors::InvalidArgument("The 2nd dimension of "
36-
"Input(CVM) should be 2."));
33+
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
34+
"The 2nd dimension of "
35+
"Input(CVM) should be 2."));
3736

3837
auto ins_dims = ctx->GetInputsDim("X");
38+
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");
3939
const size_t num_inputs = ins_dims.size();
4040
std::vector<framework::DDim> outs_dims;
4141
outs_dims.resize(num_inputs);
@@ -69,7 +69,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel {
6969
if (ctx->Attrs().Get<bool>("use_cvm")) {
7070
out_dim = {-1, dims[rank - 1]};
7171
} else {
72-
out_dim = {-1, dims[rank - 1] - 2};
72+
out_dim = {-1, dims[rank - 1] - cvm_offset};
7373
}
7474
outs_dims[i] = framework::make_ddim(out_dim);
7575
}
@@ -111,6 +111,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
111111
AddAttr<float>("show_coeff", "(float, default 0.2)").SetDefault(0.2);
112112
AddAttr<float>("clk_coeff", "(float, default 1)").SetDefault(1);
113113
AddAttr<float>("threshold", "(float, default 0.96)").SetDefault(0.96);
114+
AddAttr<int>("cvm_offset", "(int, default 2)").SetDefault(2);
114115

115116
AddComment(R"DOC(
116117
Fuse multiple pairs of Sequence Pool and CVM Operator.
@@ -127,6 +128,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
127128
auto og_dims = ctx->GetInputsDim(framework::GradVarName("Out"));
128129
auto x_dims = ctx->GetInputsDim("X");
129130
auto cvm_dims = ctx->GetInputDim("CVM");
131+
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");
130132

131133
PADDLE_ENFORCE_EQ(
132134
cvm_dims.size(), 2,
@@ -151,7 +153,7 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel {
151153
} else {
152154
PADDLE_ENFORCE_EQ(
153155
og_dims[i][og_dims[i].size() - 1],
154-
x_dims[i][og_dims[i].size() - 1] - 2,
156+
x_dims[i][og_dims[i].size() - 1] - cvm_offset,
155157
platform::errors::InvalidArgument(
156158
"The dimension mismatch between Input(OUT@GRAD) and "
157159
"Input(X). Received Input(OUT@GRAD): input rank %u, "

0 commit comments

Comments
 (0)