@@ -285,6 +285,8 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, float>;
285285template struct SelectedRowsAddToTensor <platform::CPUDeviceContext, double >;
286286template struct SelectedRowsAddToTensor <platform::CPUDeviceContext, int >;
287287template struct SelectedRowsAddToTensor <platform::CPUDeviceContext, int64_t >;
288+ template struct SelectedRowsAddToTensor <platform::CPUDeviceContext,
289+ platform::bfloat16>;
288290
289291// This is a separated namespace for manipulate SelectedRows typed
290292// data. Like merge duplicated rows, adding two SelectedRows etc.
@@ -294,21 +296,17 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
294296// add or mul.
295297namespace scatter {
296298
297- template <typename DeviceContext, typename T>
298- typename std::enable_if<
299- std::is_floating_point<T>::value &&
300- std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
301- elementwise_add_to (const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
302- size_t data_len, const T* in, T* out) {
303- blas->AXPY (data_len, 1 ., in, out);
299+ template <typename T>
300+ typename std::enable_if<std::is_floating_point<T>::value>::type
301+ elementwise_add_to (BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
302+ const T* in, T* out) {
303+ blas->AXPY (data_len, T (1 .f ), in, out);
304304}
305305
306- template <typename DeviceContext, typename T>
307- typename std::enable_if<
308- !std::is_floating_point<T>::value &&
309- std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
310- elementwise_add_to (const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
311- size_t data_len, const T* in, T* out) {
306+ template <typename T>
307+ typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to (
308+ BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, const T* in,
309+ T* out) {
312310 for (size_t i = 0 ; i < data_len; i++) {
313311 out[i] += in[i];
314312 }
@@ -412,7 +410,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
412410 out.set_rows (merge_rows);
413411
414412 math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
415- constant_functor (context, out.mutable_value (), 0.0 );
413+ constant_functor (context, out.mutable_value (), static_cast <T>( 0 . f ) );
416414
417415 std::unordered_map<int64_t , size_t > rows_to_id;
418416 for (size_t i = 0 ; i < merge_rows.size (); ++i) {
@@ -429,9 +427,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
429427
430428 for (size_t i = 0 ; i < input_rows.size (); i++) {
431429 size_t out_i = rows_to_id[input_rows[i]];
432- elementwise_add_to<platform::CPUDeviceContext, T>(
433- context, &blas, static_cast < size_t >( input_width) ,
434- &input_data[i * input_width], &out_data[out_i * input_width]);
430+ elementwise_add_to<T>(&blas, static_cast < size_t >(input_width),
431+ &input_data[i * input_width] ,
432+ &out_data[out_i * input_width]);
435433 }
436434 }
437435 }
@@ -524,9 +522,9 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
524522
525523 for (size_t i = 0 ; i < input_rows.size (); i++) {
526524 size_t out_i = rows_to_id[input_rows[i]];
527- elementwise_add_to<platform::CPUDeviceContext, T>(
528- context, &blas, static_cast < size_t >( input_width) ,
529- &input_data[i * input_width], &out_data[out_i * input_width]);
525+ elementwise_add_to<T>(&blas, static_cast < size_t >(input_width),
526+ &input_data[i * input_width] ,
527+ &out_data[out_i * input_width]);
530528 }
531529 }
532530 size_t input_width_cast = static_cast <size_t >(input_width);
@@ -547,6 +545,8 @@ template struct MergeAdd<platform::CPUDeviceContext,
547545 paddle::platform::complex64>;
548546template struct MergeAdd <platform::CPUDeviceContext,
549547 paddle::platform::complex128>;
548+ template struct MergeAdd <platform::CPUDeviceContext,
549+ paddle::platform::bfloat16>;
550550
551551template struct MergeAverage <platform::CPUDeviceContext, int >;
552552template struct MergeAverage <platform::CPUDeviceContext, int64_t >;
0 commit comments