-
Notifications
You must be signed in to change notification settings - Fork 6k
Enhance gc to support deleting tensor buffer in advance #16409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a93a9ee
072d95d
f8ed2c2
7000ec8
a7d0ac5
78fb3a6
a0f4fef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ limitations under the License. */ | |
| #include <vector> | ||
| #include "paddle/fluid/framework/grad_op_desc_maker.h" | ||
| #include "paddle/fluid/framework/inplace_op_inference.h" | ||
| #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" | ||
| #include "paddle/fluid/framework/op_info.h" | ||
| #include "paddle/fluid/framework/op_proto_maker.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
|
|
@@ -36,27 +37,86 @@ enum OpInfoFillType { | |
| kGradOpDescMaker = 2, | ||
| kVarTypeInference = 3, | ||
| kShapeInference = 4, | ||
| kInplaceOpInference = 5 | ||
| kInplaceOpInference = 5, | ||
| kNoNeedBufferVarsInference = 6, | ||
| kUnknown = -1 | ||
| }; | ||
|
|
||
| namespace internal { | ||
| template <typename T, OpInfoFillType kType> | ||
| struct TypePair { | ||
| using Type = T; | ||
| static constexpr OpInfoFillType kFillType = kType; | ||
| }; | ||
|
|
||
| using OpRegistryClasses = std::tuple< // NOLINT | ||
| TypePair<OperatorBase, kOperator>, // NOLINT | ||
| TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT | ||
| TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT | ||
| TypePair<VarTypeInference, kVarTypeInference>, // NOLINT | ||
| TypePair<InferShapeBase, kShapeInference>, // NOLINT | ||
| TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT | ||
| TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT | ||
| >; | ||
|
|
||
| static constexpr int kOpRegistryClassNumber = | ||
| std::tuple_size<OpRegistryClasses>::value; | ||
|
|
||
| template <typename T, int kPos, bool kIsBounded /* = true*/> | ||
| struct IsMatchedBaseTypeImpl { | ||
| using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type; | ||
| static constexpr bool kValue = | ||
| std::is_base_of<typename PairType::Type, T>::value; | ||
| }; | ||
|
|
||
| template <typename T, int kPos> | ||
| struct IsMatchedBaseTypeImpl<T, kPos, false> { | ||
| static constexpr bool kValue = false; | ||
| }; | ||
|
|
||
| template <typename T, int kPos> | ||
| static inline constexpr bool IsMatchedBaseType() { | ||
| return IsMatchedBaseTypeImpl< | ||
| T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue; | ||
| } | ||
|
|
||
| template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched> | ||
| struct OpInfoFillTypeGetterImpl {}; | ||
|
|
||
| // This case should not happen | ||
| template <typename T, int kStart, int kEnd> | ||
| struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {}; | ||
|
|
||
| template <typename T, int kStart, int kEnd> | ||
| struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> { | ||
| static constexpr OpInfoFillType kType = kUnknown; | ||
| }; | ||
|
|
||
| template <typename T, int kStart, int kEnd> | ||
| struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> { | ||
| static constexpr OpInfoFillType kType = | ||
| OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd, | ||
| IsMatchedBaseType<T, kStart + 1>()>::kType; | ||
| }; | ||
|
|
||
| template <typename T, int kStart, int kEnd> | ||
| struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> { | ||
| using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type; | ||
| static constexpr OpInfoFillType kType = PairType::kFillType; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| using OpInfoFillTypeGetter = | ||
| OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber, | ||
| kOpRegistryClassNumber == 0, | ||
| IsMatchedBaseType<T, 0>()>; | ||
|
|
||
| } // namespace internal | ||
|
|
||
| template <typename T> | ||
| struct OpInfoFillTypeID { | ||
| static constexpr OpInfoFillType ID() { | ||
| return std::is_base_of<OperatorBase, T>::value | ||
| ? kOperator | ||
| : (std::is_base_of<OpProtoAndCheckerMaker, T>::value | ||
| ? kOpProtoAndCheckerMaker | ||
| : (std::is_base_of<GradOpDescMakerBase, T>::value | ||
| ? kGradOpDescMaker | ||
| : (std::is_base_of<VarTypeInference, T>::value | ||
| ? kVarTypeInference | ||
| : (std::is_base_of<InferShapeBase, T>::value | ||
| ? kShapeInference | ||
| : (std::is_base_of< | ||
| InplaceOpInference, T>::value | ||
| ? kInplaceOpInference | ||
| : static_cast<OpInfoFillType>( | ||
| -1)))))); | ||
| return internal::OpInfoFillTypeGetter<T>::kType; | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -156,6 +216,18 @@ struct OpInfoFiller<T, kInplaceOpInference> { | |
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct OpInfoFiller<T, kNoNeedBufferVarsInference> { | ||
| void operator()(const char* op_type, OpInfo* info) const { | ||
| info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious about when will these three parameters be used to get the NoNeedBufferVars, seems now we just return the parameters specified in the macro as an unordered_set?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reserve these parameters for future use. Some ops may not need some forward inputs or outputs when some attribute is true/false. For example, batch_norm_grad_op does not need Bias when use_mkldnn is false. |
||
| const VariableNameMap& outputs, | ||
| const AttributeMap& attrs) { | ||
| T infer(inputs, outputs, attrs); | ||
| return infer(); | ||
| }; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace details | ||
|
|
||
| } // namespace framework | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugly but scalable codes here. I rewrite
OpInfoFillTypeID::ID()method because the character number limit is set to be 80 in a line.