@@ -19,8 +19,7 @@ limitations under the License. */
1919#include < unordered_map>
2020#include < vector>
2121
22- #include < boost/any.hpp>
23-
22+ #include " any.h"
2423#include " ext_dll_decl.h" // NOLINT
2524#include " ext_exception.h" // NOLINT
2625#include " ext_tensor.h" // NOLINT
@@ -83,7 +82,7 @@ inline std::string Vec(const std::string& t_name) {
8382using KernelFunc =
8483 std::vector<Tensor> (*)(const std::vector<Tensor>& inputs,
8584 const std::vector<std::vector<Tensor>>& vec_inputs,
86- const std::vector<boost ::any>& attrs);
85+ const std::vector<paddle ::any>& attrs);
8786
8887#define PD_SPECIALIZE_ComputeCallHelper (attr_type ) \
8988 template <typename ... Tail> \
@@ -92,14 +91,14 @@ using KernelFunc =
9291 typename ... PreviousArgs> \
9392 static Return Compute (const std::vector<Tensor>& inputs, \
9493 const std::vector<std::vector<Tensor>>& vec_inputs, \
95- const std::vector<boost ::any>& attrs, \
94+ const std::vector<paddle ::any>& attrs, \
9695 const PreviousArgs&... pargs) { \
9796 try { \
98- attr_type arg = boost ::any_cast<attr_type>(attrs[attr_idx]); \
97+ attr_type arg = paddle ::any_cast<attr_type>(attrs[attr_idx]); \
9998 return ComputeCallHelper<Tail...>::template Compute< \
10099 in_idx, vec_in_idx, attr_idx + 1 >(inputs, vec_inputs, attrs, \
101100 pargs..., arg); \
102- } catch (boost ::bad_any_cast&) { \
101+ } catch (paddle ::bad_any_cast&) { \
103102 PD_THROW ( \
104103 " Attribute cast error in custom operator. Expected " #attr_type \
105104 " value." ); \
@@ -117,7 +116,7 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
117116struct KernelFuncImpl <Return (*)(Args...), impl_fn> {
118117 static Return Compute (const std::vector<Tensor>& inputs,
119118 const std::vector<std::vector<Tensor>>& vec_inputs,
120- const std::vector<boost ::any>& attrs) {
119+ const std::vector<paddle ::any>& attrs) {
121120 return ComputeCallHelper<Args..., TypeTag<int >>::template Compute<0 , 0 , 0 >(
122121 inputs, vec_inputs, attrs);
123122 }
@@ -132,7 +131,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
132131 typename ... PreviousArgs>
133132 static Return Compute (const std::vector<Tensor>& inputs,
134133 const std::vector<std::vector<Tensor>>& vec_inputs,
135- const std::vector<boost ::any>& attrs,
134+ const std::vector<paddle ::any>& attrs,
136135 const PreviousArgs&... pargs) {
137136 const Tensor& arg = inputs[in_idx];
138137 return ComputeCallHelper<Tail...>::template Compute<in_idx + 1 ,
@@ -147,7 +146,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
147146 typename ... PreviousArgs>
148147 static Return Compute (const std::vector<Tensor>& inputs,
149148 const std::vector<std::vector<Tensor>>& vec_inputs,
150- const std::vector<boost ::any>& attrs,
149+ const std::vector<paddle ::any>& attrs,
151150 const PreviousArgs&... pargs) {
152151 const std::vector<Tensor>& arg = vec_inputs[vec_in_idx];
153152 return ComputeCallHelper<Tail...>::template Compute<
@@ -189,7 +188,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
189188 template <int in_idx, int vec_in_idx, int attr_idx>
190189 static Return Compute (const std::vector<Tensor>& inputs,
191190 const std::vector<std::vector<Tensor>>& vec_inputs,
192- const std::vector<boost ::any>& attrs,
191+ const std::vector<paddle ::any>& attrs,
193192 const Args&... args) {
194193 return impl_fn (args...);
195194 }
@@ -205,67 +204,67 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
205204using InferShapeFunc = std::vector<std::vector<int64_t >> (*)(
206205 const std::vector<std::vector<int64_t >>& input_shapes,
207206 const std::vector<std::vector<std::vector<int64_t >>>& vec_input_shapes,
208- const std::vector<boost ::any>& attrs);
209-
210- #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE (input_type ) \
211- template <typename ... Tail> \
212- struct InferShapeCallHelper <input_type, Tail...> { \
213- template <int in_idx, int vec_in_idx, int attr_idx, \
214- typename ... PreviousArgs> \
215- static Return InferShape ( \
216- const std::vector<std::vector<int64_t >>& input_shapes, \
217- const std::vector<std::vector<std::vector<int64_t >>>& \
218- vec_input_shapes, \
219- const std::vector<boost ::any>& attrs, const PreviousArgs&... pargs) { \
220- input_type arg = input_shapes[in_idx]; \
221- return InferShapeCallHelper<Tail...>::template InferShape< \
222- in_idx + 1 , vec_in_idx, attr_idx>(input_shapes, vec_input_shapes, \
223- attrs, pargs..., arg); \
224- } \
207+ const std::vector<paddle ::any>& attrs);
208+
209+ #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE (input_type ) \
210+ template <typename ... Tail> \
211+ struct InferShapeCallHelper <input_type, Tail...> { \
212+ template <int in_idx, int vec_in_idx, int attr_idx, \
213+ typename ... PreviousArgs> \
214+ static Return InferShape ( \
215+ const std::vector<std::vector<int64_t >>& input_shapes, \
216+ const std::vector<std::vector<std::vector<int64_t >>>& \
217+ vec_input_shapes, \
218+ const std::vector<paddle ::any>& attrs, const PreviousArgs&... pargs) { \
219+ input_type arg = input_shapes[in_idx]; \
220+ return InferShapeCallHelper<Tail...>::template InferShape< \
221+ in_idx + 1 , vec_in_idx, attr_idx>(input_shapes, vec_input_shapes, \
222+ attrs, pargs..., arg); \
223+ } \
225224 }
226225
227- #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES (input_type ) \
228- template <typename ... Tail> \
229- struct InferShapeCallHelper <input_type, Tail...> { \
230- template <int in_idx, int vec_in_idx, int attr_idx, \
231- typename ... PreviousArgs> \
232- static Return InferShape ( \
233- const std::vector<std::vector<int64_t >>& input_shapes, \
234- const std::vector<std::vector<std::vector<int64_t >>>& \
235- vec_input_shapes, \
236- const std::vector<boost ::any>& attrs, const PreviousArgs&... pargs) { \
237- input_type arg = vec_input_shapes[vec_in_idx]; \
238- return InferShapeCallHelper<Tail...>::template InferShape< \
239- in_idx, vec_in_idx + 1 , attr_idx>(input_shapes, vec_input_shapes, \
240- attrs, pargs..., arg); \
241- } \
226+ #define PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES (input_type ) \
227+ template <typename ... Tail> \
228+ struct InferShapeCallHelper <input_type, Tail...> { \
229+ template <int in_idx, int vec_in_idx, int attr_idx, \
230+ typename ... PreviousArgs> \
231+ static Return InferShape ( \
232+ const std::vector<std::vector<int64_t >>& input_shapes, \
233+ const std::vector<std::vector<std::vector<int64_t >>>& \
234+ vec_input_shapes, \
235+ const std::vector<paddle ::any>& attrs, const PreviousArgs&... pargs) { \
236+ input_type arg = vec_input_shapes[vec_in_idx]; \
237+ return InferShapeCallHelper<Tail...>::template InferShape< \
238+ in_idx, vec_in_idx + 1 , attr_idx>(input_shapes, vec_input_shapes, \
239+ attrs, pargs..., arg); \
240+ } \
242241 }
243242
244- #define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR (attr_type ) \
245- template <typename ... Tail> \
246- struct InferShapeCallHelper <attr_type, Tail...> { \
247- template <int in_idx, int vec_in_idx, int attr_idx, \
248- typename ... PreviousArgs> \
249- static Return InferShape ( \
250- const std::vector<std::vector<int64_t >>& input_shapes, \
251- const std::vector<std::vector<std::vector<int64_t >>>& \
252- vec_input_shapes, \
253- const std::vector<boost ::any>& attrs, const PreviousArgs&... pargs) { \
254- try { \
255- attr_type arg = boost ::any_cast<attr_type>(attrs[attr_idx]); \
256- return InferShapeCallHelper<Tail...>::template InferShape< \
257- in_idx, vec_in_idx, attr_idx + 1 >(input_shapes, vec_input_shapes, \
258- attrs, pargs..., arg); \
259- } catch (boost ::bad_any_cast&) { \
260- PD_THROW ( \
261- " Attribute cast error in custom operator InferShapeFn. " \
262- " Expected " #attr_type \
263- " value. InferShapeFn's attribute list must be exactly same as " \
264- " Forward " \
265- " KernelFn's attribute list except std::vector<int64_t> " \
266- " attribute." ); \
267- } \
268- } \
243+ #define PD_SPECIALIZE_InferShapeCallHelper_FOR_ATTR (attr_type ) \
244+ template <typename ... Tail> \
245+ struct InferShapeCallHelper <attr_type, Tail...> { \
246+ template <int in_idx, int vec_in_idx, int attr_idx, \
247+ typename ... PreviousArgs> \
248+ static Return InferShape ( \
249+ const std::vector<std::vector<int64_t >>& input_shapes, \
250+ const std::vector<std::vector<std::vector<int64_t >>>& \
251+ vec_input_shapes, \
252+ const std::vector<paddle ::any>& attrs, const PreviousArgs&... pargs) { \
253+ try { \
254+ attr_type arg = paddle ::any_cast<attr_type>(attrs[attr_idx]); \
255+ return InferShapeCallHelper<Tail...>::template InferShape< \
256+ in_idx, vec_in_idx, attr_idx + 1 >(input_shapes, vec_input_shapes, \
257+ attrs, pargs..., arg); \
258+ } catch (paddle ::bad_any_cast&) { \
259+ PD_THROW ( \
260+ " Attribute cast error in custom operator InferShapeFn. " \
261+ " Expected " #attr_type \
262+ " value. InferShapeFn's attribute list must be exactly same as " \
263+ " Forward " \
264+ " KernelFn's attribute list except std::vector<int64_t> " \
265+ " attribute." ); \
266+ } \
267+ } \
269268 }
270269
271270template <typename F, F f>
@@ -276,7 +275,7 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
276275 static Return InferShape (
277276 const std::vector<std::vector<int64_t >>& input_shapes,
278277 const std::vector<std::vector<std::vector<int64_t >>>& vec_input_shapes,
279- const std::vector<boost ::any>& attrs) {
278+ const std::vector<paddle ::any>& attrs) {
280279 return InferShapeCallHelper<Args..., TypeTag<int >>::template InferShape<
281280 0 , 0 , 0 >(input_shapes, vec_input_shapes, attrs);
282281 }
@@ -314,7 +313,7 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
314313 static Return InferShape (
315314 const std::vector<std::vector<int64_t >>& input_shapes,
316315 const std::vector<std::vector<std::vector<int64_t >>>& vec_input_shapes,
317- const std::vector<boost ::any>& attrs, const Args&... args) {
316+ const std::vector<paddle ::any>& attrs, const Args&... args) {
318317 return impl_fn (args...);
319318 }
320319 };
0 commit comments