Skip to content

Commit 94096ae

Browse files
authored
add memory switch mechanism in operator kernel switch (#6991)
* add memory switch mechanism in operator kernel switch
1 parent bff0cbf commit 94096ae

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
2121
cc_library(scope SRCS scope.cc DEPS glog)
2222
cc_test(scope_test SRCS scope_test.cc DEPS scope)
2323

24+
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
25+
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
2426

2527
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
2628
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
2931
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
3032
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
3133
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
32-
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
34+
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
35+
shape_inference data_transform)
3336
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
3437
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
3538

@@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
6568
cc_test(init_test SRCS init_test.cc DEPS init)
6669

6770
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
68-
69-
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
70-
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)

paddle/framework/data_transform.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ using DataTransformFN =
3232
const Variable& in, Variable* out)>;
3333
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
3434

35-
static void hash_combine(std::size_t& seed, const OpKernelType& t) {
36-
OpKernelType::Hash kernel_type_hasher;
37-
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
38-
}
39-
4035
struct KernelTypePairHash {
36+
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
37+
OpKernelType::Hash kernel_type_hasher;
38+
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
39+
}
40+
4141
size_t operator()(const KernelTypePair& kernel_pair) const {
4242
std::size_t seed = 0;
43-
hash_combine(seed, kernel_pair.first);
44-
hash_combine(seed, kernel_pair.second);
45-
43+
HashCombine(kernel_pair.first, &seed);
44+
HashCombine(kernel_pair.second, &seed);
4645
return seed;
4746
}
4847
};

paddle/framework/operator.cc

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include <algorithm>
1616
#include <atomic>
1717

18+
#include "paddle/framework/data_transform.h"
1819
#include "paddle/framework/executor.h"
1920
#include "paddle/framework/lod_tensor_array.h"
2021
#include "paddle/framework/operator.h"
@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
411412
expected_kernel_key);
412413
}
413414

414-
kernel_iter->second->Compute(ctx);
415+
if (actual_kernel_key == expected_kernel_key) {
416+
kernel_iter->second->Compute(ctx);
417+
} else {
418+
Scope& op_scope = scope.NewScope();
419+
auto input_vars = this->InputVars();
420+
for (auto var_name : input_vars) {
421+
op_scope.Var(var_name);
422+
}
423+
424+
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
425+
platform::DeviceContext* trans_dev_ctx = nullptr;
426+
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
427+
428+
// TODO(qijun) get appropriate DataTransformFN from global map
429+
framework::DataTransformFN trans_fun = nullptr;
430+
431+
// Wait for transform starting
432+
dev_ctx->Wait();
433+
434+
for (auto var_name : input_vars) {
435+
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
436+
op_scope.FindVar(var_name));
437+
}
438+
// Wait for data transform finishing
439+
for (auto ctx : trans_dev_ctx_vec) {
440+
ctx->Wait();
441+
}
442+
443+
// Create a new ExecutionContext
444+
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
445+
kernel_iter->second->Compute(op_ctx);
446+
}
415447
}
416448

417449
OpKernelType OperatorWithKernel::GetActualKernelType(

0 commit comments

Comments
 (0)