Skip to content

Commit d53e567

Browse files
authored
fix bug of recompute in hybridparallel (#35588)
1 parent 652da1f commit d53e567

File tree

4 files changed

+8
-0
lines changed

4 files changed

+8
-0
lines changed

paddle/fluid/operators/flatten_op.cu.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/flatten_op.h"
1616

1717
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
1819

1920
REGISTER_OP_CUDA_KERNEL(
2021
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
@@ -50,6 +51,8 @@ REGISTER_OP_CUDA_KERNEL(
5051
flatten_contiguous_range,
5152
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
5253
float>,
54+
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
55+
plat::float16>,
5356
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
5457
double>,
5558
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
@@ -63,6 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
6366
flatten_contiguous_range_grad,
6467
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
6568
float>,
69+
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
70+
plat::float16>,
6671
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
6772
double>,
6873
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,

python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _split_activation(tensor):
133133

134134
# use inplace operation to save memory
135135
data = tensor.flatten_()
136+
136137
part_size = tensor_numel // mp_degree
137138
start = part_size * mp_rank
138139
end = start + part_size

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _update_list(self):
9494
'softmax',
9595
'softmax_with_cross_entropy',
9696
'sigmoid_cross_entropy_with_logits',
97+
'c_softmax_with_cross_entropy',
9798
'cross_entropy',
9899
'cross_entropy2',
99100
# fp16 is slower than fp32, though fp16 is supported.

python/paddle/fluid/dygraph/amp/auto_cast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
'softmax',
4646
'softmax_with_cross_entropy',
4747
'sigmoid_cross_entropy_with_logits',
48+
'c_softmax_with_cross_entropy',
4849
'cross_entropy',
4950
'cross_entropy2',
5051
# default fp32 can avoid return inf when the sum value large than 65504

0 commit comments

Comments
 (0)