@@ -15,7 +15,9 @@ limitations under the License. */
1515#include < memory>
1616#include < string>
1717
18+ #include " paddle/fluid/framework/tensor_util.h"
1819#include " paddle/fluid/operators/elementwise/elementwise_min_op.h"
20+ #include " paddle/fluid/operators/elementwise/elementwise_npu.h"
1921#include " paddle/fluid/operators/npu_op_runner.h"
2022
2123namespace paddle {
@@ -27,31 +29,199 @@ template <typename DeviceContext, typename T>
2729class ElementwiseMinNPUKernel : public framework ::OpKernel<T> {
2830 public:
2931 void Compute (const framework::ExecutionContext& ctx) const override {
32+ auto & dev_ctx =
33+ ctx.template device_context <paddle::platform::NPUDeviceContext>();
3034 auto * x = ctx.Input <Tensor>(" X" );
3135 auto * y = ctx.Input <Tensor>(" Y" );
3236
3337 auto * out = ctx.Output <Tensor>(" Out" );
34-
3538 auto place = ctx.GetPlace ();
3639
3740 out->mutable_data <T>(place);
3841
42+ int axis = ctx.Attr <int >(" axis" );
43+ bool direct_compute = false ;
44+ auto x_dims = x->dims ();
45+ auto y_dims = y->dims ();
46+ axis = (axis == -1 ? std::abs (x_dims.size () - y_dims.size ()) : axis);
47+ if (x_dims.size () >= y_dims.size ()) {
48+ direct_compute =
49+ y_dims == framework::slice_ddim (x_dims, axis, x_dims.size ());
50+ } else {
51+ direct_compute =
52+ x_dims == framework::slice_ddim (y_dims, axis, y_dims.size ());
53+ }
54+ Tensor transformed_x, transformed_y;
55+ if (direct_compute) {
56+ transformed_x.ShareDataWith (*x);
57+ transformed_y.ShareDataWith (*y);
58+ } else {
59+ NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
60+ &transformed_y);
61+ }
62+ const auto & runner =
63+ NpuOpRunner (" Minimum" , {transformed_x, transformed_y}, {*out}, {});
3964 auto stream =
4065 ctx.template device_context <paddle::platform::NPUDeviceContext>()
4166 .stream ();
42-
43- const auto & runner = NpuOpRunner (" Minimum" , {*x, *y}, {*out}, {});
4467 runner.Run (stream);
4568 }
4669};
4770
71+ template <typename DeviceContext, typename T>
72+ class ElementwiseMinGradNPUKernel : public framework ::OpKernel<T> {
73+ public:
74+ void Compute (const framework::ExecutionContext& ctx) const override {
75+ auto & dev_ctx =
76+ ctx.template device_context <paddle::platform::NPUDeviceContext>();
77+ auto * x = ctx.Input <Tensor>(" X" );
78+ auto * y = ctx.Input <Tensor>(" Y" );
79+ auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
80+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
81+ auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
82+ int axis = ctx.Attr <int >(" axis" );
83+ axis = (axis == -1 ? std::abs (x->dims ().size () - y->dims ().size ()) : axis);
84+ auto stream = dev_ctx.stream ();
85+ if (dx && dy) {
86+ // dx
87+ dx->mutable_data <T>(ctx.GetPlace ());
88+ Tensor tmp_x;
89+ tmp_x.ShareDataWith (*dx);
90+ if (dx->dims () != dout->dims ()) {
91+ std::vector<int > dst_dims_vec_x;
92+ std::vector<int > reduce_axes_x;
93+ auto src_dims_x = dx->dims ();
94+ auto dout_dims = dout->dims ();
95+
96+ int src_axis_x = (src_dims_x.size () < dout_dims.size () ? axis : 0 );
97+ for (int ax = 0 ; ax < dout_dims.size (); ++ax) {
98+ if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size ()) ||
99+ (dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1 )) {
100+ reduce_axes_x.push_back (ax);
101+ } else {
102+ dst_dims_vec_x.push_back (dout_dims[ax]);
103+ }
104+ }
105+ if (!reduce_axes_x.empty ()) {
106+ tmp_x.Resize (framework::make_ddim (dst_dims_vec_x));
107+ }
108+ }
109+ // dy
110+ dy->mutable_data <T>(ctx.GetPlace ());
111+ Tensor tmp_y;
112+ tmp_y.ShareDataWith (*dy);
113+ if (dy->dims () != dout->dims ()) {
114+ std::vector<int > dst_dims_vec_y;
115+ std::vector<int > reduce_axes_y;
116+ auto src_dims_y = dy->dims ();
117+ auto dout_dims = dout->dims ();
118+
119+ int src_axis_y = (src_dims_y.size () < dout_dims.size () ? axis : 0 );
120+ for (int ax = 0 ; ax < dout_dims.size (); ++ax) {
121+ if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size ()) ||
122+ (dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1 )) {
123+ reduce_axes_y.push_back (ax);
124+ } else {
125+ dst_dims_vec_y.push_back (dout_dims[ax]);
126+ }
127+ }
128+ if (!reduce_axes_y.empty ()) {
129+ tmp_y.Resize (framework::make_ddim (dst_dims_vec_y));
130+ }
131+ }
132+
133+ const auto & runner =
134+ NpuOpRunner (" MinimumGrad" , {*dout, *x, *y}, {tmp_x, tmp_y},
135+ {{" grad_x" , true }, {" grad_y" , true }});
136+ runner.Run (stream);
137+
138+ } else if (dx) {
139+ Tensor zero_tensor (dout->type ());
140+ zero_tensor.mutable_data <T>(y->dims (), ctx.GetPlace ());
141+ FillNpuTensorWithConstant<T>(&zero_tensor, static_cast <T>(0 ));
142+ // dx
143+ dx->mutable_data <T>(ctx.GetPlace ());
144+ Tensor tmp_x;
145+ tmp_x.ShareDataWith (*dx);
146+ if (dx->dims () != dout->dims ()) {
147+ std::vector<int > dst_dims_vec_x;
148+ std::vector<int > reduce_axes_x;
149+ auto src_dims_x = dx->dims ();
150+ auto dout_dims = dout->dims ();
151+
152+ int src_axis_x = (src_dims_x.size () < dout_dims.size () ? axis : 0 );
153+ for (int ax = 0 ; ax < dout_dims.size (); ++ax) {
154+ if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size ()) ||
155+ (dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1 )) {
156+ reduce_axes_x.push_back (ax);
157+ } else {
158+ dst_dims_vec_x.push_back (dout_dims[ax]);
159+ }
160+ }
161+ if (!reduce_axes_x.empty ()) {
162+ tmp_x.Resize (framework::make_ddim (dst_dims_vec_x));
163+ }
164+ }
165+
166+ const auto & runner =
167+ NpuOpRunner (" MinimumGrad" , {*dout, *x, *y}, {tmp_x, zero_tensor},
168+ {{" grad_x" , true }, {" grad_y" , true }});
169+ runner.Run (stream);
170+
171+ } else if (dy) {
172+ Tensor zero_tensor (dout->type ());
173+ zero_tensor.mutable_data <T>(x->dims (), ctx.GetPlace ());
174+ FillNpuTensorWithConstant<T>(&zero_tensor, static_cast <T>(0 ));
175+
176+ // dy
177+ dy->mutable_data <T>(ctx.GetPlace ());
178+ Tensor tmp_y;
179+ tmp_y.ShareDataWith (*dy);
180+ if (dy->dims () != dout->dims ()) {
181+ std::vector<int > dst_dims_vec_y;
182+ std::vector<int > reduce_axes_y;
183+ auto src_dims_y = dy->dims ();
184+ auto dout_dims = dout->dims ();
185+
186+ int src_axis_y = (src_dims_y.size () < dout_dims.size () ? axis : 0 );
187+ for (int ax = 0 ; ax < dout_dims.size (); ++ax) {
188+ if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size ()) ||
189+ (dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1 )) {
190+ reduce_axes_y.push_back (ax);
191+ } else {
192+ dst_dims_vec_y.push_back (dout_dims[ax]);
193+ }
194+ }
195+ if (!reduce_axes_y.empty ()) {
196+ tmp_y.Resize (framework::make_ddim (dst_dims_vec_y));
197+ }
198+ }
199+
200+ const auto & runner =
201+ NpuOpRunner (" MinimumGrad" , {*dout, *x, *y}, {zero_tensor, tmp_y},
202+ {{" grad_x" , true }, {" grad_y" , true }});
203+ runner.Run (stream);
204+
205+ } else {
206+ std::cout << " error" << std::endl;
207+ }
208+ }
209+ };
210+
48211} // namespace operators
49212} // namespace paddle
50213
51214namespace ops = paddle::operators;
215+ namespace plat = paddle::platform;
52216
53217REGISTER_OP_NPU_KERNEL (
54218 elementwise_min,
55219 ops::ElementwiseMinNPUKernel<paddle::platform::NPUDeviceContext, float >,
56220 ops::ElementwiseMinNPUKernel<paddle::platform::NPUDeviceContext,
57221 paddle::platform::float16>);
222+
223+ REGISTER_OP_NPU_KERNEL (
224+ elementwise_min_grad,
225+ ops::ElementwiseMinGradNPUKernel<paddle::platform::NPUDeviceContext, float >,
226+ ops::ElementwiseMinGradNPUKernel<paddle::platform::NPUDeviceContext,
227+ paddle::platform::float16>);
0 commit comments