Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
auto* true_out = true_out_ptrs.at(i);
auto calc_out =
std::dynamic_pointer_cast<phi::DenseTensor>(calc_outs->at(i).impl());
// assgin meta info
// assign meta info
auto* true_out_meta = phi::DenseTensorUtils::GetMutableMeta(true_out);
true_out_meta->dims = calc_out->dims();
true_out_meta->dtype = calc_out->dtype();
Expand Down Expand Up @@ -708,6 +708,10 @@ static void RegisterOperatorKernel(const std::string& name,
RegisterOperatorKernelWithPlace(
name, op_kernel_func, proto::VarType::RAW, platform::CUDAPlace());
#endif
#if defined(PADDLE_WITH_XPU)
RegisterOperatorKernelWithPlace(
name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace());
#endif
}

void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/fluid/tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ if(WITH_GPU OR APPLE)
endif()
endif()

if(WITH_XPU)
set(CUSTOM_XPU_ENVS FLAGS_init_allocated_mem=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init这个问题有找到具体原因吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

py_test 是 paddle 自己封装的函数(ref. cmake/generic.cmake)而非 cmake 内部函数,使用 py_test 添加单测时,默认开启 FLAGS_init_allocated_mem=true ,所以此处需要手动设置成 FLAGS_init_allocated_mem=0 避免报错

py_test is a function implemented in paddle (ref. cmake/generic.cmake) instead of cmake internal function. py_test uses FLAGS_init_allocated_mem=true by default, so we need to manually set FLAGS_init_allocated_mem=0 in XPU unit test to avoid error.

py_test(test_custom_relu_op_xpu_setup SRCS test_custom_relu_op_xpu_setup.py
ENVS ${CUSTOM_XPU_ENVS})
endif()

py_test(test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py)
set_tests_properties(test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180)

Expand Down
66 changes: 66 additions & 0 deletions python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2021->2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, next PR fix this.

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>

#include "paddle/extension.h"

#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
#define CHECK_XPU_INPUT(x) PD_CHECK(x.is_xpu(), #x " must be a XPU Tensor.")

template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}

std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x);
auto out = paddle::empty_like(x);

PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.data<data_t>(), x.numel());
}));

return {out};
}

std::vector<paddle::Tensor> relu_xpu_forward(const paddle::Tensor& x) {
CHECK_XPU_INPUT(x);
auto out = paddle::relu(x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块能直接调用XPU的api吗?在paddle中有xpu api的头文件和so

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以的,load 函数提供了 extra_include_pathsbuild_directory 参数,指定头文件和动态链接库(参考文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/utils/cpp_extension/load_cn.html)

extra_include_paths and build_directory parameters in load function can specify the include directory and dynamic link directory (ref. https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/utils/cpp_extension/load_cn.html)

return {out};
}

std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.is_cpu()) {
return relu_cpu_forward(x);
} else if (x.is_xpu()) {
return relu_xpu_forward(x);
} else {
PD_THROW("Not implemented.");
}
}

PD_BUILD_OP(custom_relu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

反向要怎么加?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

反向可以通过前向 api 构造出来,下个 PR 会完善此处的测试用例

grad_op can be implemented by the forward api, next PR will add more test cases here.

.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
27 changes: 27 additions & 0 deletions python/paddle/fluid/tests/custom_op/custom_relu_xpu_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2021 -> 2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, next PR fix this.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from utils import extra_compile_args, paddle_includes

from paddle.utils.cpp_extension import CppExtension, setup

setup(
name='custom_relu_xpu_module_setup',
ext_modules=CppExtension( # XPU don't support GPU
sources=['custom_relu_op_xpu.cc'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args,
verbose=True,
),
)
136 changes: 136 additions & 0 deletions python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2021->2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, next PR fix this.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import site
import sys
import unittest

import numpy as np

import paddle
import paddle.static as static
from paddle.fluid.framework import _test_eager_guard
from paddle.utils.cpp_extension.extension_utils import run_cmd


def custom_relu_dynamic(func, device, dtype, np_x, use_func=True):
paddle.set_device(device)

t = paddle.to_tensor(np_x, dtype=dtype)
out = func(t) if use_func else paddle.nn.functional.relu(t)

return out.numpy()


def custom_relu_static(
func, device, dtype, np_x, use_func=True, test_infer=False
):
paddle.enable_static()
paddle.set_device(device)

with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype=dtype)
out = func(x) if use_func else paddle.nn.functional.relu(x)

exe = static.Executor()
exe.run(static.default_startup_program())
# in static mode, x data has been covered by out
out_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name],
)

paddle.disable_static()
return out_v


class TestNewCustomOpSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
# Currently custom XPU op does not support Windows
if os.name == 'nt':
return
cmd = 'cd {} && {} custom_relu_xpu_setup.py install'.format(
cur_dir, sys.executable
)
run_cmd(cmd)

site_dir = site.getsitepackages()[0]
custom_egg_path = [
x
for x in os.listdir(site_dir)
if 'custom_relu_xpu_module_setup' in x
]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path
)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))

# usage: import the package directly
import custom_relu_xpu_module_setup

self.custom_op = custom_relu_xpu_module_setup.custom_relu

self.dtypes = ['float32', 'float64']
self.devices = ['xpu']

# config seed
SEED = 2021
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

def test_static(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_static(self.custom_op, device, dtype, x)
pd_out = custom_relu_static(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)

def func_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out = custom_relu_dynamic(self.custom_op, device, dtype, x)
pd_out = custom_relu_dynamic(
self.custom_op, device, dtype, x, False
)
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op out: {},\n paddle api out: {}'.format(
out, pd_out
),
)

def test_dynamic(self):
with _test_eager_guard():
self.func_dynamic()
self.func_dynamic()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtaina copy of the License at
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
Expand Down