|
| 1 | +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <thrust/equal.h> |
| 16 | +#include <thrust/execution_policy.h> |
| 17 | + |
| 18 | +#include "paddle/phi/kernels/elementwise_add_kernel.h" |
| 19 | +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" |
| 20 | +#include "paddle/phi/kernels/sparse/empty_kernel.h" |
| 21 | + |
| 22 | +#include "paddle/phi/core/enforce.h" |
| 23 | +#include "paddle/phi/core/kernel_registry.h" |
| 24 | +#include "paddle/phi/core/visit_type.h" |
| 25 | + |
| 26 | +namespace phi { |
| 27 | +namespace sparse { |
| 28 | + |
| 29 | +template <typename T, typename IntT> |
| 30 | +void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx, |
| 31 | + const SparseCooTensor& x, |
| 32 | + const SparseCooTensor& y, |
| 33 | + SparseCooTensor* out) { |
| 34 | + const auto& x_indices = x.indices(); |
| 35 | + const auto& y_indices = y.indices(); |
| 36 | + PADDLE_ENFORCE_EQ( |
| 37 | + x_indices.numel(), |
| 38 | + y_indices.numel(), |
| 39 | + phi::errors::PreconditionNotMet( |
| 40 | + "The numel of x.indices() and y.indices() should be equal")); |
| 41 | + const IntT* x_indices_ptr = x_indices.data<IntT>(); |
| 42 | + const IntT* y_indices_ptr = y_indices.data<IntT>(); |
| 43 | +#ifdef PADDLE_WITH_HIP |
| 44 | + bool is_same = thrust::equal(thrust::hip::par.on(dev_ctx.stream()), |
| 45 | +#else |
| 46 | + bool is_same = thrust::equal(thrust::cuda::par.on(dev_ctx.stream()), |
| 47 | +#endif |
| 48 | + x_indices_ptr, |
| 49 | + x_indices_ptr + x_indices.numel(), |
| 50 | + y_indices_ptr); |
| 51 | + PADDLE_ENFORCE_EQ( |
| 52 | + is_same, |
| 53 | + true, |
| 54 | + phi::errors::PreconditionNotMet( |
| 55 | + "Currently, ElementWiseAddCooKernel only supports the case " |
| 56 | + "where x and y have the same indices")); |
| 57 | + EmptyLikeCooKernel<T, GPUContext>(dev_ctx, x, out); |
| 58 | + phi::AddKernel<T, GPUContext>( |
| 59 | + dev_ctx, x.values(), y.values(), out->mutable_values()); |
| 60 | +} |
| 61 | + |
| 62 | +template <typename T, typename Context> |
| 63 | +void ElementWiseAddCooKernel(const Context& dev_ctx, |
| 64 | + const SparseCooTensor& x, |
| 65 | + const SparseCooTensor& y, |
| 66 | + SparseCooTensor* out) { |
| 67 | + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "VerifyIndices", ([&] { |
| 68 | + ElementWiseAddCooGPUKernel<T, data_t>( |
| 69 | + dev_ctx, x, y, out); |
| 70 | + })); |
| 71 | +} |
| 72 | + |
| 73 | +} // namespace sparse |
| 74 | +} // namespace phi |
| 75 | + |
| 76 | +PD_REGISTER_KERNEL(add_coo_coo, |
| 77 | + GPU, |
| 78 | + ALL_LAYOUT, |
| 79 | + phi::sparse::ElementWiseAddCooKernel, |
| 80 | + float, |
| 81 | + double, |
| 82 | + int16_t, |
| 83 | + int, |
| 84 | + int64_t, |
| 85 | + phi::dtype::float16) { |
| 86 | + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); |
| 87 | + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); |
| 88 | +} |
0 commit comments