@@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations {
289289 framework::Tensor Div (const framework::Tensor& x,
290290 const framework::Tensor& y) {
291291 framework::Tensor ret;
292- std::vector<int > out_shape = GetBroadcastShape ({&x, &y});
293- ret.Resize (framework::make_ddim (out_shape));
294- ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
295- context, &x, &y, -1 , DivFunctor<T>(), &ret);
292+ if (x.type () != y.type ()) {
293+ ret.mutable_data <T>(x.dims (), context.GetPlace ());
294+ auto x_vector = EigenVector<T>::Flatten (x);
295+ auto y_vector = EigenVector<ValueType>::Flatten (y);
296+ auto out_vector = EigenVector<T>::Flatten (ret);
297+ auto & place =
298+ *context.template device_context <DeviceContext>().eigen_device ();
299+ out_vector.device (place) = x_vector / y_vector;
300+ } else {
301+ std::vector<int > out_shape = GetBroadcastShape ({&x, &y});
302+ ret.Resize (framework::make_ddim (out_shape));
303+ ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
304+ context, &x, &y, -1 , DivFunctor<T>(), &ret);
305+ }
296306 return ret;
297307 }
298308 framework::Tensor Add (const framework::Tensor& x,
@@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations {
330340 NameInTensorMap inputs ({{" X" , {&x}}});
331341 return CreateOpRunAndReturnTensor (" reduce_max" , inputs, attrs, out_dim);
332342 }
333-
343+ // Support float and complex type subtraction,the default is T type
344+ template <typename InT = T>
334345 framework::Tensor Sub (const framework::Tensor& x,
335346 const framework::Tensor& y) {
336347 framework::Tensor ret;
@@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations {
340351#if defined(__NVCC__) || defined(__HIPCC__)
341352 // For GPU, there is no need to define XxxInverseFunctor and call
342353 // ElementwiseComputeEx in two branches.
343- ElementwiseComputeEx<SubFunctor<T >, DeviceContext, T >(
344- context, &x, &y, -1 , SubFunctor<T >(), &ret);
354+ ElementwiseComputeEx<SubFunctor<InT >, DeviceContext, InT >(
355+ context, &x, &y, -1 , SubFunctor<InT >(), &ret);
345356#endif
346357 } else {
347358 if (x.dims ().size () >= y.dims ().size ()) {
348- ElementwiseComputeEx<SubFunctor<T >, DeviceContext, T >(
349- context, &x, &y, -1 , SubFunctor<T >(), &ret);
359+ ElementwiseComputeEx<SubFunctor<InT >, DeviceContext, InT >(
360+ context, &x, &y, -1 , SubFunctor<InT >(), &ret);
350361 } else {
351- ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
352- // This is copyed from elementwise_sub, which means we
353- // need reverse will xrank < yrank
354- context, &x, &y, -1 , InverseSubFunctor<T >(), &ret);
362+ // This is copyed from elementwise_sub, which means we
363+ // need reverse will xrank < yrank
364+ ElementwiseComputeEx<InverseSubFunctor<InT>, DeviceContext, InT>(
365+ context, &x, &y, -1 , InverseSubFunctor<InT >(), &ret);
355366 }
356367 }
357368 return ret;
@@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations {
461472 return out;
462473 }
463474
464- // Support x and y are different data types
465- Tensor Div_ (const Tensor& x, const Tensor& y) {
466- Tensor out;
467- out.mutable_data <T>(x.dims (), context.GetPlace ());
468- auto x_vector = EigenVector<T>::Flatten (x);
469- auto y_vector = EigenVector<ValueType>::Flatten (y);
470- auto out_vector = EigenVector<T>::Flatten (out);
471- auto & place =
472- *context.template device_context <DeviceContext>().eigen_device ();
473- out_vector.device (place) = x_vector / y_vector;
474- return out;
475- }
476-
477- framework::Tensor Sub_ (const framework::Tensor& x,
478- const framework::Tensor& y) {
479- framework::Tensor ret;
480- std::vector<int > out_shape = GetBroadcastShape ({&x, &y});
481- ret.Resize (framework::make_ddim (out_shape));
482- if (x.dims ().size () >= y.dims ().size ()) {
483- ElementwiseComputeEx<SubFunctor<ValueType>, DeviceContext, ValueType>(
484- context, &x, &y, -1 , SubFunctor<ValueType>(), &ret);
485- } else {
486- ElementwiseComputeEx<InverseSubFunctor<ValueType>, DeviceContext,
487- ValueType>(
488- // This is copyed from elementwise_sub, which means we
489- // need reverse will xrank < yrank
490- context, &x, &y, -1 , InverseSubFunctor<ValueType>(), &ret);
491- }
492- return ret;
493- }
494-
495475 private:
496476 const framework::ExecutionContext& context;
497477 BlasT<DeviceContext, T> GetBlas () {
0 commit comments