Skip to content

Commit 5b1a5c1

Browse files
author
xutianbing
committed
Daoyuan's comments.
1 parent 999cd14 commit 5b1a5c1

File tree

5 files changed

+48
-100
lines changed

5 files changed

+48
-100
lines changed

paddle/function/FunctionTest.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class FunctionCompare {
7070
}
7171

7272
// output need only contains shape, do not contains data.
73-
void addOutputs(const BufferArg& output, ArgType argType = ADD_TO) {
73+
void addOutputs(const BufferArg& output, ArgType argType = ASSIGN_TO) {
7474
size_t size =
7575
output.shape().getElements() * sizeOfValuType(output.valueType());
7676
cpuMemory_.emplace_back(std::make_shared<CpuMemoryHandle>(size));

paddle/function/MulOp.cpp

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
4949
real scaleAB,
5050
real scaleT,
5151
bool aTrans,
52-
bool bTrans,
53-
bool cTrans) {
52+
bool bTrans) {
5453
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
5554
if (scaleT == 0) {
5655
out.zeroMem();
@@ -114,8 +113,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
114113
real scaleAB,
115114
real scaleT,
116115
bool aTrans,
117-
bool bTrans,
118-
bool cTrans) {
116+
bool bTrans) {
119117
GEMM(aTrans ? CblasTrans : CblasNoTrans,
120118
bTrans ? CblasTrans : CblasNoTrans,
121119
out.getHeight(),
@@ -139,8 +137,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
139137
real scaleAB,
140138
real scaleT,
141139
bool aTrans,
142-
bool bTrans,
143-
bool cTrans) {
140+
bool bTrans) {
144141
if (scaleT == 0) {
145142
out.zeroMem();
146143
}
@@ -174,8 +171,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
174171
real scaleAB,
175172
real scaleT,
176173
bool aTrans,
177-
bool bTrans,
178-
bool cTrans) {
174+
bool bTrans) {
179175
if (scaleT == 0) {
180176
out.zeroMem();
181177
}
@@ -222,10 +218,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
222218

223219
/**
224220
* mul operator
225-
* out = scaleT * out + scaleAB * (in1 * in2)
221+
* out = scaleT * out + scaleAB * (A * B)
226222
* here, scaleT in {0, 1}, scaleAB == 1,
227-
* out = in1 (A) * in2 (B), ASSIGN_TO
228-
* out += in1 (A) * in2 (B), ADD_TO
223+
* out = A * B, ASSIGN_TO
224+
* out += A * B, ADD_TO
229225
*
230226
*
231227
* \param outputs[0] output matrix (out), M * N,
@@ -253,15 +249,11 @@ template <DeviceType Device>
253249
class MulFunc : public FunctionBase {
254250
public:
255251
void init(const FuncConfig& config) override {
256-
alpha_ = config.get<real>("scaleAB");
257-
beta_ = config.get<real>("scaleT");
258252
aTrans_ = config.get<bool>("aTrans");
259253
bTrans_ = config.get<bool>("bTrans");
260-
cTrans_ = config.get<bool>("cTrans");
261254
}
262255

263256
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
264-
CHECK(!cTrans_) << "output matrix should not be transposed";
265257
CHECK(!aTrans_ || !bTrans_)
266258
<< "Not support both a and b are transpose matrices";
267259

@@ -281,10 +273,8 @@ class MulFunc : public FunctionBase {
281273
CHECK_EQ(aRow, outputs[0].shape()[0]);
282274
CHECK_EQ(bCol, outputs[0].shape()[1]);
283275

284-
/// only support C = A * B or C += A * B
285-
CHECK_EQ(alpha_, static_cast<real>(1.0));
286-
CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
287-
(beta_ == 1 && outputs[0].getArgType() == ADD_TO));
276+
/// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
277+
real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
288278

289279
/// support dense = not both sparse * sparse
290280
/// or sparse = dense * dense
@@ -300,11 +290,10 @@ class MulFunc : public FunctionBase {
300290
MulOp<Device>(outMat,
301291
inputs[0].matrix<Device>(),
302292
inputs[1].matrix<Device>(),
303-
alpha_,
304-
beta_,
293+
1.0, // scaleAB
294+
scaleT,
305295
aTrans_,
306-
bTrans_,
307-
cTrans_);
296+
bTrans_);
308297
return;
309298
}
310299

@@ -315,11 +304,10 @@ class MulFunc : public FunctionBase {
315304
MulOp<Device>(outMat,
316305
inputs[0].matrix<Device>(),
317306
inputs[1].sparse().SparseMatrix<Device>(),
318-
alpha_,
319-
beta_,
307+
1.0, // scaleAB
308+
scaleT,
320309
aTrans_,
321-
bTrans_,
322-
cTrans_);
310+
bTrans_);
323311
return;
324312
}
325313

@@ -332,11 +320,10 @@ class MulFunc : public FunctionBase {
332320
MulOp<Device>(outMat,
333321
inputs[0].sparse().SparseMatrix<Device>(),
334322
inputs[1].matrix<Device>(),
335-
alpha_,
336-
beta_,
323+
1.0, // scaleAB
324+
scaleT,
337325
aTrans_,
338-
bTrans_,
339-
cTrans_);
326+
bTrans_);
340327
return;
341328
}
342329

@@ -347,21 +334,17 @@ class MulFunc : public FunctionBase {
347334
MulOp<Device>(outSparseMat,
348335
inputs[0].matrix<Device>(),
349336
inputs[1].matrix<Device>(),
350-
alpha_,
351-
beta_,
337+
1.0, // scaleAB
338+
scaleT,
352339
aTrans_,
353-
bTrans_,
354-
cTrans_);
340+
bTrans_);
355341
return;
356342
}
357343
}
358344

359345
private:
360-
real alpha_;
361-
real beta_;
362346
bool aTrans_;
363347
bool bTrans_;
364-
bool cTrans_;
365348
};
366349

367350
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);

paddle/function/MulOp.h

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ void MulOp(CpuMatrix& out,
2727
real scaleAB,
2828
real scaleT,
2929
bool aTrans,
30-
bool bTrans,
31-
bool cTrans);
30+
bool bTrans);
3231

3332
/// CPU, dense matrix (+)= sparse matrix * dense matrix
3433
template <DeviceType DType>
@@ -38,8 +37,7 @@ void MulOp(CpuMatrix& out,
3837
real scaleAB,
3938
real scaleT,
4039
bool aTrans,
41-
bool bTrans,
42-
bool cTrans);
40+
bool bTrans);
4341

4442
/// CPU, dense matrix (+)= dense matrix * sparse matrix
4543
template <DeviceType DType>
@@ -49,8 +47,7 @@ void MulOp(CpuMatrix& out,
4947
real scaleAB,
5048
real scaleT,
5149
bool aTrans,
52-
bool bTrans,
53-
bool cTrans);
50+
bool bTrans);
5451

5552
/// CPU, sparse matrix (+)= dense matrix * dense matrix
5653
template <DeviceType DType>
@@ -60,8 +57,7 @@ void MulOp(CpuSparseMatrix& out,
6057
real scaleAB,
6158
real scaleT,
6259
bool aTrans,
63-
bool bTrans,
64-
bool cTrans);
60+
bool bTrans);
6561

6662
/// GPU, dense matrix (+)= dense matrix * dense matrix
6763
template <DeviceType DType>
@@ -71,8 +67,7 @@ void MulOp(GpuMatrix& out,
7167
real scaleAB,
7268
real scaleT,
7369
bool aTrans,
74-
bool bTrans,
75-
bool cTrans);
70+
bool bTrans);
7671

7772
/// GPU, dense matrix (+)= sparse matrix * dense matrix
7873
template <DeviceType DType>
@@ -82,8 +77,7 @@ void MulOp(GpuMatrix& out,
8277
real scaleAB,
8378
real scaleT,
8479
bool aTrans,
85-
bool bTrans,
86-
bool cTrans);
80+
bool bTrans);
8781

8882
/// GPU, dense matrix (+)= dense matrix * sparse matrix
8983
template <DeviceType DType>
@@ -93,8 +87,8 @@ void MulOp(GpuMatrix& out,
9387
real scaleAB,
9488
real scaleT,
9589
bool aTrans,
96-
bool bTrans,
97-
bool cTrans);
90+
bool bTrans);
91+
9892
/// GPU, sparse matrix (+)= dense matrix * dense matrix
9993
template <DeviceType DType>
10094
void MulOp(GpuSparseMatrix& out,
@@ -103,7 +97,6 @@ void MulOp(GpuSparseMatrix& out,
10397
real scaleAB,
10498
real scaleT,
10599
bool aTrans,
106-
bool bTrans,
107-
bool cTrans);
100+
bool bTrans);
108101

109102
} // namespace paddle

paddle/function/MulOpGpu.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
2626
real scaleAB,
2727
real scaleT,
2828
bool aTrans,
29-
bool bTrans,
30-
bool cTrans) {
29+
bool bTrans) {
3130
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
3231
hl_matrix_mul(const_cast<real*>(a.getData()),
3332
!aTrans ? HPPL_OP_N : HPPL_OP_T,
@@ -52,8 +51,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
5251
real scaleAB,
5352
real scaleT,
5453
bool aTrans,
55-
bool bTrans,
56-
bool cTrans) {
54+
bool bTrans) {
5755
CHECK(out.isContiguous());
5856
CHECK(b.isContiguous());
5957
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
@@ -77,8 +75,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
7775
real scaleAB,
7876
real scaleT,
7977
bool aTrans,
80-
bool bTrans,
81-
bool cTrans) {
78+
bool bTrans) {
8279
CHECK(out.isContiguous());
8380
CHECK(a.isContiguous());
8481
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
@@ -116,8 +113,7 @@ void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
116113
real scaleAB,
117114
real scaleT,
118115
bool aTrans,
119-
bool bTrans,
120-
bool cTrans) {
116+
bool bTrans) {
121117
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
122118
hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
123119
aTrans ? HPPL_OP_T : HPPL_OP_N,

0 commit comments

Comments
 (0)