Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
// deal with inplace name
custom_name = custom_name.substr(0, pos - 1);
}

pos = custom_name.length();
if (custom_name.find("_grad_grad") != custom_name.npos) {
pos = custom_name.find("_grad_grad") + 1;
} else if (custom_name.find("_grad") != custom_name.npos) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,26 @@ class CustomKernelInstruction : public InstructionBase {
void BuildCustomContext(
const paddle::dialect::OpYamlInfoParser& op_yaml_info);

void BuildShapeDtype();

void UpdateOutputMeta(const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<DataType>& output_dtypes);

std::vector<std::vector<int64_t>> RunDefaultInferShape();
std::vector<DataType> RunDefaultInferDtype();
void CheckDefaultInferShapeDtype(
const paddle::dialect::OpYamlInfoParser& op_yaml_info);

paddle::CustomOpKernelContext custom_kernel_ctx_;

paddle::InferShapeFunc infershape_func_ = nullptr;
paddle::InferDtypeFunc inferdtype_func_ = nullptr;
paddle::KernelFunc kernel_func_ = nullptr;

// key is input name, value is a index in input_shapes_ or vec_input_shapes_
std::unordered_map<std::string, int> input_name2id_map_;
std::unordered_map<std::string, int> vec_input_name2id_map_;

// use for runing infershape
std::vector<std::vector<int64_t>> input_shapes_;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes_;
Expand All @@ -63,6 +74,10 @@ class CustomKernelInstruction : public InstructionBase {
std::vector<DataType> input_dtypes_;
std::vector<std::vector<DataType>> vec_input_dtypes_;

// use for calculate input shapes and dtypes in runtime
std::vector<phi::DenseTensor*> input_ptrs_;
std::vector<std::vector<phi::DenseTensor*>> vec_input_ptrs_;

// use for update output
std::vector<phi::DenseTensor*> cache_out_ptrs_;

Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/inference/api/demo_ci/custom_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
}

std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());

auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
x.data<data_t>(), out.data<data_t>(), x.size());
}));

return {out};
Expand All @@ -52,13 +51,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
auto grad_x = paddle::empty_like(x);

PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
grad_x.data<data_t>(),
out.size());
}));

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/inference/api/demo_ci/custom_relu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy,
}

std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
auto out = paddle::empty_like(x);

int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
x.data<data_t>(), out.data<data_t>(), numel);
}));

return {out};
Expand All @@ -53,7 +53,7 @@ std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
auto grad_x = paddle::empty_like(x);

int numel = out.size();
int block = 512;
Expand All @@ -63,7 +63,7 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
<<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
grad_x.data<data_t>(),
numel);
}));

Expand Down
27 changes: 25 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/interface_value.h"
Expand Down Expand Up @@ -356,9 +357,25 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
output_name, "paddle::dialect::DenseTensorType", is_optional, false});
}

auto& inplace_maps = OpMetaInfoHelper::GetInplaceReverseMap(op_meta);

if (!inplace_maps.empty()) {
VLOG(3) << "Register Custom Operator: op inplace_map: "
<< string::join_strings(inplace_maps, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}

std::vector<std::pair<std::string, std::string>> vec_inplace;
for (auto inplace_map : inplace_maps) {
vec_inplace.push_back(inplace_map);
}

// we only need kernel params name in run_time_info
paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo("", {}, "", param_names, {}, {}, {}, {});
paddle::dialect::OpRunTimeInfo(
"", {}, "", param_names, {}, {}, vec_inplace, {});

return std::make_tuple(
inputs_info, attributes_info, outputs_info, run_time_info, "");
}
Expand Down Expand Up @@ -387,6 +404,13 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
pir::TypeId id = IdManager::Instance().CreateId();
std::string op_name = paddle::framework::kCustomDialectPrefix +
OpMetaInfoHelper::GetOpName(op_meta);
std::vector<pir::TypeId> traits;

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta);
if (!inplace_map.empty()) {
op_name += "_";
traits.push_back(pir::TypeId::get<paddle::dialect::InplaceTrait>());
}
op_names_.push_back(op_name);

auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
Expand All @@ -400,7 +424,6 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
AttributeManager::Instance().ToCharPointers(attr_names);
uint32_t attr_num = attr_names.size();

