Skip to content

Commit 452fe18

Browse files
authored
c_softmax_with_cross_entropy support bf16 for xpu (#60472)
1 parent 1b5c8f3 commit 452fe18

File tree

4 files changed

+75
-30
lines changed

4 files changed

+75
-30
lines changed

paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
102102
// reduce last dim
103103
int dims[1] = {1};
104104
auto f = [](xpu::Context* ctx,
105-
const XPUType* x,
106-
XPUType* y,
105+
const T* x,
106+
T* y,
107107
const std::vector<int>& xdims,
108108
const std::vector<int>& reduce_dims) {
109-
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
109+
return xpu::reduce_max<XPUType>(ctx,
110+
reinterpret_cast<const XPUType*>(x),
111+
reinterpret_cast<XPUType*>(y),
112+
xdims,
113+
reduce_dims);
110114
};
111-
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
115+
ret = phi::XPUReduce<phi::XPUContext, T>(
112116
dev_ctx,
113117
logits_2d,
114118
std::vector<int64_t>(dims, dims + 1),
@@ -194,13 +198,17 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
194198
{
195199
int dims[1] = {1};
196200
auto f = [](xpu::Context* ctx,
197-
const XPUType* x,
198-
XPUType* y,
201+
const T* x,
202+
T* y,
199203
const std::vector<int>& xdims,
200204
const std::vector<int>& reduce_dims) {
201-
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
205+
return xpu::reduce_sum<XPUType>(ctx,
206+
reinterpret_cast<const XPUType*>(x),
207+
reinterpret_cast<XPUType*>(y),
208+
xdims,
209+
reduce_dims);
202210
};
203-
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
211+
ret = phi::XPUReduce<phi::XPUContext, T>(
204212
dev_ctx,
205213
softmax_2d,
206214
std::vector<int64_t>(dims, dims + 1),
@@ -323,13 +331,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
323331
{
324332
int dims[1] = {1};
325333
auto f = [](xpu::Context* ctx,
326-
const XPUType* x,
327-
XPUType* y,
334+
const T* x,
335+
T* y,
328336
const std::vector<int>& xdims,
329337
const std::vector<int>& reduce_dims) {
330-
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
338+
return xpu::reduce_max<XPUType>(ctx,
339+
reinterpret_cast<const XPUType*>(x),
340+
reinterpret_cast<XPUType*>(y),
341+
xdims,
342+
reduce_dims);
331343
};
332-
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
344+
ret = phi::XPUReduce<phi::XPUContext, T>(
333345
dev_ctx,
334346
logits_2d,
335347
std::vector<int64_t>(dims, dims + 1),
@@ -436,13 +448,17 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
436448
{
437449
int dims[1] = {1};
438450
auto f = [](xpu::Context* ctx,
439-
const XPUType* x,
440-
XPUType* y,
451+
const T* x,
452+
T* y,
441453
const std::vector<int>& xdims,
442454
const std::vector<int>& reduce_dims) {
443-
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
455+
return xpu::reduce_sum<XPUType>(ctx,
456+
reinterpret_cast<const XPUType*>(x),
457+
reinterpret_cast<XPUType*>(y),
458+
xdims,
459+
reduce_dims);
444460
};
445-
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
461+
ret = phi::XPUReduce<phi::XPUContext, T>(
446462
dev_ctx,
447463
softmax_2d,
448464
std::vector<int64_t>(dims, dims + 1),
@@ -567,9 +583,11 @@ PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy,
567583
XPU,
568584
ALL_LAYOUT,
569585
ops::CSoftmaxWithCrossEntropyOp,
570-
float) {}
586+
float,
587+
phi::dtype::bfloat16) {}
571588
PD_REGISTER_STRUCT_KERNEL(c_softmax_with_cross_entropy_grad,
572589
XPU,
573590
ALL_LAYOUT,
574591
ops::CSoftmaxWithCrossEntropyGrad,
575-
float) {}
592+
float,
593+
phi::dtype::bfloat16) {}

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ XPUOpMap& get_kl3_ops() {
143143
phi::DataType::BFLOAT16,
144144
phi::DataType::INT32,
145145
phi::DataType::INT64})},
146-
{"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
146+
{"c_softmax_with_cross_entropy",
147+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
147148
{"c_softmax_with_cross_entropy_grad",
148-
XPUKernelSet({phi::DataType::FLOAT32})},
149+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})},
149150
{"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
150151
{"c_split",
151152
XPUKernelSet({phi::DataType::FLOAT16,

test/xpu/collective_softmax_with_cross_entropy_op_xpu.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sys
1818

1919
import numpy as np
20+
from op_test import convert_float_to_uint16
2021
from test_collective_base_xpu import (
2122
DataTypeCast,
2223
TestCollectiveRunnerBase,
@@ -44,7 +45,7 @@ def get_model(self, main_prog, startup_program, rank):
4445
logits = data(
4546
name="Logits",
4647
shape=[self.batch_size, self.local_elements],
47-
dtype='float32',
48+
dtype=self.dtype,
4849
)
4950
label = data(
5051
name="Label", shape=[self.batch_size, 1], dtype='int32'
@@ -110,6 +111,7 @@ def run_trainer(self, args):
110111
self.initCommunicator(
111112
startup_prog, rank, self.nranks, True, current_endpoint, endpoints
112113
)
114+
self.dtype = args["dtype"]
113115
np_dtype = DataTypeCast(args["dtype"])
114116
loss, softmax = self.get_model(train_prog, startup_prog, rank)
115117
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
@@ -126,15 +128,23 @@ def run_trainer(self, args):
126128
dtype='int32',
127129
)
128130
# use FAKE loss_grad here, only to examine the correctness of grad func
129-
loss_grad = np.random.uniform(
131+
loss_grad_fp32 = np.random.uniform(
130132
low=-10.0, high=10.0, size=(self.batch_size, 1)
131-
).astype(np_dtype)
133+
).astype(np.float32)
134+
if args["dtype"] == "bfloat16":
135+
loss_grad = convert_float_to_uint16(loss_grad_fp32)
136+
else:
137+
loss_grad = loss_grad_fp32.astype(np_dtype)
132138

133139
# each xpu uses own half of logits
134140
np.random.seed(os.getpid())
135-
logits = np.random.uniform(
141+
logits_fp32 = np.random.uniform(
136142
low=-40.0, high=40.0, size=(self.batch_size, self.local_elements)
137-
).astype(np_dtype)
143+
).astype(np.float32)
144+
if args["dtype"] == "bfloat16":
145+
logits = convert_float_to_uint16(logits_fp32)
146+
else:
147+
logits = logits_fp32.astype(np_dtype)
138148
out = exe.run(
139149
train_prog,
140150
feed={'Logits': logits, 'Label': label, 'Loss@GRAD': loss_grad},

test/xpu/test_collective_softmax_with_cross_entropy_xpu.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
create_test_class,
2222
get_xpu_op_support_types,
2323
)
24+
from op_test import convert_uint16_to_float
2425
from test_collective_base_xpu import DataTypeCast, TestDistBase
2526

2627
import paddle
@@ -154,15 +155,30 @@ def check_with_place(
154155
# get real result
155156
loss0, softmax0, logits_grad0 = tr0_out
156157
loss1, softmax1, logits_grad1 = tr1_out
158+
if dtype == "bfloat16":
159+
loss0 = convert_uint16_to_float(loss0)
160+
softmax0 = convert_uint16_to_float(softmax0)
161+
logits_grad0 = convert_uint16_to_float(logits_grad0)
162+
loss1 = convert_uint16_to_float(loss1)
163+
softmax1 = convert_uint16_to_float(softmax1)
164+
logits_grad1 = convert_uint16_to_float(logits_grad1)
157165
softmax = np.concatenate((softmax0, softmax1), axis=1)
158166
logits_grad = np.concatenate((logits_grad0, logits_grad1), axis=1)
159167

160168
# compare results
161169
rtol = 1e-6
162-
np.testing.assert_allclose(loss0, need_loss, rtol=rtol)
163-
np.testing.assert_allclose(loss1, need_loss, rtol=rtol)
164-
np.testing.assert_allclose(softmax, need_softmax, rtol=rtol)
165-
np.testing.assert_allclose(logits_grad, need_logits_grad, rtol=rtol)
170+
atol = 0
171+
if dtype == "bfloat16":
172+
rtol = 0.1
173+
atol = 0.1
174+
np.testing.assert_allclose(loss0, need_loss, rtol=rtol, atol=atol)
175+
np.testing.assert_allclose(loss1, need_loss, rtol=rtol, atol=atol)
176+
np.testing.assert_allclose(
177+
softmax, need_softmax, rtol=rtol, atol=atol
178+
)
179+
np.testing.assert_allclose(
180+
logits_grad, need_logits_grad, rtol=rtol, atol=atol
181+
)
166182

167183

168184
support_types = get_xpu_op_support_types('c_softmax_with_cross_entropy')
@@ -171,7 +187,7 @@ def check_with_place(
171187
globals(),
172188
XPUTestCSoftmaxWithCEOP,
173189
stype,
174-
ignore_device_version=[core.XPUVersion.XPU1],
190+
ignore_device_version=[core.XPUVersion.XPU1, core.XPUVersion.XPU3],
175191
)
176192

177193
if __name__ == '__main__':

0 commit comments

Comments
 (0)