From 8875679f5fac62419a253482867cce4fb7790f91 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 3 Jan 2024 07:12:54 +0000 Subject: [PATCH 01/12] support inplace in pir --- .../fluid/framework/custom_operator_utils.h | 7 + .../instruction/custom_kernel_instruction.cc | 274 +++++++++++++----- .../instruction/custom_kernel_instruction.h | 15 + .../inference/api/demo_ci/custom_relu_op.cc | 9 +- .../inference/api/demo_ci/custom_relu_op.cu | 8 +- .../pir/dialect/operator/ir/op_dialect.cc | 27 +- test/custom_op/CMakeLists.txt | 2 + test/custom_op/custom_inplace.cc | 21 +- test/custom_op/custom_inplace.cu | 44 +++ test/custom_op/test_inference_inplace.py | 138 +++++++++ 10 files changed, 450 insertions(+), 95 deletions(-) create mode 100644 test/custom_op/custom_inplace.cu create mode 100644 test/custom_op/test_inference_inplace.py diff --git a/paddle/fluid/framework/custom_operator_utils.h b/paddle/fluid/framework/custom_operator_utils.h index bf1750dfdbbb50..4cd9b91330d143 100644 --- a/paddle/fluid/framework/custom_operator_utils.h +++ b/paddle/fluid/framework/custom_operator_utils.h @@ -86,6 +86,13 @@ inline static const OpMetaInfo& GetOpInfoByPirName( const std::string& pir_op_name) { auto custom_name = pir_op_name.substr(strlen(kCustomDialectPrefix)); int pos = custom_name.length(); + + if (custom_name[pos - 1] == '_') { + // deal with inplace name + custom_name = custom_name.substr(0, pos - 1); + } + + pos = custom_name.length(); if (custom_name.find("_grad_grad") != custom_name.npos) { pos = custom_name.find("_grad_grad") + 1; } else if (custom_name.find("_grad") != custom_name.npos) { diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc index a585976fd6b9af..378a5d2d90ee3a 100644 --- a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc @@ -32,10 +32,13 @@ void CustomKernelInstruction::BuildCustomContext( << "]"; auto attr_map = op_->attributes(); - + CheckDefaultInferShapeDtype(op_yaml_info); // EmplaceBackInputs auto& vec_input_tensor_params = op_yaml_info.TensorParams(true); auto& name2id = op_yaml_info.InputName2Id(); + auto inplace_id_map = op_yaml_info.GetInplaceIdMap(); + int input_index = 0; + int vec_input_index = 0; for (auto& t : vec_input_tensor_params) { PADDLE_ENFORCE_EQ( name2id.count(t), @@ -43,12 +46,12 @@ void CustomKernelInstruction::BuildCustomContext( phi::errors::NotFound("param [%s] MUST in name2id map", t)); pir::Value ptr = op_->operand_source(op_yaml_info.InputName2Id().at(t)); - if (!IsInvalid(ptr)) { if (op_yaml_info.GetInputType(op_yaml_info.InputName2Id().at(t)) == "pir::VectorType") { - vec_input_shapes_.emplace_back(); - vec_input_dtypes_.emplace_back(); + vec_input_name2id_map_[t] = vec_input_index; + vec_input_index++; + vec_input_ptrs_.emplace_back(); // NOTE(YuanRisheng): In dygraph mode, we can not distinguish Tensor and // vector when user inputs None, so dygraph mode appends one // un-initialized Tensor to CustomOpKernelContext. To be compatible with @@ -58,15 +61,16 @@ void CustomKernelInstruction::BuildCustomContext( custom_vec_in.emplace_back(paddle::Tensor()); custom_kernel_ctx_.EmplaceBackInputs(std::move(custom_vec_in)); } else { - input_shapes_.emplace_back(); - input_dtypes_.emplace_back(); + input_name2id_map_[t] = vec_input_index; + input_index++; + input_ptrs_.emplace_back(nullptr); custom_kernel_ctx_.EmplaceBackInput(std::move(paddle::Tensor())); } VLOG(8) << "ctx->EmplaceBackInput : an optioanl input " << t; continue; } - auto in_var_name = value_exec_info_.GetVarName(ptr); + VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), @@ -75,18 +79,19 @@ void CustomKernelInstruction::BuildCustomContext( auto var = inner_scope->FindVar(in_var_name); if (var->IsType()) { auto dense_tensor_in = var->GetMutable(); + std::shared_ptr tensor_in( dense_tensor_in, [](phi::DenseTensor* ptr) { VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; }); - input_shapes_.push_back(phi::vectorize(tensor_in->dims())); - input_dtypes_.push_back(tensor_in->dtype()); + input_name2id_map_[t] = vec_input_index; + input_index++; + input_ptrs_.push_back(dense_tensor_in); paddle::Tensor custom_in; custom_in.set_impl(tensor_in); custom_kernel_ctx_.EmplaceBackInput(std::move(custom_in)); } else if (var->IsType()) { - std::vector> vec_input_shape; - std::vector vec_input_dtype; + std::vector vec_input_ptrs; std::vector vec_custom_in; auto& variable_array = var->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { @@ -97,8 +102,7 @@ void CustomKernelInstruction::BuildCustomContext( dense_tensor_in, [](phi::DenseTensor* ptr) { VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; }); - vec_input_shape.push_back(phi::vectorize(tensor_in->dims())); - vec_input_dtype.push_back(tensor_in->dtype()); + vec_input_ptrs.push_back(dense_tensor_in); paddle::Tensor custom_in; custom_in.set_impl(tensor_in); vec_custom_in.push_back(std::move(custom_in)); @@ -109,15 +113,15 @@ void CustomKernelInstruction::BuildCustomContext( variable_array[i]->Type())); } } - vec_input_shapes_.push_back(vec_input_shape); - vec_input_dtypes_.push_back(vec_input_dtype); + vec_input_name2id_map_[t] = vec_input_index; + vec_input_index++; + vec_input_ptrs_.push_back(vec_input_ptrs); custom_kernel_ctx_.EmplaceBackInputs(vec_custom_in); } else { PADDLE_THROW(phi::errors::Unimplemented("Not support var type [%d] ", var->Type())); } } - // EmplaceBackAttributes auto& vec_attr_params = op_yaml_info.AttrParams(true); for (auto& t : vec_attr_params) { @@ -239,17 +243,25 @@ void CustomKernelInstruction::BuildCustomContext( VLOG(8) << "ctx->EmplaceBackOutput: "; for (size_t i = 0; i < op_->num_results(); ++i) { pir::Value out_ptr = op_->result(i); + auto out_name = op_yaml_info.OutputNames()[i]; if (!IsInvalid(out_ptr)) { - if (op_yaml_info.GetOutputType(i) == - "pir::VectorType") { - std::vector custom_vec_out; - custom_vec_out.emplace_back(); - cache_out_ptrs_.emplace_back(nullptr); - custom_kernel_ctx_.EmplaceBackOutputs(std::move(custom_vec_out)); - } else { - cache_out_ptrs_.emplace_back(nullptr); - custom_kernel_ctx_.EmplaceBackOutput(std::move(paddle::Tensor())); - } + PADDLE_ENFORCE( + paddle::framework::detail::IsOptionalVar(out_name) && + !inplace_id_map.empty(), + phi::errors::InvalidArgument( + "Custom operator couldn't find custom output for name %s. If " + "you " + "are using inplace optional inputs & outputs, please check " + "your " + "InplaceMap and `Outputs` again and make sure %s is wrapped by " + "`paddle::Optional`", + out_name, + out_name)); + VLOG(3) << "Custom Operator: BuildContext - inplace optional outputs : " + << out_name << " is None."; + cache_out_ptrs_.emplace_back(nullptr); + custom_kernel_ctx_.EmplaceBackOutput(std::move(paddle::Tensor())); + VLOG(8) << "ctx->EmplaceBackOutput : an optioanl output"; continue; } @@ -276,10 +288,15 @@ void CustomKernelInstruction::BuildCustomContext( inner_scope->FindVar(value_exec_info_.GetVarName(out_ptr)) ->Get(); std::vector custom_vec_out; - for (size_t i = 0; i < variable_array.size(); ++i) { - if (variable_array[i]->IsType()) { + PADDLE_ENFORCE( + !inplace_id_map.empty() || (i == 0UL && op_->num_results() == 1UL), + phi::errors::PreconditionNotMet( + "If custom operator's outputs contains `paddle::Vec()` type " + "without setting InplaceMap, it only can hold one output.")); + for (size_t j = 0; j < variable_array.size(); ++j) { + if (variable_array[j]->IsType()) { auto dense_tensor_out = const_cast( - &(variable_array[i]->Get())); + &(variable_array[j]->Get())); cache_out_ptrs_.emplace_back(dense_tensor_out); std::shared_ptr tensor_out( dense_tensor_out, [](phi::DenseTensor* ptr) { @@ -290,9 +307,9 @@ void CustomKernelInstruction::BuildCustomContext( custom_vec_out.push_back(std::move(custom_out)); } else { PADDLE_THROW(phi::errors::Unimplemented( - "Only support Vector and vector now, " + "Only support Vector now, " "not support vector<%d>.", - variable_array[i]->Type())); + variable_array[j]->Type())); } } VLOG(8) << "ctx->EmplaceBackOutput VariableRefArray: " @@ -303,6 +320,7 @@ void CustomKernelInstruction::BuildCustomContext( phi::errors::Unimplemented("only support DenseTensor and vector ")); } } + auto& op_inputs = OpMetaInfoHelper::GetInputs(*custom_op_meta_); auto& op_outputs = OpMetaInfoHelper::GetOutputs(*custom_op_meta_); auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); @@ -408,37 +426,110 @@ void CustomKernelInstruction::UpdateOutputMeta( } } -void CustomKernelInstruction::Run() { - VLOG(3) << "Custom Operator: InferShape - calc output ddim."; - std::vector> output_shapes; - std::vector output_dtypes; - if (infershape_func_) { - output_shapes = - infershape_func_(input_shapes_, vec_input_shapes_, custom_attrs_); - } else { +void CustomKernelInstruction::CheckDefaultInferShapeDtype( + const paddle::dialect::OpYamlInfoParser& op_yaml_info) { + if (infershape_func_ && inferdtype_func_) { + return; + } + auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); + if (inplace_map.empty()) { // general case, assure single input and output PADDLE_ENFORCE_EQ( OpMetaInfoHelper::GetInputs(*custom_op_meta_).size(), 1UL, phi::errors::Unavailable( "Your custom operator contains multiple inputs. " "We only allow a custom operator that contains only one input " - "and only one output without setting the InferShapeFn. " - "At this time, the input shape will be directly set to " - "the output shape.\n" - "Please set the InferShapeFn of custom " - "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + "and only one output without setting the " + "InferShapeFn/InferDtypeFn. " + "At this time, the input shape/dtype will be directly set to " + "the output shape/dtype.\n" + "Please set the InferShapeFn/InferDtypeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...)) / " + ".SetInferDtypeFn(PD_INFER_DTYPE(...))")); PADDLE_ENFORCE_EQ( OpMetaInfoHelper::GetOutputs(*custom_op_meta_).size(), 1UL, phi::errors::Unavailable( "Your custom operator contains multiple outputs. " "We only allow a custom operator that contains only one input " - "and only one output without setting the InferShapeFn. " - "At this time, the input shape will be directly set to " - "the output shape.\n" - "Please set the InferShapeFn of custom " - "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + "and only one output without setting the " + "InferShapeFn/InferDtypeFn. " + "At this time, the input shape/dtype will be directly set to " + "the output shape/dtype.\n" + "Please set the InferShapeFn/InferDtypeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...)) / " + ".SetInferDtypeFn(PD_INFER_DTYPE(...))")); + } else { // inplace case + PADDLE_ENFORCE_EQ( + inplace_map.size(), + op_yaml_info.OutputNames().size(), + phi::errors::Unavailable( + "Your custom operator uses `SetInplaceMap` without setting the " + "InferShapeFn/InferDtypeFn. However, `Outputs` size = %d does not " + "match the " + "`InplaceMap` size = %d. Please check `SetInplaceMap` again or set " + "the InferShapeFn/InferDtypeFn of custom operator by " + ".SetInferShapeFn(PD_INFER_SHAPE(...)) / " + ".SetInferDtypeFn(PD_INFER_DTYPE(...))", + op_yaml_info.OutputNames().size(), + inplace_map.size())); + + for (auto const& pair : inplace_map) { + pir::Value output_value = + op_->result(op_yaml_info.OutputName2Id().at(pair.second)); + if (paddle::framework::detail::IsDuplicableVar(pair.first) && + !IsInvalid(output_value)) { + // make sure ctx has valid inplace optional outputs + PADDLE_ENFORCE( + paddle::framework::detail::IsOptionalVar(pair.second), + phi::errors::InvalidArgument( + "Custom operator couldn't find custom output name for %s. If " + "you are using inplace optional inputs & outputs, please " + "check " + "your InplaceMap and `Outputs` again and make sure %s is " + "wrapped by `paddle::Optional`", + pair.second, + pair.second)); + } + } + } +} + +void CustomKernelInstruction::BuildShapeDtype() { + input_shapes_.clear(); + input_dtypes_.clear(); + vec_input_shapes_.clear(); + vec_input_dtypes_.clear(); + for (auto in_tensor : input_ptrs_) { + if (in_tensor) { + input_shapes_.push_back(phi::vectorize(in_tensor->dims())); + input_dtypes_.push_back(in_tensor->dtype()); + } else { + input_shapes_.emplace_back(); + input_dtypes_.emplace_back(); + } + } + for (auto in_tensors : vec_input_ptrs_) { + std::vector> input_shapes; + std::vector input_dtypes; + if (in_tensors.size() > 0) { + for (auto in_tensor : in_tensors) { + input_shapes.push_back(phi::vectorize(in_tensor->dims())); + input_dtypes.push_back(in_tensor->dtype()); + } + } + vec_input_shapes_.push_back(input_shapes); + vec_input_dtypes_.push_back(input_dtypes); + } +} +std::vector> +CustomKernelInstruction::RunDefaultInferShape() { + std::vector> output_shapes; + auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); + auto& inplace_reverse_map = + OpMetaInfoHelper::GetInplaceReverseMap(*custom_op_meta_); + if (inplace_map.empty()) { // general case, assure single input and output VLOG(3) << "Custom Operator: Default InferShape - share ddim."; if (input_shapes_.size() == 1) { output_shapes = input_shapes_; @@ -449,36 +540,30 @@ void CustomKernelInstruction::Run() { "We only allow a custom operator that contains only one input " "and only one output without setting the InferShapeFn. ")); } + } else { // inplace case + for (auto const& pair : inplace_reverse_map) { + if (paddle::framework::detail::IsDuplicableVar(pair.first)) { + int input_index = vec_input_name2id_map_[pair.second]; + auto input_shape = vec_input_shapes_[input_index]; + output_shapes.insert( + output_shapes.end(), input_shape.begin(), input_shape.end()); + } else { + int input_index = input_name2id_map_[pair.second]; + auto input_shape = input_shapes_[input_index]; + output_shapes.push_back(input_shape); + } + } } + return output_shapes; +} - if (inferdtype_func_) { - output_dtypes = - inferdtype_func_(input_dtypes_, vec_input_dtypes_, custom_attrs_); - } else { - PADDLE_ENFORCE_EQ( - OpMetaInfoHelper::GetInputs(*custom_op_meta_).size(), - 1UL, - phi::errors::Unavailable( - "Your custom operator contains multiple inputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferDtypeFn. " - "At this time, the input dtype will be directly set to " - "the output dtype.\n" - "Please set the InferDtypeFn of custom " - "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); - PADDLE_ENFORCE_EQ( - OpMetaInfoHelper::GetOutputs(*custom_op_meta_).size(), - 1UL, - phi::errors::Unavailable( - "Your custom operator contains multiple outputs. " - "We only allow a custom operator that contains only one input " - "and only one output without setting the InferDtypeFn. " - "At this time, the input dtype will be directly set to " - "the output dtype.\n" - "Please set the InferDtypeFn of custom " - "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); - - VLOG(3) << "Custom Operator: InferDtype - share dtype."; +std::vector CustomKernelInstruction::RunDefaultInferDtype() { + std::vector output_dtypes; + auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); + auto& inplace_reverse_map = + OpMetaInfoHelper::GetInplaceReverseMap(*custom_op_meta_); + if (inplace_map.empty()) { // general case, assure single input and output + VLOG(3) << "Custom Operator: Default InferDtype - share ddim."; if (input_dtypes_.size() == 1) { output_dtypes = input_dtypes_; } else if (vec_input_dtypes_.size() == 1) { @@ -488,12 +573,45 @@ void CustomKernelInstruction::Run() { "We only allow a custom operator that contains only one input " "and only one output without setting the InferDtypeFn. ")); } + } else { // inplace case + for (auto const& pair : inplace_reverse_map) { + if (paddle::framework::detail::IsDuplicableVar(pair.first)) { + int input_index = vec_input_name2id_map_[pair.second]; + auto input_dtype = vec_input_dtypes_[input_index]; + output_dtypes.insert( + output_dtypes.end(), input_dtype.begin(), input_dtype.end()); + } else { + int input_index = input_name2id_map_[pair.second]; + auto input_dtype = input_dtypes_[input_index]; + output_dtypes.push_back(input_dtype); + } + } + } + return output_dtypes; +} + +void CustomKernelInstruction::Run() { + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + BuildShapeDtype(); + std::vector> output_shapes; + std::vector output_dtypes; + if (infershape_func_) { + output_shapes = + infershape_func_(input_shapes_, vec_input_shapes_, custom_attrs_); + } else { + output_shapes = RunDefaultInferShape(); + } + + if (inferdtype_func_) { + output_dtypes = + inferdtype_func_(input_dtypes_, vec_input_dtypes_, custom_attrs_); + } else { + output_dtypes = RunDefaultInferDtype(); } UpdateOutputMeta(output_shapes, output_dtypes); VLOG(6) << "Run custom op " << custom_op_name_ << " kernel."; kernel_func_(&custom_kernel_ctx_); - custom_kernel_ctx_.AssignInplaceOutputs(); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h index 6c6a7d90ae8f0f..5c24236509c9c6 100644 --- a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h @@ -45,15 +45,26 @@ class CustomKernelInstruction : public InstructionBase { void BuildCustomContext( const paddle::dialect::OpYamlInfoParser& op_yaml_info); + void BuildShapeDtype(); + void UpdateOutputMeta(const std::vector>& output_shapes, const std::vector& output_dtypes); + std::vector> RunDefaultInferShape(); + std::vector RunDefaultInferDtype(); + void CheckDefaultInferShapeDtype( + const paddle::dialect::OpYamlInfoParser& op_yaml_info); + paddle::CustomOpKernelContext custom_kernel_ctx_; paddle::InferShapeFunc infershape_func_ = nullptr; paddle::InferDtypeFunc inferdtype_func_ = nullptr; paddle::KernelFunc kernel_func_ = nullptr; + // key is input name, value is a index in input_shapes_ or vec_input_shapes_ + std::unordered_map input_name2id_map_; + std::unordered_map vec_input_name2id_map_; + // use for runing infershape std::vector> input_shapes_; std::vector>> vec_input_shapes_; @@ -63,6 +74,10 @@ class CustomKernelInstruction : public InstructionBase { std::vector input_dtypes_; std::vector> vec_input_dtypes_; + // use for calculate input shapes and dtypes in runtime + std::vector input_ptrs_; + std::vector> vec_input_ptrs_; + // use for update output std::vector cache_out_ptrs_; diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc index e55b943a5568f8..43be102d97865a 100755 --- a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc @@ -38,12 +38,11 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, } std::vector relu_cpu_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); - + auto out = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( - x.data(), out.mutable_data(x.place()), x.size()); + x.data(), out.data(x.place()), x.size()); })); return {out}; @@ -52,13 +51,13 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { std::vector relu_cpu_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto grad_x = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { relu_cpu_backward_kernel( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), + grad_x.data(x.place()), out.size()); })); diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu index a4b7fcf06bce6c..fd24b1dd150ef1 100644 --- a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu @@ -36,7 +36,7 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, } std::vector relu_cuda_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto out = paddle::empty_like(x); int numel = x.size(); int block = 512; @@ -44,7 +44,7 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( - x.data(), out.mutable_data(x.place()), numel); + x.data(), out.data(x.place()), numel); })); return {out}; @@ -53,7 +53,7 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { std::vector relu_cuda_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto grad_x = paddle::empty_like(x); int numel = out.size(); int block = 512; @@ -63,7 +63,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, <<>>( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), + grad_x.data(x.place()), numel); })); diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 80f6e598f967c2..ab8a6123b9d94d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/interface_value.h" @@ -356,9 +357,25 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { output_name, "paddle::dialect::DenseTensorType", is_optional, false}); } + auto& inplace_maps = OpMetaInfoHelper::GetInplaceReverseMap(op_meta); + + if (!inplace_maps.empty()) { + VLOG(3) << "Register Custom Operator: op inplace_map: " + << string::join_strings(inplace_maps, ',', [](auto& pair) { + return pair.first + ": " + pair.second; + }); + } + + std::vector> vec_inplace; + for (auto inplace_map : inplace_maps) { + vec_inplace.push_back(inplace_map); + } + // we only need kernel params name in run_time_info paddle::dialect::OpRunTimeInfo run_time_info = - paddle::dialect::OpRunTimeInfo("", {}, "", param_names, {}, {}, {}, {}); + paddle::dialect::OpRunTimeInfo( + "", {}, "", param_names, {}, {}, vec_inplace, {}); + return std::make_tuple( inputs_info, attributes_info, outputs_info, run_time_info, ""); } @@ -387,6 +404,13 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) { pir::TypeId id = IdManager::Instance().CreateId(); std::string op_name = paddle::framework::kCustomDialectPrefix + OpMetaInfoHelper::GetOpName(op_meta); + std::vector traits; + + auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta); + if (!inplace_map.empty()) { + op_name += "_"; + traits.push_back(pir::TypeId::get()); + } op_names_.push_back(op_name); auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta); @@ -400,7 +424,6 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) { AttributeManager::Instance().ToCharPointers(attr_names); uint32_t attr_num = attr_names.size(); - std::vector traits; std::set interface_values; pir::InterfaceValue op_info_interface = pir::InterfaceValue::Get -void relu_forward_kernel(data_t* x_data, int64_t numel) { +void relu_cpu_forward_kernel(data_t* x_data, int64_t numel) { for (size_t i = 0; i < numel; ++i) { x_data[i] = x_data[i] > 0 ? x_data[i] : 0; } @@ -200,15 +200,24 @@ PD_BUILD_GRAD_OP(custom_multi_inplace) {paddle::Grad("OutAB"), paddle::Grad("A")}}) .SetKernelFn(PD_KERNEL(MultiInplaceBackward)); -void ReluForwardInplace(paddle::Tensor& x) { // NOLINT +void relu_cpu_forward(paddle::Tensor& x) { // NOLINT CHECK_INPUT(x); - - PD_DISPATCH_FLOATING_TYPES(x.type(), "ReluForward", ([&] { - relu_forward_kernel(x.data(), - x.size()); + PD_DISPATCH_FLOATING_TYPES(x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel(x.data(), + x.size()); })); } +void relu_cuda_forward(paddle::Tensor& x); // NOLINT + +void ReluForwardInplace(paddle::Tensor& x) { // NOLINT + if (x.is_cpu()) { + relu_cpu_forward(x); + } else { + relu_cuda_forward(x); + } +} + void ReluBackwardInplace(const paddle::Tensor& x, const paddle::Tensor& out, paddle::Tensor& grad_out) { // NOLINT diff --git a/test/custom_op/custom_inplace.cu b/test/custom_op/custom_inplace.cu new file mode 100644 index 00000000000000..53379788af5833 --- /dev/null +++ b/test/custom_op/custom_inplace.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either +// express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +template +__global__ void relu_cuda_forward_kernel(data_t* x, int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + x[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); + } +} + +void relu_cuda_forward(paddle::Tensor& x) { // NOLINT + CHECK_GPU_INPUT(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel + <<>>(x.data(), numel); + })); +} diff --git a/test/custom_op/test_inference_inplace.py b/test/custom_op/test_inference_inplace.py new file mode 100644 index 00000000000000..76e3e54f8f1db0 --- /dev/null +++ b/test/custom_op/test_inference_inplace.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +from utils import ( + extra_cc_args, + extra_nvcc_args, + paddle_includes, +) + +import paddle +from paddle.inference import Config, create_predictor +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = f'{get_build_directory()}\\custom_inplace\\custom_inplace.pyd' +if os.name == 'nt' and os.path.isfile(file): + cmd = f'del {file}' + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +custom_inplace = load( + name='custom_inplace', + sources=['custom_inplace.cc', 'custom_inplace.cu'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cflags + extra_cuda_cflags=extra_nvcc_args, # test for cflags + verbose=True, +) + + +class TestInplaceNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + fc_out = self.fc(x) + out = custom_inplace.custom_relu_inplace(fc_out) + mean_out = paddle.mean(out) + return mean_out + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda(), 'should compile with cuda.' +) +class TestPredictorRunWithTensor(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + net = TestInplaceNet() + model = paddle.jit.to_static( + net, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 4], dtype='float32', name='x' + ), + ], + ) + paddle.jit.save( + model, + os.path.join( + self.temp_dir.name, 'test_predictor_run_model/inference' + ), + ) + + def tearDown(self): + self.temp_dir.cleanup() + + def enable_pir(self, flag: bool): + paddle.set_flags({'FLAGS_enable_pir_in_executor': flag}) + + def init_predictor(self): + config = Config( + os.path.join( + self.temp_dir.name, + 'test_predictor_run_model/inference.pdmodel', + ), + os.path.join( + self.temp_dir.name, + 'test_predictor_run_model/inference.pdiparams', + ), + ) + config.enable_use_gpu(256, 0) + config.switch_ir_optim(False) + config.enable_new_executor() + predictor = create_predictor(config) + return predictor + + def get_inputs(self): + x = np.array([[1, 2, 3, 4], [2, 3, 4, 5]]).astype(np.float32) + + x_tensor = paddle.to_tensor(x) + + return [x_tensor] + + def get_outputs(self, predictor): + [x_tensor] = self.get_inputs() + + input_names = predictor.get_input_names() + x_tensor.name = input_names[0] + + # disorder + inputs = [x_tensor] + outputs = predictor.run(inputs) + + return outputs[0] + + def test_output(self): + self.enable_pir(True) + pir_predictor = self.init_predictor() + pir_output = self.get_outputs(pir_predictor) + self.enable_pir(False) + predictor = self.init_predictor() + output = self.get_outputs(predictor) + np.testing.assert_allclose( + output.numpy().flatten(), pir_output.numpy().flatten() + ) + + +if __name__ == "__main__": + unittest.main() From 7ddc18066c7e118f086f25313ff78bde8c40bdbc Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 3 Jan 2024 07:20:15 +0000 Subject: [PATCH 02/12] fix inference ut --- paddle/fluid/inference/api/demo_ci/custom_relu_op.cc | 4 ++-- paddle/fluid/inference/api/demo_ci/custom_relu_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc index 43be102d97865a..603a9bc4cefd6a 100755 --- a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc @@ -42,7 +42,7 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( - x.data(), out.data(x.place()), x.size()); + x.data(), out.data(), x.size()); })); return {out}; @@ -57,7 +57,7 @@ std::vector relu_cpu_backward(const paddle::Tensor& x, relu_cpu_backward_kernel( grad_out.data(), out.data(), - grad_x.data(x.place()), + grad_x.data(), out.size()); })); diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu index fd24b1dd150ef1..ddd5f103ac1283 100644 --- a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu @@ -44,7 +44,7 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( - x.data(), out.data(x.place()), numel); + x.data(), out.data(), numel); })); return {out}; @@ -63,7 +63,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, <<>>( grad_out.data(), out.data(), - grad_x.data(x.place()), + grad_x.data(), numel); })); From 3bda90108cec3bf3290fb12c3eb06726bcdaa0a8 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 4 Jan 2024 03:36:02 +0000 Subject: [PATCH 03/12] fix win bugs --- test/custom_op/CMakeLists.txt | 5 ++++- test/custom_op/test_inference_inplace.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/custom_op/CMakeLists.txt b/test/custom_op/CMakeLists.txt index c49d25ab815255..458208d0a8abb1 100644 --- a/test/custom_op/CMakeLists.txt +++ b/test/custom_op/CMakeLists.txt @@ -5,13 +5,16 @@ if(WITH_TESTING) py_test(test_custom_relu_op_jit SRCS test_custom_relu_op_jit.py) py_test(test_custom_relu_model SRCS test_custom_relu_model.py) py_test(test_context_pool SRCS test_context_pool.py) - py_test(test_inference_inplace SRCS test_inference_inplace.py) # Compiling shared library will cost some time, but running process is very fast. set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250) set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180) set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180) + endif() + + if(WITH_GPU) + py_test(test_inference_inplace SRCS test_inference_inplace.py) set_tests_properties(test_inference_inplace PROPERTIES TIMEOUT 180) endif() diff --git a/test/custom_op/test_inference_inplace.py b/test/custom_op/test_inference_inplace.py index 76e3e54f8f1db0..161761a57e5412 100644 --- a/test/custom_op/test_inference_inplace.py +++ b/test/custom_op/test_inference_inplace.py @@ -30,14 +30,14 @@ # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. -file = f'{get_build_directory()}\\custom_inplace\\custom_inplace.pyd' +file = f'{get_build_directory()}\\infer_custom\\infer_custom.pyd' if os.name == 'nt' and os.path.isfile(file): cmd = f'del {file}' run_cmd(cmd, True) # Compile and load custom op Just-In-Time. custom_inplace = load( - name='custom_inplace', + name='infer_custom', sources=['custom_inplace.cc', 'custom_inplace.cu'], extra_include_paths=paddle_includes, # add for Coverage CI extra_cxx_cflags=extra_cc_args, # test for cflags From b888da380e11ae2b73bda473447e38db957d3dae Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 4 Jan 2024 06:11:43 +0000 Subject: [PATCH 04/12] fix win bug --- test/custom_op/custom_inplace.cc | 21 ++++++--------------- test/custom_op/custom_inplace.cu | 8 +++++++- test/custom_op/test_inference_inplace.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/test/custom_op/custom_inplace.cc b/test/custom_op/custom_inplace.cc index bc8c467c97952d..f7db7922bf3f72 100644 --- a/test/custom_op/custom_inplace.cc +++ b/test/custom_op/custom_inplace.cc @@ -37,7 +37,7 @@ void assign_data_pointer(const data_t* x_data, } template -void relu_cpu_forward_kernel(data_t* x_data, int64_t numel) { +void relu_forward_kernel(data_t* x_data, int64_t numel) { for (size_t i = 0; i < numel; ++i) { x_data[i] = x_data[i] > 0 ? x_data[i] : 0; } @@ -200,22 +200,13 @@ PD_BUILD_GRAD_OP(custom_multi_inplace) {paddle::Grad("OutAB"), paddle::Grad("A")}}) .SetKernelFn(PD_KERNEL(MultiInplaceBackward)); -void relu_cpu_forward(paddle::Tensor& x) { // NOLINT +void ReluForwardInplace(paddle::Tensor& x) { // NOLINT CHECK_INPUT(x); - PD_DISPATCH_FLOATING_TYPES(x.type(), "relu_cpu_forward", ([&] { - relu_cpu_forward_kernel(x.data(), - x.size()); - })); -} - -void relu_cuda_forward(paddle::Tensor& x); // NOLINT -void ReluForwardInplace(paddle::Tensor& x) { // NOLINT - if (x.is_cpu()) { - relu_cpu_forward(x); - } else { - relu_cuda_forward(x); - } + PD_DISPATCH_FLOATING_TYPES(x.type(), "ReluForward", ([&] { + relu_forward_kernel(x.data(), + x.size()); + })); } void ReluBackwardInplace(const paddle::Tensor& x, diff --git a/test/custom_op/custom_inplace.cu b/test/custom_op/custom_inplace.cu index 53379788af5833..eb77a06b8ab44f 100644 --- a/test/custom_op/custom_inplace.cu +++ b/test/custom_op/custom_inplace.cu @@ -28,7 +28,7 @@ __global__ void relu_cuda_forward_kernel(data_t* x, int64_t num) { } } -void relu_cuda_forward(paddle::Tensor& x) { // NOLINT +void ReluForwardInplace(paddle::Tensor& x) { // NOLINT CHECK_GPU_INPUT(x); PD_CHECK(x.place() == paddle::DefaultGPUPlace()); @@ -42,3 +42,9 @@ void relu_cuda_forward(paddle::Tensor& x) { // NOLINT <<>>(x.data(), numel); })); } + +PD_BUILD_OP(custom_relu_inplace) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetInplaceMap({{"X", "Out"}}) + .SetKernelFn(PD_KERNEL(ReluForwardInplace)); diff --git a/test/custom_op/test_inference_inplace.py b/test/custom_op/test_inference_inplace.py index 161761a57e5412..ac66f67142927c 100644 --- a/test/custom_op/test_inference_inplace.py +++ b/test/custom_op/test_inference_inplace.py @@ -38,7 +38,7 @@ # Compile and load custom op Just-In-Time. custom_inplace = load( name='infer_custom', - sources=['custom_inplace.cc', 'custom_inplace.cu'], + sources=['custom_inplace.cu'], extra_include_paths=paddle_includes, # add for Coverage CI extra_cxx_cflags=extra_cc_args, # test for cflags extra_cuda_cflags=extra_nvcc_args, # test for cflags From 76c61a9f814efa4bfd0edbf08c3725d67ae7ff8e Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 4 Jan 2024 11:00:03 +0000 Subject: [PATCH 05/12] fix --- .../new_executor/instruction/custom_kernel_instruction.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc index 378a5d2d90ee3a..ee5d1989eff4e7 100644 --- a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc @@ -61,7 +61,7 @@ void CustomKernelInstruction::BuildCustomContext( custom_vec_in.emplace_back(paddle::Tensor()); custom_kernel_ctx_.EmplaceBackInputs(std::move(custom_vec_in)); } else { - input_name2id_map_[t] = vec_input_index; + input_name2id_map_[t] = input_index; input_index++; input_ptrs_.emplace_back(nullptr); custom_kernel_ctx_.EmplaceBackInput(std::move(paddle::Tensor())); @@ -84,7 +84,7 @@ void CustomKernelInstruction::BuildCustomContext( dense_tensor_in, [](phi::DenseTensor* ptr) { VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; }); - input_name2id_map_[t] = vec_input_index; + input_name2id_map_[t] = input_index; input_index++; input_ptrs_.push_back(dense_tensor_in); paddle::Tensor custom_in; From 14055c4b33e93f2ff58068863f8834f757e41bf8 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 4 Jan 2024 12:44:51 +0000 Subject: [PATCH 06/12] polish code --- .../instruction/custom_kernel_instruction.cc | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc index ee5d1989eff4e7..aaffb41b0f609e 100644 --- a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc @@ -527,8 +527,6 @@ std::vector> CustomKernelInstruction::RunDefaultInferShape() { std::vector> output_shapes; auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); - auto& inplace_reverse_map = - OpMetaInfoHelper::GetInplaceReverseMap(*custom_op_meta_); if (inplace_map.empty()) { // general case, assure single input and output VLOG(3) << "Custom Operator: Default InferShape - share ddim."; if (input_shapes_.size() == 1) { @@ -541,14 +539,14 @@ CustomKernelInstruction::RunDefaultInferShape() { "and only one output without setting the InferShapeFn. ")); } } else { // inplace case - for (auto const& pair : inplace_reverse_map) { - if (paddle::framework::detail::IsDuplicableVar(pair.first)) { - int input_index = vec_input_name2id_map_[pair.second]; + for (auto const& pair : inplace_map) { + if (paddle::framework::detail::IsDuplicableVar(pair.second)) { + int input_index = vec_input_name2id_map_[pair.first]; auto input_shape = vec_input_shapes_[input_index]; output_shapes.insert( output_shapes.end(), input_shape.begin(), input_shape.end()); } else { - int input_index = input_name2id_map_[pair.second]; + int input_index = input_name2id_map_[pair.first]; auto input_shape = input_shapes_[input_index]; output_shapes.push_back(input_shape); } @@ -560,8 +558,6 @@ CustomKernelInstruction::RunDefaultInferShape() { std::vector CustomKernelInstruction::RunDefaultInferDtype() { std::vector output_dtypes; auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); - auto& inplace_reverse_map = - OpMetaInfoHelper::GetInplaceReverseMap(*custom_op_meta_); if (inplace_map.empty()) { // general case, assure single input and output VLOG(3) << "Custom Operator: Default InferDtype - share ddim."; if (input_dtypes_.size() == 1) { @@ -574,14 +570,14 @@ std::vector CustomKernelInstruction::RunDefaultInferDtype() { "and only one output without setting the InferDtypeFn. ")); } } else { // inplace case - for (auto const& pair : inplace_reverse_map) { - if (paddle::framework::detail::IsDuplicableVar(pair.first)) { - int input_index = vec_input_name2id_map_[pair.second]; + for (auto const& pair : inplace_map) { + if (paddle::framework::detail::IsDuplicableVar(pair.second)) { + int input_index = vec_input_name2id_map_[pair.first]; auto input_dtype = vec_input_dtypes_[input_index]; output_dtypes.insert( output_dtypes.end(), input_dtype.begin(), input_dtype.end()); } else { - int input_index = input_name2id_map_[pair.second]; + int input_index = input_name2id_map_[pair.first]; auto input_dtype = input_dtypes_[input_index]; output_dtypes.push_back(input_dtype); } From b635a8dbd14989a728c42c6cc26cebda772049a5 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 4 Jan 2024 12:47:34 +0000 Subject: [PATCH 07/12] polish code --- test/custom_op/custom_inplace.cu | 2 +- test/custom_op/test_inference_inplace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/custom_op/custom_inplace.cu b/test/custom_op/custom_inplace.cu index eb77a06b8ab44f..9891045f3bd479 100644 --- a/test/custom_op/custom_inplace.cu +++ b/test/custom_op/custom_inplace.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/test/custom_op/test_inference_inplace.py b/test/custom_op/test_inference_inplace.py index ac66f67142927c..303b2b21d15dc8 100644 --- a/test/custom_op/test_inference_inplace.py +++ b/test/custom_op/test_inference_inplace.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 84ae78b9c056e993e3373f8a183d5499053a1a78 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 8 Jan 2024 13:03:11 +0000 Subject: [PATCH 08/12] print log --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 0227091e0aa531..738256fe6c40e0 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -234,6 +234,8 @@ inline bool HasOpInfo(pir::IrContext* ctx, const OpDesc& op_desc, std::string prefix) { std::string target_op_name = prefix + OpNameCompatibleMapping(op_desc.Type()); + LOG(INFO) << "target_op_name in has opinfo:" << target_op_name; + LOG(INFO) << "is inplace:" << IsInplace(op_desc); if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } From e65bd22853fbafb72b7013ae3bc2b90674795024 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 9 Jan 2024 03:27:08 +0000 Subject: [PATCH 09/12] print log --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 738256fe6c40e0..f3762b1ff538b3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -97,6 +97,7 @@ static const std::unordered_set SpecialInplaceOps = { inline bool IsInplace(const OpDesc& op_desc) { if (SpecialNonInplaceOps.count(op_desc.Type())) { + LOG(INFO) << "come into special non inplace ops"; return false; } if (SpecialInplaceOps.count(op_desc.Type())) { @@ -105,6 +106,8 @@ inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); + LOG(INFO) << "input_names empty:" << input_names.empty(); + LOG(INFO) << "output_names empty:" << output_names.empty(); if (input_names.empty() || output_names.empty()) { return inplace; } @@ -117,7 +120,7 @@ inline bool IsInplace(const OpDesc& op_desc) { output_names.begin(), output_names.end(), std::back_inserter(name_intersection)); - + LOG(INFO) << "name_intersection empty:" << name_intersection.empty(); if (!name_intersection.empty()) { std::string redundant_variables = std::accumulate( std::next(name_intersection.begin()), From 0c248278f1f9350a5a543ef0be275dd23d47e639 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 9 Jan 2024 05:48:11 +0000 Subject: [PATCH 10/12] debug --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index f3762b1ff538b3..d76863da7d3882 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -107,7 +107,15 @@ inline bool IsInplace(const OpDesc& op_desc) { auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); LOG(INFO) << "input_names empty:" << input_names.empty(); + for (auto input_name : input_names) { + LOG(INFO) << "input name:" << input_name; + } + LOG(INFO) << "output_names empty:" << output_names.empty(); + for (auto output_name : output_names) { + LOG(INFO) << "output name:" << output_name; + } + if (input_names.empty() || output_names.empty()) { return inplace; } From 68cfe20359ebc13f3d2d55ea878c71cf5f3a9410 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 9 Jan 2024 09:49:22 +0000 Subject: [PATCH 11/12] fix win bugs --- .../ir_adaptor/translator/op_translator.cc | 13 ----------- .../utils/cpp_extension/extension_utils.py | 23 +++---------------- 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index d76863da7d3882..c646d78e7e03e1 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -97,7 +97,6 @@ static const std::unordered_set SpecialInplaceOps = { inline bool IsInplace(const OpDesc& op_desc) { if (SpecialNonInplaceOps.count(op_desc.Type())) { - LOG(INFO) << "come into special non inplace ops"; return false; } if (SpecialInplaceOps.count(op_desc.Type())) { @@ -106,15 +105,6 @@ inline bool IsInplace(const OpDesc& op_desc) { bool inplace = false; auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); - LOG(INFO) << "input_names empty:" << input_names.empty(); - for (auto input_name : input_names) { - LOG(INFO) << "input name:" << input_name; - } - - LOG(INFO) << "output_names empty:" << output_names.empty(); - for (auto output_name : output_names) { - LOG(INFO) << "output name:" << output_name; - } if (input_names.empty() || output_names.empty()) { return inplace; @@ -128,7 +118,6 @@ inline bool IsInplace(const OpDesc& op_desc) { output_names.begin(), output_names.end(), std::back_inserter(name_intersection)); - LOG(INFO) << "name_intersection empty:" << name_intersection.empty(); if (!name_intersection.empty()) { std::string redundant_variables = std::accumulate( std::next(name_intersection.begin()), @@ -245,8 +234,6 @@ inline bool HasOpInfo(pir::IrContext* ctx, const OpDesc& op_desc, std::string prefix) { std::string target_op_name = prefix + OpNameCompatibleMapping(op_desc.Type()); - LOG(INFO) << "target_op_name in has opinfo:" << target_op_name; - LOG(INFO) << "is inplace:" << IsInplace(op_desc); if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index e64f5e6a25b3f6..e56f8464e5b34a 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -1074,13 +1074,7 @@ def _gen_output_content( {indent} start_idx += len({lower_in_names}) {indent}else: {indent} res.append(None) -{indent} start_idx += 1""" - if IS_WINDOWS: - static_content += f""" -{indent}if {lower_in_names} is not None: -{indent} outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]""" - else: - static_content += f""" +{indent} start_idx += 1 {indent}if {lower_in_names} is not None: {indent} outs['{out_name}'] = {lower_in_names}""" @@ -1090,12 +1084,7 @@ def _gen_output_content( lower_in_names = in_names[in_idx].split("@")[0].lower() dynamic_content += f""" {indent}res.append(outs[start_idx: start_idx + len({lower_in_names})]) -{indent}start_idx += len({lower_in_names})""" - if IS_WINDOWS: - static_content += f""" -{indent}outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]""" - else: - static_content += f""" +{indent}start_idx += len({lower_in_names}) {indent}outs['{out_name}'] = {lower_in_names}""" elif ( in_idx != -1 and "@OPTIONAL" in in_names[in_idx] @@ -1106,13 +1095,7 @@ def _gen_output_content( {indent} res.append(outs[start_idx]) {indent}else: {indent} res.append(None) -{indent}start_idx += 1""" - if IS_WINDOWS: - static_content += f""" -{indent}if {lower_in_names} is not None: -{indent} outs['{out_name}'] = helper.create_variable(dtype='float32')""" - else: - static_content += f""" +{indent}start_idx += 1 {indent}if {lower_in_names} is not None: {indent} outs['{out_name}'] = {lower_in_names}""" elif ( From 0935d7e414321e1357234466dd170fffd76d08d0 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 9 Jan 2024 12:35:46 +0000 Subject: [PATCH 12/12] fix windows --- .../utils/cpp_extension/extension_utils.py | 23 ++++++++++++++++--- test/custom_op/CMakeLists.txt | 10 +++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index e56f8464e5b34a..e64f5e6a25b3f6 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -1074,7 +1074,13 @@ def _gen_output_content( {indent} start_idx += len({lower_in_names}) {indent}else: {indent} res.append(None) -{indent} start_idx += 1 +{indent} start_idx += 1""" + if IS_WINDOWS: + static_content += f""" +{indent}if {lower_in_names} is not None: +{indent} outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]""" + else: + static_content += f""" {indent}if {lower_in_names} is not None: {indent} outs['{out_name}'] = {lower_in_names}""" @@ -1084,7 +1090,12 @@ def _gen_output_content( lower_in_names = in_names[in_idx].split("@")[0].lower() dynamic_content += f""" {indent}res.append(outs[start_idx: start_idx + len({lower_in_names})]) -{indent}start_idx += len({lower_in_names}) +{indent}start_idx += len({lower_in_names})""" + if IS_WINDOWS: + static_content += f""" +{indent}outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]""" + else: + static_content += f""" {indent}outs['{out_name}'] = {lower_in_names}""" elif ( in_idx != -1 and "@OPTIONAL" in in_names[in_idx] @@ -1095,7 +1106,13 @@ def _gen_output_content( {indent} res.append(outs[start_idx]) {indent}else: {indent} res.append(None) -{indent}start_idx += 1 +{indent}start_idx += 1""" + if IS_WINDOWS: + static_content += f""" +{indent}if {lower_in_names} is not None: +{indent} outs['{out_name}'] = helper.create_variable(dtype='float32')""" + else: + static_content += f""" {indent}if {lower_in_names} is not None: {indent} outs['{out_name}'] = {lower_in_names}""" elif ( diff --git a/test/custom_op/CMakeLists.txt b/test/custom_op/CMakeLists.txt index 458208d0a8abb1..abed612162a1e5 100644 --- a/test/custom_op/CMakeLists.txt +++ b/test/custom_op/CMakeLists.txt @@ -13,9 +13,13 @@ if(WITH_TESTING) set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180) endif() - if(WITH_GPU) - py_test(test_inference_inplace SRCS test_inference_inplace.py) - set_tests_properties(test_inference_inplace PROPERTIES TIMEOUT 180) + if(NOT WIN32) + # TODO(YuanRisheng) : Currently, we run this unittest by translating old ir to new ir, and it has bug that can't judge whether op_desc is a inplace op in windows. + # We will fix it when abandoning translation in final state. + if(WITH_GPU) + py_test(test_inference_inplace SRCS test_inference_inplace.py) + set_tests_properties(test_inference_inplace PROPERTIES TIMEOUT 180) + endif() endif() # custom OP support TensorRT inference