Skip to content

Commit 0b1e3df

Browse files
authored
only compile sparse conv kernel on cuda device (#63864)
1 parent 5afb403 commit 0b1e3df

3 files changed

Lines changed: 20 additions & 0 deletions

File tree

paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,18 @@ void Conv3dImplicitGemmGPUKernel(const GPUContext& dev_ctx,
152152
phi::funcs::TransposeGPUKernelDriver<T>(
153153
dev_ctx, kernel, perm, &kernel_transpose);
154154

155+
#ifdef PADDLE_WITH_CUDA
155156
conv_forward_implicit_gemm_cuda(dev_ctx,
156157
x.values(),
157158
kernel_transpose,
158159
*(out_kmap_cache_ptr->out_in_map),
159160
out->nnz(),
160161
out_channels,
161162
*(out->mutable_values()));
163+
#else
164+
PADDLE_THROW(phi::errors::Unimplemented(
165+
"conv_forward_implicit_gemm_cuda is only supported on CUDA."));
166+
#endif
162167
}
163168

164169
/**
@@ -179,6 +184,7 @@ void Conv3dImplicitGemmKernel(const Context& dev_ctx,
179184
const bool subm,
180185
const std::string& key,
181186
SparseCooTensor* out) {
187+
#ifdef PADDLE_WITH_CUDA
182188
PD_VISIT_BASE_INTEGRAL_TYPES(
183189
x.indices().dtype(), "Conv3dImplicitGemmGPUKernel", ([&] {
184190
// Conv3dImplicitGemmGPUKernel<T, data_t>(dev_ctx,
@@ -193,6 +199,10 @@ void Conv3dImplicitGemmKernel(const Context& dev_ctx,
193199
key,
194200
out);
195201
}));
202+
#else
203+
PADDLE_THROW(phi::errors::Unimplemented(
204+
"Conv3dImplicitGemmKernel is only supported on CUDA."));
205+
#endif
196206
}
197207
} // namespace sparse
198208
} // namespace phi

paddle/phi/kernels/sparse/gpu/conv_kernel_impl.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#pragma once
2+
3+
#ifdef PADDLE_WITH_CUDA
4+
15
#include <cuda_fp16.h>
26
#include "paddle/phi/common/float16.h"
37
#include "paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh"
@@ -1271,3 +1275,5 @@ void conv_forward_implicit_gemm_cuda(
12711275
}
12721276
}
12731277
}
1278+
1279+
#endif //PADDLE_WITH_CUDA

paddle/phi/kernels/sparse/gpu/conv_memory_utils.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#ifdef PADDLE_WITH_CUDA
4+
35
template <int bytes>
46
struct global_load;
57

@@ -93,3 +95,5 @@ struct global_load<2>
9395
}
9496
}
9597
};
98+
99+
#endif // PADDLE_WITH_CUDA

0 commit comments

Comments
 (0)