Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,6 @@ __device__ __forceinline__ float exp2f_approx(const float &x) {
return ret;
}

// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES

__device__ __forceinline__ uint32_t elect_one_sync(int lane_id) {
uint32_t pred = 0;
asm volatile(
Expand All @@ -437,23 +434,30 @@ __device__ __forceinline__ uint32_t elect_one_sync(int lane_id) {
}

__device__ __forceinline__ void fence_view_async_shared() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.async.shared::cta; \n" ::);
#endif
}

__device__ __forceinline__ void fence_barrier_init() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.mbarrier_init.release.cluster; \n" ::);
#endif
}

__device__ __forceinline__ void mbarrier_init(uint64_t *mbar_ptr,
uint32_t arrive_count) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" ::"r"(arrive_count),
"r"(mbar_int_ptr));
#endif
}

__device__ __forceinline__ void mbarrier_wait(uint64_t *mbar_ptr,
uint32_t &phase) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
Expand All @@ -466,19 +470,24 @@ __device__ __forceinline__ void mbarrier_wait(uint64_t *mbar_ptr,
"r"(phase),
"r"(0x989680));
phase ^= 1;
#endif
}

__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(
uint64_t *mbar_ptr, int num_bytes) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile(
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" ::"r"(
num_bytes),
"r"(mbar_int_ptr));
#endif
}

__device__ __forceinline__ void tma_store_fence() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.async.shared::cta;");
#endif
}

constexpr uint64_t kEvictFirst = 0x12f0000000000000;
Expand All @@ -492,6 +501,7 @@ __device__ __forceinline__ void tma_load_1d(const void *smem_ptr,
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::"
"cache_hint [%0], [%1], %2, [%3], %4;\n" ::"r"(smem_int_ptr),
Expand All @@ -500,6 +510,7 @@ __device__ __forceinline__ void tma_load_1d(const void *smem_ptr,
"r"(mbar_int_ptr),
"l"(cache_hint)
: "memory");
#endif
}

__device__ __forceinline__ void tma_store_1d(const void *smem_ptr,
Expand All @@ -508,6 +519,7 @@ __device__ __forceinline__ void tma_store_1d(const void *smem_ptr,
bool evict_first = true) {
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], "
"%2, %3;\n" ::"l"(gmem_ptr),
Expand All @@ -516,14 +528,15 @@ __device__ __forceinline__ void tma_store_1d(const void *smem_ptr,
"l"(cache_hint)
: "memory");
asm volatile("cp.async.bulk.commit_group;");
#endif
}

template <int N = 0>
__device__ __forceinline__ void tma_store_wait() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("cp.async.bulk.wait_group.read %0;" ::"n"(N) : "memory");
}

#endif
}

template <typename dtype_t>
__host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
Expand Down