Skip to content

Commit 565980a

Browse files
eee4017Frank Lin (Engrg-Hardware 1)Tom-Zheng
authored
Fix test_weight_decay and test_graph_reindex (#62707)
* fix test_graph_reindex * Fix test_weight_decay --------- Co-authored-by: Frank Lin (Engrg-Hardware 1) <[email protected]> Co-authored-by: Tian Zheng (Engrg-Hardware 1) <[email protected]>
1 parent 6307361 commit 565980a

File tree

3 files changed

+57
-39
lines changed

3 files changed

+57
-39
lines changed

cmake/external/cccl.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ set(CCCL_INCLUDE_DIR ${CCCL_SOURCE_DIR})
1515
message("CCCL_INCLUDE_DIR is ${CCCL_INCLUDE_DIR}")
1616
include_directories(${CCCL_INCLUDE_DIR})
1717

18+
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/cccl/util_device.cuh.patch
19+
native_src)
20+
set(CCCL_PATCH_COMMAND git checkout -- . && git checkout ${CCCL_TAG} && patch
21+
-p1 -Nd ${CCCL_SOURCE_DIR} < ${native_src})
22+
1823
ExternalProject_Add(
1924
extern_cccl
2025
${EXTERNAL_PROJECT_LOG_ARGS}
2126
SOURCE_DIR ${CCCL_SOURCE_DIR}
2227
PREFIX ${CCCL_PREFIX_DIR}
2328
UPDATE_COMMAND ""
29+
PATCH_COMMAND ${CCCL_PATCH_COMMAND}
2430
CONFIGURE_COMMAND ""
2531
BUILD_COMMAND ""
2632
INSTALL_COMMAND ""

paddle/phi/kernels/gpu/graph_reindex_kernel.cu

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,53 +67,34 @@ std::shared_ptr<phi::Allocation> FillHashTable(const Context& dev_ctx,
6767
input, num_input, len_hashtable, keys, key_index);
6868

6969
// Get item index count.
70-
auto item_count =
71-
phi::memory_utils::Alloc(place, (num_input + 1) * sizeof(int));
72-
int* item_count_ptr = reinterpret_cast<int*>(item_count->ptr());
73-
#ifdef PADDLE_WITH_HIP
74-
hipMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1));
75-
#else
76-
cudaMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1));
77-
#endif
70+
thrust::device_vector<int> item_count(num_input + 1, 0);
7871
GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
79-
input, item_count_ptr, num_input, len_hashtable, keys, key_index);
80-
81-
size_t temp_storage_bytes = 0;
82-
cub::DeviceScan::ExclusiveSum(
83-
NULL, temp_storage_bytes, item_count_ptr, item_count_ptr, num_input + 1);
84-
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
85-
cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(),
86-
temp_storage_bytes,
87-
item_count_ptr,
88-
item_count_ptr,
89-
num_input + 1);
90-
int total_unique_items = 0;
91-
#ifdef PADDLE_WITH_HIP
92-
hipMemcpy(&total_unique_items,
93-
item_count_ptr + num_input,
94-
sizeof(int),
95-
hipMemcpyDeviceToHost);
96-
#else
97-
cudaMemcpy(&total_unique_items,
98-
item_count_ptr + num_input,
99-
sizeof(int),
100-
cudaMemcpyDeviceToHost);
101-
#endif
72+
input,
73+
thrust::raw_pointer_cast(item_count.data()),
74+
num_input,
75+
len_hashtable,
76+
keys,
77+
key_index);
10278

79+
thrust::exclusive_scan(
80+
item_count.begin(), item_count.end(), item_count.begin());
81+
82+
int total_unique_items = item_count[num_input];
10383
auto unique_items =
10484
phi::memory_utils::AllocShared(place, total_unique_items * sizeof(T));
10585
T* unique_items_data = reinterpret_cast<T*>(unique_items->ptr());
10686
*final_nodes_len = total_unique_items;
10787

10888
// Get unique items
109-
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(input,
110-
num_input,
111-
len_hashtable,
112-
unique_items_data,
113-
item_count_ptr,
114-
keys,
115-
values,
116-
key_index);
89+
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
90+
input,
91+
num_input,
92+
len_hashtable,
93+
unique_items_data,
94+
thrust::raw_pointer_cast(item_count.data()),
95+
keys,
96+
values,
97+
key_index);
11798
return unique_items;
11899
}
119100

patches/cccl/util_device.cuh.patch

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh
2+
index c7e15cafe..756336914 100644
3+
--- a/cub/cub/util_device.cuh
4+
+++ b/cub/cub/util_device.cuh
5+
@@ -278,7 +278,7 @@ public:
6+
/**
7+
* \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10).
8+
*/
9+
-CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersionUncached(int& ptx_version)
10+
+CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersionUncached(int& ptx_version)
11+
{
12+
// Instantiate `EmptyKernel<void>` in both host and device code to ensure
13+
// it can be called.
14+
@@ -375,7 +375,7 @@ __host__ inline cudaError_t PtxVersion(int& ptx_version, int device)
15+
*
16+
* \note This function is thread safe.
17+
*/
18+
-CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersion(int &ptx_version)
19+
+CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersion(int &ptx_version)
20+
{
21+
cudaError_t result = cudaErrorUnknown;
22+
NV_IF_TARGET(
23+
@@ -593,7 +593,7 @@ CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool& has_uva)
24+
*
25+
*/
26+
template <typename KernelPtr>
27+
-CUB_RUNTIME_FUNCTION inline
28+
+CUB_RUNTIME_FUNCTION __forceinline__
29+
cudaError_t MaxSmOccupancy(
30+
int& max_sm_occupancy, ///< [out] maximum number of thread blocks that can reside on a single SM
31+
KernelPtr kernel_ptr, ///< [in] Kernel pointer for which to compute SM occupancy

0 commit comments

Comments
 (0)