Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ cc_library(
SRCS api_custom_impl.cc
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta
phi_data_transform)

cc_library(
phi_function_api
SRCS ${api_source_file}
Expand All @@ -384,6 +383,7 @@ cc_library(
kernel_dispatch
api_gen_utils
backward_infermeta
sparse_backward_infermeta
phi_data_transform
phi_function_api
api_custom_impl
Expand All @@ -395,7 +395,8 @@ cc_library(
cc_library(
sparse_bw_api
SRCS ${sparse_bw_api_source_file}
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api)
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api
sparse_backward_infermeta)
cc_library(
phi_dygraph_api
SRCS ${dygraph_api_source_file}
Expand All @@ -418,6 +419,7 @@ cc_library(
api_gen_utils
kernel_dispatch
infermeta
sparse_infermeta
sparse_api
strings_api)
cc_library(
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) {
return std::dynamic_pointer_cast<phi::StringTensor>(tensor.impl());
}

std::shared_ptr<phi::SparseCooTensor> TensorToSparseCooTensor(
const Tensor& tensor) {
return std::static_pointer_cast<phi::SparseCooTensor>(tensor.impl());
}
/* ----------------- for infer_meta --------------------- */

phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
Expand Down Expand Up @@ -130,6 +134,22 @@ phi::MetaTensor MakeMetaTensor(
return phi::MetaTensor();
}

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCooTensor>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return phi::MetaTensor();
}

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCsrTensor>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return phi::MetaTensor();
}

std::vector<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ paddle::optional<phi::SelectedRows> TensorToSelectedRows(

std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor);

std::shared_ptr<phi::SparseCooTensor> TensorToSparseCooTensor(
const Tensor& tensor);
/* ----------------- for infer_meta --------------------- */

phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor);
Expand All @@ -68,6 +70,12 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor);

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCooTensor>& tensor);

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCsrTensor>& tensor);

std::vector<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors);

Expand Down
14 changes: 12 additions & 2 deletions paddle/phi/api/lib/tensor_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License. */
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/unary.h"
// clang-format off

namespace paddle {
namespace experimental {
Expand Down Expand Up @@ -165,15 +167,23 @@ void Tensor::copy_(const Tensor &src,
static_cast<phi::SelectedRows *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_COO_KERNEL) {
SetSparseKernelOutput(this, TensorType::SPARSE_COO);
// TODO(zhangkaihuo) add sparse infer_meta
phi::MetaTensor meta_out(impl_.get());
phi::sparse::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是可以不用cast?

&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCooTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) {
SetSparseKernelOutput(this, TensorType::SPARSE_CSR);
// TODO(zhangkaihuo) add sparse infer_meta
phi::MetaTensor meta_out(impl_.get());
phi::sparse::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
target_place,
Expand Down
39 changes: 31 additions & 8 deletions paddle/phi/api/yaml/generator/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,24 +477,39 @@ def gene_kernel_select(self) -> str:

return kernel_select_code

def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
def gene_infer_meta(self,
kernel_output_names,
code_indent,
is_sparse=False) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
infer_meta = self.infer_meta

prefix_tensor_name = PREFIX_TENSOR_NAME if is_sparse is False else ""

infer_meta_params = infer_meta['param'] if infer_meta[
'param'] is not None else input_names + attr_names
# generate meta tensors
meta_tensor_code = ""
param_code = ""
for param in infer_meta_params:
if is_sparse is False:
make_input_meta_tensor = f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor({prefix_tensor_name}{param});
"""
else:
make_input_meta_tensor = f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor(*{param}.impl());
"""
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
if is_sparse is False:
param_code = param_code + "MakeMetaTensor(*" + prefix_tensor_name + param + "), "
else:
param_code = param_code + "MakeMetaTensor(*" + param + ".impl()), "
elif self.inputs['input_info'][
param] == "const std::vector<Tensor>&":
meta_tensor_code = meta_tensor_code + f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
meta_tensor_code = meta_tensor_code + make_input_meta_tensor + f"""
{code_indent} std::vector<const phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
{code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent} {param}_metas[i] = &{param}_meta_vec[i];
Expand All @@ -503,16 +518,19 @@ def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
param_code = param_code + param + "_metas, "
elif self.inputs['input_info'][
param] == "const paddle::optional<std::vector<Tensor>>&":
meta_tensor_code = meta_tensor_code + f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
meta_tensor_code = meta_tensor_code + make_input_meta_tensor + f"""
{code_indent} paddle::optional<std::vector<const phi::MetaTensor*>> {param}_metas({param}_meta_vec.size());
{code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent} {param}_metas->at(i) = &{param}_meta_vec[i];
{code_indent} }}
"""
param_code = param_code + param + "_metas, "
elif param in self.optional_vars:
param_code = param_code + "MakeMetaTensor(" + PREFIX_TENSOR_NAME + param + "), "
if is_sparse is False:
param_code = param_code + "MakeMetaTensor(" + prefix_tensor_name + param + "), "
else:
param_code = param_code + f"""{param}""" + "? MakeMetaTensor(*(*" + param + ").impl()) : phi::MetaTensor(), "

else:
raise ValueError(
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
Expand Down Expand Up @@ -546,7 +564,12 @@ def gene_infer_meta(self, kernel_output_names, code_indent) -> str:
param_code = param_code + f"{out_name} ? &{out_name.replace('kernel_', PREFIX_META_TENSOR_NAME)} : nullptr, "

param_code = param_code[:-2]
return f"""{meta_tensor_code}
if is_sparse:
return f"""{meta_tensor_code}
{code_indent} phi::sparse::{infer_meta['func']}({param_code});
"""
else:
return f"""{meta_tensor_code}
{code_indent} phi::{infer_meta['func']}({param_code});
"""

Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/yaml/generator/intermediate_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"

#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/ternary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"

#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
"""
Expand Down
15 changes: 12 additions & 3 deletions paddle/phi/api/yaml/generator/sparse_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
auto kernel_context = phi::KernelContext(dev_ctx);
{output_create}
{self.gene_infer_meta(kernel_output_names, '', True)}
{kernel_context_code}
phi_kernel(&kernel_context);
{return_code}"""
Expand All @@ -178,9 +179,13 @@ def get_condition_code(self, kernel_name):
f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())"
)
else:
condition_list.append(
f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}"
)
if in_type == 'sparse_coo':
condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_coo_tensor()")
else:
condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_csr_tensor()")

return " && ".join(condition_list)

def gene_dispatch_code(self, kernel_name, inplace_flag=False):
Expand Down Expand Up @@ -230,6 +235,10 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/ternary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
"""


Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/generator/sparse_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/backward.h"
"""


Expand Down
Loading