Skip to content

Commit 549868a

Browse files
committed
custom op spmd rule register
1 parent d385c68 commit 549868a

File tree

4 files changed

+23
-2
lines changed

4 files changed

+23
-2
lines changed

cmake/inference_lib.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,13 @@ copy(
328328
inference_lib_dist
329329
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
330330
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
331+
332+
copy(
333+
inference_lib_dist
334+
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h
335+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/
336+
)
337+
331338
copy(
332339
inference_lib_dist
333340
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h

paddle/phi/api/ext/op_meta_info.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,11 @@ struct TrtGetOutputDimsFuncImpl<Return (*)(Args...), impl_fn> {
996996
#endif
997997

998998
////////////////////// Op Meta Info //////////////////////
999-
class CustomSpmdInferTensorArg;
1000-
class CustomSpmdInferAttrArg;
999+
1000+
using CustomSpmdInferTensorArg =
1001+
paddle::variant<phi::distributed::DistMetaTensor,
1002+
std::vector<phi::distributed::DistMetaTensor>>;
1003+
using CustomSpmdInferAttrArg = paddle::any;
10011004

10021005
using InferSpmdFunc = phi::distributed::SpmdInfo (*)(
10031006
const std::vector<CustomSpmdInferTensorArg>& inputs,

paddle/phi/core/distributed/type_defs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
namespace phi {
2424
namespace distributed {
2525
class TensorDistAttr;
26+
class DistMetaTensor;
2627

2728
using ArgDistAttr =
2829
paddle::variant<TensorDistAttr, std::vector<TensorDistAttr>>;

test/cpp/auto_parallel/custom_op_spmd_rule_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/phi/api/ext/op_meta_info.h"
1516
#include "paddle/phi/api/ext/spmd_infer.h"
1617
#include "test/cpp/auto_parallel/spmd_rule_test_util.h"
1718

@@ -74,6 +75,15 @@ TEST(CustomOp, Ctor) {
7475
check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0});
7576
check_partial_dims(infered_dist_attrs.second[0], {});
7677
}
78+
79+
TEST(CustomOp, Register) {
80+
OpMetaInfoBuilder builder("test_custom_op_smpd", 0);
81+
auto iter = OpMetaInfoMap::Instance().GetMap().find("test_custom_op_smpd");
82+
EXPECT_TRUE(iter != OpMetaInfoMap::Instance().GetMap().end());
83+
EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) == nullptr);
84+
builder.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::ConcatInferSpmd));
85+
EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) != nullptr);
86+
}
7787
} // namespace auto_parallel
7888
} // namespace distributed
7989
} // namespace paddle

0 commit comments

Comments
 (0)