Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)

cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog)

cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
Expand Down Expand Up @@ -126,7 +126,7 @@ cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_con
cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)

cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc no_need_buffer_vars_inference.cc DEPS shape_inference op_info operator glog version)

cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
Expand Down Expand Up @@ -164,6 +164,8 @@ else()
set(NGRAPH_EXE_DEPS)
endif()

cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)

if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
Expand All @@ -174,7 +176,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()

target_link_libraries(executor garbage_collector while_op_helper)
target_link_libraries(executor while_op_helper executor_gc_helper)

cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
Expand All @@ -194,6 +196,7 @@ cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_con
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)

cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)

Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/framework/details/eager_deletion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,9 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");

namespace paddle {
namespace framework {
namespace details {
Expand Down Expand Up @@ -206,8 +201,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
}
}

op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
FLAGS_memory_fraction_of_eager_deletion);
double memory_fraction = framework::GetEagerDeletionMemoryFraction();

op_vars_map = ShrinkGCVars(op_vars_map, vars, places, memory_fraction);

for (auto &pair : op_vars_map) {
auto *op = pair.first;
Expand Down Expand Up @@ -239,8 +235,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf);
}

VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
<< FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";

auto while_op_eager_deletion_pass =
Expand Down
140 changes: 0 additions & 140 deletions paddle/fluid/framework/details/early_delete_op_handle.h

This file was deleted.

104 changes: 88 additions & 16 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
Expand All @@ -36,27 +37,86 @@ enum OpInfoFillType {
kGradOpDescMaker = 2,
kVarTypeInference = 3,
kShapeInference = 4,
kInplaceOpInference = 5
kInplaceOpInference = 5,
kNoNeedBufferVarsInference = 6,
kUnknown = -1
};

namespace internal {
template <typename T, OpInfoFillType kType>
struct TypePair {
using Type = T;
static constexpr OpInfoFillType kFillType = kType;
};

using OpRegistryClasses = std::tuple< // NOLINT
Copy link
Collaborator Author

@sneaxiy sneaxiy Mar 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugly but scalable codes here. I rewrite OpInfoFillTypeID::ID() method because the character number limit is set to be 80 in a line.

TypePair<OperatorBase, kOperator>, // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT
>;

static constexpr int kOpRegistryClassNumber =
std::tuple_size<OpRegistryClasses>::value;

template <typename T, int kPos, bool kIsBounded /* = true*/>
struct IsMatchedBaseTypeImpl {
using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
static constexpr bool kValue =
std::is_base_of<typename PairType::Type, T>::value;
};

template <typename T, int kPos>
struct IsMatchedBaseTypeImpl<T, kPos, false> {
static constexpr bool kValue = false;
};

template <typename T, int kPos>
static inline constexpr bool IsMatchedBaseType() {
return IsMatchedBaseTypeImpl<
T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue;
}

template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched>
struct OpInfoFillTypeGetterImpl {};

// This case should not happen
template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> {
static constexpr OpInfoFillType kType = kUnknown;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
static constexpr OpInfoFillType kType =
OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd,
IsMatchedBaseType<T, kStart + 1>()>::kType;
};

template <typename T, int kStart, int kEnd>
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
static constexpr OpInfoFillType kType = PairType::kFillType;
};

template <typename T>
using OpInfoFillTypeGetter =
OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber,
kOpRegistryClassNumber == 0,
IsMatchedBaseType<T, 0>()>;

} // namespace internal

template <typename T>
struct OpInfoFillTypeID {
static constexpr OpInfoFillType ID() {
return std::is_base_of<OperatorBase, T>::value
? kOperator
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
? kOpProtoAndCheckerMaker
: (std::is_base_of<GradOpDescMakerBase, T>::value
? kGradOpDescMaker
: (std::is_base_of<VarTypeInference, T>::value
? kVarTypeInference
: (std::is_base_of<InferShapeBase, T>::value
? kShapeInference
: (std::is_base_of<
InplaceOpInference, T>::value
? kInplaceOpInference
: static_cast<OpInfoFillType>(
-1))))));
return internal::OpInfoFillTypeGetter<T>::kType;
}
};

Expand Down Expand Up @@ -156,6 +216,18 @@ struct OpInfoFiller<T, kInplaceOpInference> {
}
};

template <typename T>
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about when will these three parameters be used to get the NoNeedBufferVars, seems now we just return the parameters specified in the macro as an unordered_set?

Copy link
Collaborator Author

@sneaxiy sneaxiy Mar 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reserve these parameters for future use. Some ops may not need some forward inputs or outputs when some attribute is true/false. For example, batch_norm_grad_op does not need Bias when use_mkldnn is false.

const VariableNameMap& outputs,
const AttributeMap& attrs) {
T infer(inputs, outputs, attrs);
return infer();
};
}
};

} // namespace details

} // namespace framework
Expand Down
Loading