Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)

cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
Expand All @@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)

Expand Down Expand Up @@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)

cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)

cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
15 changes: 7 additions & 8 deletions paddle/framework/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ using DataTransformFN =
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;

static void hash_combine(std::size_t& seed, const OpKernelType& t) {
OpKernelType::Hash kernel_type_hasher;
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

struct KernelTypePairHash {
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
OpKernelType::Hash kernel_type_hasher;
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}

size_t operator()(const KernelTypePair& kernel_pair) const {
std::size_t seed = 0;
hash_combine(seed, kernel_pair.first);
hash_combine(seed, kernel_pair.second);

HashCombine(kernel_pair.first, &seed);
HashCombine(kernel_pair.second, &seed);
return seed;
}
};
Expand Down
34 changes: 33 additions & 1 deletion paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>

#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
Expand Down Expand Up @@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key);
}

kernel_iter->second->Compute(ctx);
if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx);
} else {
Scope& op_scope = scope.NewScope();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reyoung @dzhwinter @jacquesqiao I find that we can not cache the transformed result variables in current scope in order to reduce the transform times. Following is an example:

       / op2
op1 ---
       \ op3

The output of op1 is the input of op2 and op3.

If we make cache in current scope,

  • In the first batch training:
    op2 runs first and creates a new variable (var_name + KernelType) and make data transform.
    Then, op3 will check if this variable has been created or not. Since this new variable has been created by op2, op3 will directly use it and has no need to make data transform.

  • In the second batch training:
    We have to make data transform again. But we still check if the new variable is created, the data transform will be skipped.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check the Executor, in every batch, the local scope will be deleted. So this problem will not happen. I will change the cache to local scope instead of op scope.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since each batch will create a new local_scope, add a cache seems can work for our framework.

auto input_vars = this->InputVars();
for (auto var_name : input_vars) {
op_scope.Var(var_name);
}

// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};

// TODO(qijun) get appropriate DataTransformFN from global map
framework::DataTransformFN trans_fun = nullptr;

// Wait for transform starting
dev_ctx->Wait();

for (auto var_name : input_vars) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}

// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
}
}

OpKernelType OperatorWithKernel::GetActualKernelType(
Expand Down