|
| 1 | +diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh |
| 2 | +index 632b579c55af..a2f7e4330000 100644 |
| 3 | +--- a/csrc/custom_all_reduce.cuh |
| 4 | ++++ b/csrc/custom_all_reduce.cuh |
| 5 | +@@ -131,15 +131,26 @@ DINLINE O downcast(array_t<float, O::size> val) { |
| 6 | + } |
| 7 | + |
| 8 | + static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { |
| 9 | ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 |
| 10 | + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), |
| 11 | + "l"(flag_addr)); |
| 12 | ++#else |
| 13 | ++ asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), |
| 14 | ++ "l"(flag_addr)); |
| 15 | ++#endif |
| 16 | + } |
| 17 | + |
| 18 | + static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { |
| 19 | + FlagType flag; |
| 20 | ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 |
| 21 | + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" |
| 22 | + : "=r"(flag) |
| 23 | + : "l"(flag_addr)); |
| 24 | ++#else |
| 25 | ++ asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" |
| 26 | ++ : "=r"(flag) |
| 27 | ++ : "l"(flag_addr)); |
| 28 | ++#endif |
| 29 | + return flag; |
| 30 | + } |
| 31 | + |
| 32 | +diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu |
| 33 | +index c8b5d0a013f6..9d83a001c77a 100644 |
| 34 | +--- a/csrc/custom_all_reduce_test.cu |
| 35 | ++++ b/csrc/custom_all_reduce_test.cu |
| 36 | +@@ -44,7 +44,14 @@ |
| 37 | + } while (0) |
| 38 | + |
| 39 | + __global__ void dummy_kernel() { |
| 40 | ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 |
| 41 | + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms |
| 42 | ++#else |
| 43 | ++ for (int i = 0; i < 100; i++) { |
| 44 | ++ long long int start = clock64(); |
| 45 | ++ while (clock64() - start < 1000000); // something like 100ms |
| 46 | ++ } |
| 47 | ++#endif |
| 48 | + } |
| 49 | + |
| 50 | + template <typename T> |
0 commit comments