@@ -49,8 +49,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
4949 real scaleAB,
5050 real scaleT,
5151 bool aTrans,
52- bool bTrans,
53- bool cTrans) {
52+ bool bTrans) {
5453 CHECK_EQ (out.getValueType (), FLOAT_VALUE);
5554 if (scaleT == 0 ) {
5655 out.zeroMem ();
@@ -114,8 +113,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
114113 real scaleAB,
115114 real scaleT,
116115 bool aTrans,
117- bool bTrans,
118- bool cTrans) {
116+ bool bTrans) {
119117 GEMM (aTrans ? CblasTrans : CblasNoTrans,
120118 bTrans ? CblasTrans : CblasNoTrans,
121119 out.getHeight (),
@@ -139,8 +137,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
139137 real scaleAB,
140138 real scaleT,
141139 bool aTrans,
142- bool bTrans,
143- bool cTrans) {
140+ bool bTrans) {
144141 if (scaleT == 0 ) {
145142 out.zeroMem ();
146143 }
@@ -174,8 +171,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
174171 real scaleAB,
175172 real scaleT,
176173 bool aTrans,
177- bool bTrans,
178- bool cTrans) {
174+ bool bTrans) {
179175 if (scaleT == 0 ) {
180176 out.zeroMem ();
181177 }
@@ -222,10 +218,10 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
222218
223219/* *
224220 * mul operator
225- * out = scaleT * out + scaleAB * (in1 * in2 )
221+ * out = scaleT * out + scaleAB * (A * B )
226222 * here, scaleT in {0, 1}, scaleAB == 1,
227- * out = in1 (A) * in2 (B) , ASSIGN_TO
228- * out += in1 (A) * in2 (B) , ADD_TO
223+ * out = A * B , ASSIGN_TO
224+ * out += A * B , ADD_TO
229225 *
230226 *
231227 * \param outputs[0] output matrix (out), M * N,
@@ -253,15 +249,11 @@ template <DeviceType Device>
253249class MulFunc : public FunctionBase {
254250public:
255251 void init (const FuncConfig& config) override {
256- alpha_ = config.get <real>(" scaleAB" );
257- beta_ = config.get <real>(" scaleT" );
258252 aTrans_ = config.get <bool >(" aTrans" );
259253 bTrans_ = config.get <bool >(" bTrans" );
260- cTrans_ = config.get <bool >(" cTrans" );
261254 }
262255
263256 void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
264- CHECK (!cTrans_) << " output matrix should not be transposed" ;
265257 CHECK (!aTrans_ || !bTrans_)
266258 << " Not support both a and b are transpose matrices" ;
267259
@@ -281,10 +273,8 @@ class MulFunc : public FunctionBase {
281273 CHECK_EQ (aRow, outputs[0 ].shape ()[0 ]);
282274 CHECK_EQ (bCol, outputs[0 ].shape ()[1 ]);
283275
284- // / only support C = A * B or C += A * B
285- CHECK_EQ (alpha_, static_cast <real>(1.0 ));
286- CHECK ((beta_ == 0 && outputs[0 ].getArgType () == ASSIGN_TO) ||
287- (beta_ == 1 && outputs[0 ].getArgType () == ADD_TO));
276+ // / only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
277+ real scaleT = (outputs[0 ].getArgType () == ADD_TO) ? 1.0 : 0.0 ;
288278
289279 // / support dense = not both sparse * sparse
290280 // / or sparse = dense * dense
@@ -300,11 +290,10 @@ class MulFunc : public FunctionBase {
300290 MulOp<Device>(outMat,
301291 inputs[0 ].matrix <Device>(),
302292 inputs[1 ].matrix <Device>(),
303- alpha_,
304- beta_ ,
293+ 1.0 , // scaleAB
294+ scaleT ,
305295 aTrans_,
306- bTrans_,
307- cTrans_);
296+ bTrans_);
308297 return ;
309298 }
310299
@@ -315,11 +304,10 @@ class MulFunc : public FunctionBase {
315304 MulOp<Device>(outMat,
316305 inputs[0 ].matrix <Device>(),
317306 inputs[1 ].sparse ().SparseMatrix <Device>(),
318- alpha_,
319- beta_ ,
307+ 1.0 , // scaleAB
308+ scaleT ,
320309 aTrans_,
321- bTrans_,
322- cTrans_);
310+ bTrans_);
323311 return ;
324312 }
325313
@@ -332,11 +320,10 @@ class MulFunc : public FunctionBase {
332320 MulOp<Device>(outMat,
333321 inputs[0 ].sparse ().SparseMatrix <Device>(),
334322 inputs[1 ].matrix <Device>(),
335- alpha_,
336- beta_ ,
323+ 1.0 , // scaleAB
324+ scaleT ,
337325 aTrans_,
338- bTrans_,
339- cTrans_);
326+ bTrans_);
340327 return ;
341328 }
342329
@@ -347,21 +334,17 @@ class MulFunc : public FunctionBase {
347334 MulOp<Device>(outSparseMat,
348335 inputs[0 ].matrix <Device>(),
349336 inputs[1 ].matrix <Device>(),
350- alpha_,
351- beta_ ,
337+ 1.0 , // scaleAB
338+ scaleT ,
352339 aTrans_,
353- bTrans_,
354- cTrans_);
340+ bTrans_);
355341 return ;
356342 }
357343 }
358344
359345private:
360- real alpha_;
361- real beta_;
362346 bool aTrans_;
363347 bool bTrans_;
364- bool cTrans_;
365348};
366349
367350REGISTER_TYPED_FUNC (MulOp, CPU, MulFunc);
0 commit comments