Skip to content

Commit 338e54d

Browse files
authored
1 parent 3ab6a9b commit 338e54d

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh
2+
index 632b579c55af..a2f7e4330000 100644
3+
--- a/csrc/custom_all_reduce.cuh
4+
+++ b/csrc/custom_all_reduce.cuh
5+
@@ -131,15 +131,26 @@ DINLINE O downcast(array_t<float, O::size> val) {
6+
}
7+
8+
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
9+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
10+
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
11+
"l"(flag_addr));
12+
+#else
13+
+ asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
14+
+ "l"(flag_addr));
15+
+#endif
16+
}
17+
18+
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
19+
FlagType flag;
20+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
21+
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
22+
: "=r"(flag)
23+
: "l"(flag_addr));
24+
+#else
25+
+ asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
26+
+ : "=r"(flag)
27+
+ : "l"(flag_addr));
28+
+#endif
29+
return flag;
30+
}
31+
32+
diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu
33+
index c8b5d0a013f6..9d83a001c77a 100644
34+
--- a/csrc/custom_all_reduce_test.cu
35+
+++ b/csrc/custom_all_reduce_test.cu
36+
@@ -44,7 +44,14 @@
37+
} while (0)
38+
39+
__global__ void dummy_kernel() {
40+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
41+
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
42+
+#else
43+
+ for (int i = 0; i < 100; i++) {
44+
+ long long int start = clock64();
45+
+ while (clock64() - start < 1000000); // something like 100ms
46+
+ }
47+
+#endif
48+
}
49+
50+
template <typename T>

0 commit comments

Comments
 (0)