Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,6 @@ cc_library(
SRCS api_custom_impl.cc
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta
phi_data_transform)
cc_library(
sparse_api_custom_impl
SRCS sparse_api_custom_impl.cc
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform
tensor_copy)

cc_library(
phi_function_api
SRCS ${api_source_file}
Expand All @@ -389,19 +383,20 @@ cc_library(
kernel_dispatch
api_gen_utils
backward_infermeta
sparse_backward_infermeta
phi_data_transform
phi_function_api
api_custom_impl
global_utils)
cc_library(
sparse_api
SRCS ${sparse_api_source_file}
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils)
cc_library(
sparse_bw_api
SRCS ${sparse_bw_api_source_file}
DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api
sparse_api_custom_impl)
sparse_backward_infermeta)
cc_library(
phi_dygraph_api
SRCS ${dygraph_api_source_file}
Expand All @@ -424,6 +419,7 @@ cc_library(
api_gen_utils
kernel_dispatch
infermeta
sparse_infermeta
sparse_api
strings_api)
cc_library(
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor) {
return std::dynamic_pointer_cast<phi::StringTensor>(tensor.impl());
}

std::shared_ptr<phi::SparseCooTensor> TensorToSparseCooTensor(
const Tensor& tensor) {
return std::static_pointer_cast<phi::SparseCooTensor>(tensor.impl());
}
/* ----------------- for infer_meta --------------------- */

phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
Expand Down Expand Up @@ -130,6 +134,22 @@ phi::MetaTensor MakeMetaTensor(
return phi::MetaTensor();
}

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCooTensor>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return phi::MetaTensor();
}

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCsrTensor>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return phi::MetaTensor();
}

std::vector<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ paddle::optional<phi::SelectedRows> TensorToSelectedRows(

std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor);

std::shared_ptr<phi::SparseCooTensor> TensorToSparseCooTensor(
const Tensor& tensor);
/* ----------------- for infer_meta --------------------- */

phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor);
Expand All @@ -68,6 +70,12 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor);

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCooTensor>& tensor);

phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SparseCsrTensor>& tensor);

std::vector<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors);

Expand Down
202 changes: 0 additions & 202 deletions paddle/phi/api/lib/sparse_api_custom_impl.cc

This file was deleted.

13 changes: 11 additions & 2 deletions paddle/phi/api/lib/tensor_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/infermeta/unary.h"
// clang-format off

namespace paddle {
namespace experimental {
Expand Down Expand Up @@ -165,15 +166,23 @@ void Tensor::copy_(const Tensor &src,
static_cast<phi::SelectedRows *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_COO_KERNEL) {
SetSparseKernelOutput(this, TensorType::SPARSE_COO);
// TODO(zhangkaihuo) add sparse infer_meta
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCooTensor>(src.impl_))),
target_place,
blocking,
static_cast<phi::SparseCooTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SPARSE_CSR_KERNEL) {
SetSparseKernelOutput(this, TensorType::SPARSE_CSR);
// TODO(zhangkaihuo) add sparse infer_meta
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
&meta_out);
phi::Copy(*dev_ctx,
(*(std::static_pointer_cast<phi::SparseCsrTensor>(src.impl_))),
target_place,
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/api/yaml/generator/intermediate_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"

#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"

#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
"""
Expand Down
Loading