Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9ea0932
add rename guard
jacquesqiao Jan 2, 2018
1169d48
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jan 2, 2018
9d23ab5
add device_data_transform
jacquesqiao Jan 3, 2018
abc86fb
add device_data_transform_test
jacquesqiao Jan 3, 2018
2ce1783
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jan 3, 2018
4fc3a26
modify GetExpectedKernelType
jacquesqiao Jan 3, 2018
285218b
update operator.run
jacquesqiao Jan 3, 2018
8363365
support test test_label_semantic_roles
jacquesqiao Jan 4, 2018
dbcf818
optimize code
jacquesqiao Jan 4, 2018
012c345
optimize code
jacquesqiao Jan 4, 2018
a01e310
rename GetActualKernelType to GetExpectedKernelType
jacquesqiao Jan 4, 2018
a8862a9
fix chunk_eval_op and device_data_transform_test
jacquesqiao Jan 4, 2018
b06cd27
add is_same_place to place
jacquesqiao Jan 4, 2018
3db3e6f
optimize code, refine rename_guard
jacquesqiao Jan 4, 2018
28e55f8
refine rename guard, add GetKernelTypeForVar
jacquesqiao Jan 5, 2018
c60e058
optimize code
jacquesqiao Jan 5, 2018
72851f3
add some log
jacquesqiao Jan 5, 2018
e771d22
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jan 5, 2018
a253f39
rename guard
jacquesqiao Jan 5, 2018
152c24a
use sub scope to create var
jacquesqiao Jan 5, 2018
85a4262
fix compile
jacquesqiao Jan 5, 2018
ae580c7
add IsInitialized for Tensor
jacquesqiao Jan 5, 2018
3c43984
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jan 6, 2018
b56c61d
add VarIsTensor
jacquesqiao Jan 6, 2018
92f5c8d
fix op_registry_test
jacquesqiao Jan 6, 2018
82787a1
test
jacquesqiao Jan 7, 2018
87c35ea
tmp disable priority
jacquesqiao Jan 8, 2018
9c355ae
restore switch_kernel.md
jacquesqiao Jan 8, 2018
ec76669
code clean
jacquesqiao Jan 8, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)

cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)

cc_library(attribute SRCS attribute.cc DEPS framework_proto)
Expand Down Expand Up @@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
cc_test(init_test SRCS init_test.cc DEPS init)

cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)

nv_test(device_data_transform_test SRCS device_data_transform_test.cu
DEPS operator op_registry init math_function)
33 changes: 33 additions & 0 deletions paddle/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ limitations under the License. */
#include <functional>

#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"

namespace paddle {
Expand All @@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
return data_transform_map;
}

Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else {
PADDLE_THROW("unknown var type");
}
}

auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);

Expand Down
8 changes: 8 additions & 0 deletions paddle/framework/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>

#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/math/math_function.h"
Expand Down Expand Up @@ -49,6 +50,13 @@ struct KernelTypePairHash {
}
};

Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);

void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);

template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
Expand Down
46 changes: 46 additions & 0 deletions paddle/framework/device_data_transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/device_data_transform.h"

namespace paddle {
namespace framework {

static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}

Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
CopyFrom(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}

} // namespace framework
} // namespace paddle
27 changes: 27 additions & 0 deletions paddle/framework/device_data_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"

namespace paddle {
namespace framework {

Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);

} // namespace framework
} // namespace paddle
168 changes: 168 additions & 0 deletions paddle/framework/device_data_transform_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "gtest/gtest.h"

#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"

namespace paddle {
namespace framework {

template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};

class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
AddComment("This is test op");
}
};

class TestOpWithKernel : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
};

template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;

const Tensor* input = ctx.Input<Tensor>("input");

std::cout << "input place:" << input->place() << std::endl;
auto* output = ctx.Output<framework::LoDTensor>("output");
output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace());

operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
input, input, output, ctx.template device_context<DeviceContext>(),
AddFunctor<T>());
functor.Run();
}
};

} // namespace framework
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
test_op, paddle::framework::TestOpWithKernel,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);

static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}

TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;

ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);

paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;

// create an op to run on CPU
paddle::framework::proto::OpDesc cpu_op_desc;
cpu_op_desc.set_type("test_op");
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());

auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
// prepare input
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float>(i);
}

// get output
auto* output = scope.Var("OUT1");
cpu_op->Run(scope, cpu_place);

auto* output_ptr = output->Get<LoDTensor>().data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
}

// create an op to run on GPU
paddle::framework::proto::OpDesc gpu_op_desc;
gpu_op_desc.set_type("test_op");
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());

auto attr = gpu_op_desc.mutable_attrs()->Add();
attr->set_name("use_gpu");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);

auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);

paddle::platform::CUDAPlace cuda_place(0);
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);

// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
auto dev_ctx = pool.Get(cuda_place);

paddle::framework::Tensor output_tensor;
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor);

dev_ctx->Wait();
float* output2_ptr = output_tensor.data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
}
}
22 changes: 9 additions & 13 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}

framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
Expand Down Expand Up @@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}

framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}

framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};

Expand Down Expand Up @@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);

// TODO(qiao) add priority back
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
Expand All @@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU();
op->Run(scope, cpu_place);

EXPECT_EQ(op_test_value, -9);
EXPECT_EQ(op_test_value, -20);

// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);

EXPECT_EQ(op_test_value, -10);
EXPECT_EQ(op_test_value, -30);

// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20);
EXPECT_EQ(op_test_value, -40);
}
Loading