|
593 | 593 | }} |
594 | 594 | """ |
595 | 595 |
|
| 596 | +# Note: After unify the expand, expand_as and their grad kernel for all device, |
| 597 | +# this logic is no practical effect. But for semantically correct and can be removed. |
596 | 598 | CALCULATE_LOCAL_SHAPE_KERNEL_TEMPLATE = """ |
597 | 599 |
|
598 | 600 | auto out_grad_shape = out_grad.dims(); |
599 | 601 | std::vector<{dtype}> local_kernel_shape; |
600 | 602 | const auto& out_grad_dist_attr = {out_grad_dist_attr}; |
| 603 | + const auto& grad_mesh_shape = out_grad_dist_attr.process_mesh().shape(); |
601 | 604 | for (int i = 0; i < out_grad_shape.size(); i++) {{ |
602 | | - if (out_grad_dist_attr.dims_mapping()[i] >= 0) {{ |
| 605 | + const auto& dims = out_grad_dist_attr.multi_dims_mapping()[i]; |
| 606 | + if (dims.size() > 0) {{ |
603 | 607 | {dtype} shape_i = out_grad_shape[i]; |
604 | | - int64_t dim = out_grad_dist_attr.dims_mapping()[i]; |
605 | | - int64_t mesh_dim = out_grad_dist_attr.process_mesh().shape()[dim]; |
| 608 | + int64_t num_shard = 1; |
| 609 | + for (auto dim : dims) {{ |
| 610 | + num_shard *= grad_mesh_shape[dim]; |
| 611 | + }} |
606 | 612 | // TODO: Support aliquant condition. |
607 | | - PADDLE_ENFORCE(shape_i % mesh_dim == 0, |
608 | | - common::errors::InvalidArgument( |
609 | | - "{op_name} only support local shape dim is divisible " |
610 | | - "by the mesh dim, however local_kernel_shape[%lld] is %lld " |
611 | | - "and shard mesh dims is %lld.", i, shape_i, mesh_dim)); |
612 | | - local_kernel_shape.push_back(shape_i / mesh_dim); |
| 613 | + PADDLE_ENFORCE_EQ( |
| 614 | + shape_i % num_shard, 0, |
| 615 | + common::errors::InvalidArgument( |
| 616 | + "{op_name} only support local shape dim is divisible " |
| 617 | + "by the mesh dim, however local_kernel_shape[%lld] is %lld " |
| 618 | + "and shard mesh dims is %lld.", |
| 619 | + i, shape_i, num_shard)); |
613 | 620 | }} else {{ |
614 | 621 | local_kernel_shape.push_back(out_grad_shape[i]); |
615 | 622 | }} |
|
0 commit comments