std::vector<pir::TypeId> traits;
std::set<pir::InterfaceValue> interface_values;
pir::InterfaceValue op_info_interface =
pir::InterfaceValue::Get<OpYamlInfoInterface,
Expand Down
5 changes: 5 additions & 0 deletions test/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ if(WITH_TESTING)
set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)
endif()

if(WITH_GPU)
py_test(test_inference_inplace SRCS test_inference_inplace.py)
set_tests_properties(test_inference_inplace PROPERTIES TIMEOUT 180)
endif()

# custom OP support TensorRT inference
if(WITH_GPU
AND WITH_TENSORRT
Expand Down
50 changes: 50 additions & 0 deletions test/custom_op/custom_inplace.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2024,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,3q

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either
// express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>

#include "paddle/extension.h"

#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")

template <typename data_t>
__global__ void relu_cuda_forward_kernel(data_t* x, int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
x[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
}
}

void ReluForwardInplace(paddle::Tensor& x) { // NOLINT
CHECK_GPU_INPUT(x);

PD_CHECK(x.place() == paddle::DefaultGPUPlace());

int64_t numel = x.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t>
<<<grid, block, 0, x.stream()>>>(x.data<data_t>(), numel);
}));
}

PD_BUILD_OP(custom_relu_inplace)
.Inputs({"X"})
.Outputs({"Out"})
.SetInplaceMap({{"X", "Out"}})
.SetKernelFn(PD_KERNEL(ReluForwardInplace));
138 changes: 138 additions & 0 deletions test/custom_op/test_inference_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import numpy as np
from utils import (
extra_cc_args,
extra_nvcc_args,
paddle_includes,
)

import paddle
from paddle.inference import Config, create_predictor
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd

# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = f'{get_build_directory()}\\infer_custom\\infer_custom.pyd'
if os.name == 'nt' and os.path.isfile(file):
cmd = f'del {file}'
run_cmd(cmd, True)

# Compile and load custom op Just-In-Time.
custom_inplace = load(
name='infer_custom',
sources=['custom_inplace.cu'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cflags
extra_cuda_cflags=extra_nvcc_args, # test for cflags
verbose=True,
)


class TestInplaceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = paddle.nn.Linear(4, 4)

def forward(self, x):
fc_out = self.fc(x)
out = custom_inplace.custom_relu_inplace(fc_out)
mean_out = paddle.mean(out)
return mean_out


@unittest.skipIf(
not paddle.is_compiled_with_cuda(), 'should compile with cuda.'
)
class TestPredictorRunWithTensor(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
net = TestInplaceNet()
model = paddle.jit.to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4], dtype='float32', name='x'
),
],
)
paddle.jit.save(
model,
os.path.join(
self.temp_dir.name, 'test_predictor_run_model/inference'
),
)

def tearDown(self):
self.temp_dir.cleanup()

def enable_pir(self, flag: bool):
paddle.set_flags({'FLAGS_enable_pir_in_executor': flag})

def init_predictor(self):
config = Config(
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdmodel',
),
os.path.join(
self.temp_dir.name,
'test_predictor_run_model/inference.pdiparams',
),
)
config.enable_use_gpu(256, 0)
config.switch_ir_optim(False)
config.enable_new_executor()
predictor = create_predictor(config)
return predictor

def get_inputs(self):
x = np.array([[1, 2, 3, 4], [2, 3, 4, 5]]).astype(np.float32)

x_tensor = paddle.to_tensor(x)

return [x_tensor]

def get_outputs(self, predictor):
[x_tensor] = self.get_inputs()

input_names = predictor.get_input_names()
x_tensor.name = input_names[0]

# disorder
inputs = [x_tensor]
outputs = predictor.run(inputs)

return outputs[0]

def test_output(self):
self.enable_pir(True)
pir_predictor = self.init_predictor()
pir_output = self.get_outputs(pir_predictor)
self.enable_pir(False)
predictor = self.init_predictor()
output = self.get_outputs(predictor)
np.testing.assert_allclose(
output.numpy().flatten(), pir_output.numpy().flatten()
)


if __name__ == "__main__":
unittest.main()