Skip to content

Commit 5040226

Browse files
committed
custom op spmd rule register
1 parent 8b2b953 commit 5040226

File tree

9 files changed

+878
-597
lines changed

9 files changed

+878
-597
lines changed

cmake/inference_lib.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,18 @@ copy(
328328
inference_lib_dist
329329
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
330330
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
331+
331332
copy(
332333
inference_lib_dist
333334
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h
334335
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
336+
337+
copy(
338+
inference_lib_dist
339+
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h
340+
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/
341+
)
342+
335343
copy(
336344
inference_lib_dist
337345
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h

paddle/extension.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,4 @@ limitations under the License. */
2222
#endif
2323
// For initialization of DeviceContextPool and MemoryMethod
2424
#include "paddle/fluid/platform/init_phi.h"
25-
2625
static paddle::InitPhi g_init_phi;

paddle/phi/api/ext/op_meta_info.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/common/exception.h"
2424
#include "paddle/phi/api/include/dll_decl.h"
2525
#include "paddle/phi/api/include/tensor.h"
26+
#include "paddle/phi/core/distributed/type_defs.h"
2627
#include "paddle/utils/any.h"
2728
#include "paddle/utils/none.h"
2829
#include "paddle/utils/optional.h"
@@ -995,6 +996,12 @@ struct TrtGetOutputDimsFuncImpl<Return (*)(Args...), impl_fn> {
995996
#endif
996997

997998
////////////////////// Op Meta Info //////////////////////
999+
class CustomSpmdInferTensorArg;
1000+
class CustomSpmdInferAttrArg;
1001+
1002+
using InferSpmdFunc = phi::distributed::SpmdInfo (*)(
1003+
const std::vector<CustomSpmdInferTensorArg>& inputs,
1004+
const std::vector<CustomSpmdInferAttrArg>& attrs);
9981005

9991006
class PADDLE_API OpMetaInfo {
10001007
public:
@@ -1023,6 +1030,9 @@ class PADDLE_API OpMetaInfo {
10231030
// format: PD_INFER_DTYPE(...)
10241031
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
10251032

1033+
// format: PD_INFER_SPMD_RULE(...)
1034+
OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func);
1035+
10261036
#ifdef PADDLE_WITH_TENSORRT
10271037
// format: PD_TRT_INFER_SHAPE(...)
10281038
OpMetaInfo& SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func);
@@ -1045,6 +1055,7 @@ class PADDLE_API OpMetaInfo {
10451055
KernelFunc kernel_fn_{nullptr};
10461056
InferShapeFunc infer_shape_fn_{nullptr};
10471057
InferDtypeFunc infer_dtype_fn_{nullptr};
1058+
InferSpmdFunc infer_spmd_fn_{nullptr};
10481059
#ifdef PADDLE_WITH_TENSORRT
10491060
TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr};
10501061
std::vector<std::string> trt_supports_format_config_;
@@ -1068,6 +1079,7 @@ class OpMetaInfoHelper {
10681079
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info);
10691080
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info);
10701081
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info);
1082+
static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info);
10711083

10721084
#ifdef PADDLE_WITH_TENSORRT
10731085
static const TrtGetOutputDimsFunc& GetTrtInferShapeFn(
@@ -1108,6 +1120,7 @@ class PADDLE_API OpMetaInfoBuilder {
11081120
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
11091121
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
11101122
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
1123+
OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func);
11111124

11121125
#ifdef PADDLE_WITH_TENSORRT
11131126
OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func);

paddle/phi/api/ext/spmd_infer.h

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
17+
#include "paddle/phi/core/distributed/type_defs.h"
18+
19+
namespace paddle {
20+
21+
using CustomSpmdInferTensorArg =
22+
paddle::variant<phi::distributed::DistMetaTensor,
23+
std::vector<phi::distributed::DistMetaTensor>>;
24+
25+
using CustomSpmdInferAttrArg = paddle::any;
26+
template <typename T>
27+
struct SpmdInferHelperTypeEnd {};
28+
29+
#define PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(attr_type) \
30+
template <typename... Tail> \
31+
struct SpmdInferHelper<attr_type, Tail...> { \
32+
template <int in_idx, int attr_idx, typename... PreviousArgs> \
33+
static phi::distributed::SpmdInfo InferSpmd( \
34+
const std::vector<CustomSpmdInferTensorArg>& inputs, \
35+
const std::vector<CustomSpmdInferAttrArg>& attrs, \
36+
const PreviousArgs&... pargs) { \
37+
try { \
38+
attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]); \
39+
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx, \
40+
attr_idx + 1>( \
41+
inputs, attrs, pargs..., arg); \
42+
} catch (paddle::bad_any_cast&) { \
43+
PD_THROW( \
44+
"Attribute cast error in custom operator SpmdInferFunc " \
45+
"function. " \
46+
"Expected " #attr_type \
47+
" value. SpmdInferFunc's attribute list must be exactly " \
48+
"same " \
49+
"as " \
50+
"Forward " \
51+
"KernelFn's attribute list except std::vector<int64_t> " \
52+
"attribute."); \
53+
} \
54+
} \
55+
}
56+
57+
template <typename F, F f>
58+
struct SpmdInferImpl;
59+
60+
template <typename... Args, phi::distributed::SpmdInfo (*impl_fn)(Args...)>
61+
struct SpmdInferImpl<phi::distributed::SpmdInfo (*)(Args...), impl_fn> {
62+
static phi::distributed::SpmdInfo InferSpmd(
63+
const std::vector<CustomSpmdInferTensorArg>& inputs,
64+
const std::vector<CustomSpmdInferAttrArg>& attrs) {
65+
return SpmdInferHelper<Args..., SpmdInferHelperTypeEnd<int>>::
66+
template InferSpmd<0, 0>(inputs, attrs);
67+
}
68+
69+
private:
70+
template <typename... RemainingArgs>
71+
struct SpmdInferHelper;
72+
73+
// Handle args for general tensor input case
74+
template <typename... Tail>
75+
struct SpmdInferHelper<const phi::distributed::DistMetaTensor&, Tail...> {
76+
template <int in_idx, int attr_idx, typename... PreviousArgs>
77+
static phi::distributed::SpmdInfo InferSpmd(
78+
const std::vector<CustomSpmdInferTensorArg>& inputs,
79+
const std::vector<CustomSpmdInferAttrArg>& attrs,
80+
PreviousArgs&... pargs) {
81+
auto& arg =
82+
PADDLE_GET_CONST(phi::distributed::DistMetaTensor, inputs[in_idx]);
83+
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
84+
inputs, attrs, pargs..., arg);
85+
}
86+
};
87+
88+
// Handle args for vector of Tensor input case
89+
template <typename... Tail>
90+
struct SpmdInferHelper<const std::vector<phi::distributed::DistMetaTensor>&,
91+
Tail...> {
92+
template <int in_idx, int attr_idx, typename... PreviousArgs>
93+
static phi::distributed::SpmdInfo InferSpmd(
94+
const std::vector<CustomSpmdInferTensorArg>& inputs,
95+
const std::vector<CustomSpmdInferAttrArg>& attrs,
96+
PreviousArgs&... pargs) {
97+
auto& arg = PADDLE_GET_CONST(
98+
std::vector<phi::distributed::DistMetaTensor>, inputs[in_idx]);
99+
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
100+
inputs, attrs, pargs..., arg);
101+
}
102+
};
103+
104+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(bool);
105+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int);
106+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(float);
107+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int64_t);
108+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::string&);
109+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int>&);
110+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<float>&);
111+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<std::string>&);
112+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int64_t>&);
113+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const bool&);
114+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int&);
115+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const float&);
116+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int64_t&);
117+
118+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::string);
119+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<int>);
120+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<float>);
121+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<std::string>);
122+
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int64_t>);
123+
124+
// end: base template
125+
template <typename T>
126+
struct SpmdInferHelper<SpmdInferHelperTypeEnd<T>> {
127+
template <int in_idx, int attr_idx, typename... PreviousArgs>
128+
static phi::distributed::SpmdInfo InferSpmd(
129+
const std::vector<CustomSpmdInferTensorArg>& inputs,
130+
const std::vector<CustomSpmdInferAttrArg>& attrs,
131+
PreviousArgs&... pargs) {
132+
return impl_fn(pargs...);
133+
}
134+
};
135+
};
136+
137+
#define PD_INFER_SPMD_RULE(...) \
138+
::paddle::SpmdInferImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferSpmd
139+
140+
} // namespace paddle

