@@ -31,34 +31,38 @@ template <typename Place, typename T>
3131class MulKernel : public framework ::OpKernel {
3232 public:
3333 void Compute (const framework::ExecutionContext& context) const override {
34- auto * X = context.Input <Tensor>(" X" );
35- auto * Y = context.Input <Tensor>(" Y" );
36- auto * Z = context.Output <Tensor>(" Out" );
37- Z ->mutable_data <T>(context.GetPlace ());
34+ auto * x = context.Input <Tensor>(" X" );
35+ auto * y = context.Input <Tensor>(" Y" );
36+ auto * z = context.Output <Tensor>(" Out" );
37+ z ->mutable_data <T>(context.GetPlace ());
3838 auto * device_context =
3939 const_cast <platform::DeviceContext*>(context.device_context_ );
40- math::matmul<Place, T>(*X , false , *Y , false , 1 , Z , 0 , device_context);
40+ math::matmul<Place, T>(*x , false , *y , false , 1 , z , 0 , device_context);
4141 }
4242};
4343
4444template <typename Place, typename T>
4545class MulGradKernel : public framework ::OpKernel {
4646 public:
4747 void Compute (const framework::ExecutionContext& ctx) const override {
48- auto * X = ctx.Input <Tensor>(" X" );
49- auto * Y = ctx.Input <Tensor>(" Y" );
50- auto * dOut = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
48+ auto * x = ctx.Input <Tensor>(" X" );
49+ auto * y = ctx.Input <Tensor>(" Y" );
50+ auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
5151
52- auto * dX = ctx.Output <Tensor>(framework::GradVarName (" X" ));
53- auto * dY = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
54- dX->mutable_data <T>(ctx.GetPlace ());
55- dY->mutable_data <T>(ctx.GetPlace ());
52+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
53+ auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
5654 auto * device_context =
5755 const_cast <platform::DeviceContext*>(ctx.device_context_ );
58- // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
59- math::matmul<Place, T>(*dOut, false , *Y, true , 1 , dX, 0 , device_context);
60- // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
61- math::matmul<Place, T>(*X, true , *dOut, false , 1 , dY, 0 , device_context);
56+ if (dx) {
57+ dx->mutable_data <T>(ctx.GetPlace ());
58+ // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
59+ math::matmul<Place, T>(*dout, false , *y, true , 1 , dx, 0 , device_context);
60+ }
61+ if (dy) {
62+ dy->mutable_data <T>(ctx.GetPlace ());
63+ // dy = x' * dout. dy K x N, dout : M x N, x : M x K
64+ math::matmul<Place, T>(*x, true , *dout, false , 1 , dy, 0 , device_context);
65+ }
6266 }
6367};
6468
0 commit comments