-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Closed
Description
实现
paddle::platform::Transform实现如下:
namespace paddle {
namespace platform {
// Transform on host or device. It provides the same API in std library.
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void Transform(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
if (is_cpu_place(place)) {
std::transform(first, last, result, op);
} else {
#ifdef __NVCC__
auto& ctx = reinterpret_cast<const CUDADeviceContext&>(context);
using namespace details;
thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first),
DevPtrCast(last), DevPtrCast(result), op);
#else
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
#endif
}
}- 实现目的:
- transform函数的使用,CPU实现用
std::transform,GPU实现thrust::transform - 该实现的目的是让CPU 和 GPU的kernel公用同一个
paddle::platform::Transform,使得一些的Kernel不用区分CPU、GPU实现,不用使用if-else分支,而直接公用一个kernel - 为了使得正确编译,加了宏
__NVCC__:使得nvcc编译该代码时,包含#ifdef __NVCC__ #else #endif
thrust::transform相关代码,而g++/gcc编译时不包含thrust::transform相关代码。
- transform函数的使用,CPU实现用
问题
以prelu_op使用举例。
- 比如直接在
prelu_op.h中使用:
Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr,
PReluFunctor<T>(alpha_ptr));-
op编译时,先对
prelu_op.cc编译,此时,使用的是g++/gcc编译,导致编译出来的paddle::platform::Transform接口不包含thrust::transform实现。 -
导致在
prelu_op.cu编译后,链接 时使用了不包含thrust::transform实现的paddle::platform::Transform接口,调用就会走#else分支,而出错。 (感谢@hedaoyuan的指导,学习了~)
- 单测错误
5: ERROR: test_check_grad (__main__.PReluTest)
135: ----------------------------------------------------------------------
135: Traceback (most recent call last):
135: File "test_prelu_op.py", line 24, in test_check_grad
135: self.check_grad(['X'], 'Out')
135: File "/home/dangqingqing/github/myfork/Paddle/python/paddle/v2/framework/tests/op_test.py", line 291, in check_grad
135: for grad_name in grad_names
135: File "/home/dangqingqing/github/myfork/Paddle/python/paddle/v2/framework/tests/op_test.py", line 164, in get_gradient
135: op.run(scope, ctx)
135: RuntimeError: Do not invoke `Transform<GPUPlace>` in .cc file at [/home/dangqingqing/github/myfork/Paddle/paddle/platform/transform.h:46]
135: PaddlePaddle Call Stacks:
135: 0 0x7ffa70cb8ed4p paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 576
135: 1 0x7ffa70d5022cp void paddle::platform::Transform<float const*, float*, paddle::operators::PReluFunctor<float> >(paddle::platform::DeviceContext const&, float const*, float const*, float*, paddle::operators::PReluFunctor<float>) + 203
135: 2 0x7ffa70d52f85p paddle::operators::PReluKernel<paddle::platform::GPUPlace, float>::Compute(paddle::framework::ExecutionContext const&) const + 487
135: 3 0x7ffa70d1d87dp paddle::framework::OperatorWithKernel::Run(paddle::framework::Scope const&, paddle::platform::DeviceContext const&) const + 155
135: 4 0x7ffa70cd9146p pybind11::cpp_function::cpp_function<void, paddle::framework::OperatorBase, paddle::framework::Scope const&, paddle::platform::DeviceContext const&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (paddle::framework::OperatorBase::*)(paddle::framework::Scope const&, paddle::platform::DeviceContext const&) const, pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::OperatorBase const*, paddle::framework::Scope const&, paddle::platform::DeviceContext const&)#1}::operator()(paddle::framework::OperatorBase const*, paddle::framework::Scope const&, paddle::platform::DeviceContext const&) const + 118改进方法
wanghaoshuang
Metadata
Metadata
Assignees
Labels
No labels