paddle/phi/api/lib/op_meta_info.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) {
358358
return *this;
359359
}
360360

361+
OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) {
362+
infer_spmd_fn_ = std::forward<InferSpmdFunc>(func);
363+
return *this;
364+
}
365+
361366
#ifdef PADDLE_WITH_TENSORRT
362367
OpMetaInfo& OpMetaInfo::SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func) {
363368
trt_infer_shape_fn_ = std::forward<TrtGetOutputDimsFunc>(func);
@@ -407,6 +412,11 @@ const InferDtypeFunc& OpMetaInfoHelper::GetInferDtypeFn(
407412
return info.infer_dtype_fn_;
408413
}
409414

415+
const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn(
416+
const paddle::OpMetaInfo& info) {
417+
return info.infer_spmd_fn_;
418+
}
419+
410420
#ifdef PADDLE_WITH_TENSORRT
411421
const TrtGetOutputDimsFunc& OpMetaInfoHelper::GetTrtInferShapeFn(
412422
const paddle::OpMetaInfo& info) {
@@ -559,6 +569,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
559569
return *this;
560570
}
561571

572+
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) {
573+
info_ptr_->SetInferSpmdFn(std::forward<InferSpmdFunc>(func));
574+
return *this;
575+
}
576+
562577
#ifdef PADDLE_WITH_TENSORRT
563578
OpMetaInfoBuilder& OpMetaInfoBuilder::SetTrtInferShapeFn(
564579
TrtGetOutputDimsFunc func) {

0 commit comments

Comments
 (0)