Skip to content

Commit 610246f

Browse files
authored
use PADDLE_ENFORCE_EQ to replace PADDLE_ENFORCE after review (#74124)
1 parent 376f4c8 commit 610246f

2 files changed

Lines changed: 17 additions & 16 deletions

File tree

paddle/phi/api/generator/dist_api_gen.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -593,23 +593,30 @@
593593
}}
594594
"""
595595

596+
# Note: After unify the expand, expand_as and their grad kernel for all device,
597+
# this logic is no practical effect. But for semantically correct and can be removed.
596598
CALCULATE_LOCAL_SHAPE_KERNEL_TEMPLATE = """
597599
598600
auto out_grad_shape = out_grad.dims();
599601
std::vector<{dtype}> local_kernel_shape;
600602
const auto& out_grad_dist_attr = {out_grad_dist_attr};
603+
const auto& grad_mesh_shape = out_grad_dist_attr.process_mesh().shape();
601604
for (int i = 0; i < out_grad_shape.size(); i++) {{
602-
if (out_grad_dist_attr.dims_mapping()[i] >= 0) {{
605+
const auto& dims = out_grad_dist_attr.multi_dims_mapping()[i];
606+
if (dims.size() > 0) {{
603607
{dtype} shape_i = out_grad_shape[i];
604-
int64_t dim = out_grad_dist_attr.dims_mapping()[i];
605-
int64_t mesh_dim = out_grad_dist_attr.process_mesh().shape()[dim];
608+
int64_t num_shard = 1;
609+
for (auto dim : dims) {{
610+
num_shard *= grad_mesh_shape[dim];
611+
}}
606612
// TODO: Support aliquant condition.
607-
PADDLE_ENFORCE(shape_i % mesh_dim == 0,
608-
common::errors::InvalidArgument(
609-
"{op_name} only support local shape dim is divisible "
610-
"by the mesh dim, however local_kernel_shape[%lld] is %lld "
611-
"and shard mesh dims is %lld.", i, shape_i, mesh_dim));
612-
local_kernel_shape.push_back(shape_i / mesh_dim);
613+
PADDLE_ENFORCE_EQ(
614+
shape_i % num_shard, 0,
615+
common::errors::InvalidArgument(
616+
"{op_name} only support local shape dim is divisible "
617+
"by the mesh dim, however local_kernel_shape[%lld] is %lld "
618+
"and shard mesh dims is %lld.",
619+
i, shape_i, num_shard));
613620
}} else {{
614621
local_kernel_shape.push_back(out_grad_shape[i]);
615622
}}

test/cpp/auto_parallel/spmd_rule_test.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3544,13 +3544,7 @@ TEST(LabelSmooth, Ctor) {
35443544
EXPECT_EQ(forward_info.first.size(), 2UL);
35453545
EXPECT_EQ(forward_info.second.size(), 1UL);
35463546
check_dim_mapping(forward_info.first[0], {0, 1, -1});
3547-
const phi::distributed::ArgDistAttr& attr = forward_info.first[1];
3548-
if (paddle::holds_alternative<phi::distributed::TensorDistAttr>(attr)) {
3549-
EXPECT_EQ(paddle::get<phi::distributed::TensorDistAttr>(attr),
3550-
phi::distributed::TensorDistAttr());
3551-
} else {
3552-
FAIL() << "forward_info.first[1] is not TensorDistAttr";
3553-
}
3547+
check_empty_dist_attr(forward_info.first[1]);
35543548
check_dim_mapping(forward_info.second[0], {0, 1, -1});
35553549

35563550
// shape: [16, 16, 16], [1, 16]. [0, 1, -1], [-1, -1] --> [0, 1, -1], [-1,

0 commit comments

Comments
 (0)