Skip to content

Commit 27f4a8b

Browse files
xavier-nvidiaXavier Simmons
authored andcommitted
Fix GEMM+AR nvbugs 5219533,5127801,5072306
Update NVLS bootstrap to support MNNVL Signed-off-by: xsimmons <[email protected]>
1 parent 9db769e commit 27f4a8b

27 files changed

+918
-480
lines changed

cpp/CMakeLists.txt

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ option(ENABLE_MULTI_DEVICE
4444
option(ENABLE_UCX "Enable building with UCX (Uniform Communication X) support"
4545
ON)
4646
option(NVRTC_DYNAMIC_LINKING "Link against the dynamic NVRTC libraries" OFF)
47+
option(ENABLE_NVSHMEM "Enable building with NVSHMEM support" OFF)
4748
option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
4849
"Using open sourced Cutlass low latency gemm kernel" ON)
4950
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
@@ -53,6 +54,8 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
5354
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
5455
"Using open sourced Cutlass AR gemm kernel" ON)
5556

57+
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
58+
5659
if(NVTX_DISABLE)
5760
add_compile_definitions("NVTX_DISABLE")
5861
message(STATUS "NVTX is disabled")
@@ -165,6 +168,7 @@ message(STATUS "CUDA library status:")
165168
message(STATUS " version: ${CUDAToolkit_VERSION}")
166169
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
167170
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
171+
message(STATUS "CUDA_NVML_LIB: ${CUDA_NVML_LIB}")
168172

169173
# Prevent CMake from creating a response file for CUDA compiler, so clangd can
170174
# pick up on the includes
@@ -256,9 +260,21 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss ")
256260
# note: cmake expr generation $<BOOL:${ENABLE_MULTI_DEVICE}> is a build time
257261
# evaluation so hard to debug at cmake time
258262
if(ENABLE_MULTI_DEVICE)
259-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=1")
263+
# Add target definitions for both C++ and CUDA
264+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=1>
265+
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=1>)
266+
else()
267+
# Add target definitions for both C++ and CUDA
268+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=0>
269+
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=0>)
270+
endif()
271+
272+
if(ENABLE_NVSHMEM)
273+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=1>
274+
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=1>)
260275
else()
261-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=0")
276+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=0>
277+
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
262278
endif()
263279

264280
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can

cpp/tensorrt_llm/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ if(ENABLE_MULTI_DEVICE)
7272
include_directories(${MPI_C_INCLUDE_DIRS})
7373
endif()
7474

75+
if(ENABLE_NVSHMEM)
76+
# Add hints for aarch64
77+
find_package(NVSHMEM REQUIRED HINTS /usr/lib/sbsa-linux-gnu/cmake/nvshmem/)
78+
include_directories(/usr/include/nvshmem/)
79+
endif()
80+
7581
if(NOT WIN32)
7682
set(DECODER_SHARED_TARGET_0 decoder_attention_0)
7783
set(DECODER_SHARED_TARGET_1 decoder_attention_1)
@@ -231,7 +237,10 @@ if(ENABLE_MULTI_DEVICE)
231237
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB})
232238
endif()
233239

234-
message("TRTLLM_LINK_LIBS: ${TRTLLM_LINK_LIBS}")
240+
if(ENABLE_NVSHMEM)
241+
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} nvshmem::nvshmem_host
242+
nvshmem::nvshmem_device)
243+
endif()
235244

236245
if(NOT WIN32) # Unix-like compilers
237246
set(UNDEFINED_FLAG "-Wl,--no-undefined")

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,12 @@ enum class ClusterShape
332332
ClusterShape_1x2x1,
333333
ClusterShape_2x2x1,
334334
ClusterShape_1x4x1,
335+
ClusterShape_4x1x1,
335336
ClusterShape_4x2x1,
336337
ClusterShape_2x4x1,
337338
ClusterShape_4x4x1,
338339
ClusterShape_1x8x1,
339-
ClusterShape_8x1x1,
340-
ClusterShape_4x1x1
340+
ClusterShape_8x1x1
341341
};
342342

343343
static auto get_cluster_shape_name(ClusterShape Shape_MNK)

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/system_barrier.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "cutlass/barrier.h"
2424

25+
#include <cuda/atomic>
26+
2527
namespace cutlass
2628
{
2729

@@ -43,7 +45,7 @@ __forceinline__ __device__ uint32_t atomicCAS_system_acq(uint32_t* p, uint32_t c
4345

4446
} // namespace detail
4547

46-
template <class Sync, bool SafeBetweenPhases, bool UseMembarGPU>
48+
template <class Sync, bool SafeBetweenPhases>
4749
struct MulticastSystemBarrier : public GenericBarrier<Sync>
4850
{
4951

@@ -57,23 +59,27 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
5759

5860
protected:
5961
/// Reduce into flag, with release pattern (int specialization)
60-
CUTLASS_DEVICE
61-
static void red_release(T* mc_ptr, int val)
62+
template <cuda::thread_scope Scope>
63+
CUTLASS_DEVICE static void red_release(T* mc_ptr, int val)
6264
{
6365
#if defined(CUTE_ARCH_MULTIMEM_SM90_ENABLED)
6466
// atomic reduction to all replicas
6567
// this can be conceptually thought of as __threadfence_system(); atomicAdd_system(arrival_counter_mc, 1);
6668
// See
6769
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-multimem-ld-reduce-multimem-st-multimem-red
6870
// for multimem PTX doc
69-
if constexpr (UseMembarGPU)
71+
if constexpr (Scope == cuda::thread_scope::thread_scope_device)
7072
{
7173
asm volatile("multimem.red.release.gpu.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
7274
}
73-
else
75+
else if constexpr (Scope == cuda::thread_scope::thread_scope_system)
7476
{
7577
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
7678
}
79+
else
80+
{
81+
CUTE_INVALID_CONTROL_PATH("Invalid thread scope for MulticastSystemBarrier.");
82+
}
7783

7884
// Need a fence between MC and UC access to the same memory:
7985
// - fence.proxy instructions establish an ordering between memory accesses that may happen through different
@@ -128,8 +134,8 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
128134
Sync::sync();
129135
}
130136

131-
CUTLASS_DEVICE
132-
static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
137+
template <cuda::thread_scope Scope>
138+
CUTLASS_DEVICE static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
133139
{
134140
T* mc_barrier_ptr = mc_ptr + flag_idx;
135141
T* uc_barrier_ptr = uc_ptr + flag_idx;
@@ -156,37 +162,38 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
156162
// can be immediately reused.
157163
bool master = rank == 0;
158164
int val = master ? 0x80000000 - (world_size - 1) : 1;
159-
red_release(mc_barrier_ptr, val);
165+
red_release<Scope>(mc_barrier_ptr, val);
160166
}
161167
return old_arrive;
162168
}
163169

164-
CUTLASS_DEVICE
165-
static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
170+
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
171+
CUTLASS_DEVICE static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
166172
{
167173
T* mc_barrier = params.mc_barrier_ptr + flag_idx;
168174

169175
Sync::sync();
170176

171177
if (thread_idx == 0)
172178
{
173-
red_release(mc_barrier, 1);
179+
red_release<Scope>(mc_barrier, 1);
174180
}
175181
}
176182

177-
CUTLASS_DEVICE
178-
static void arrive_and_wait(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
183+
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
184+
CUTLASS_DEVICE static void arrive_and_wait(
185+
Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
179186
{
180187
auto mc_ptr = params.mc_barrier_ptr;
181188
auto uc_ptr = params.uc_barrier_ptr;
182189
if constexpr (SafeBetweenPhases)
183190
{
184-
auto old_arrive = arrive_inc_get(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
191+
auto old_arrive = arrive_inc_get<Scope>(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
185192
wait(old_arrive, uc_ptr, thread_idx, flag_idx);
186193
}
187194
else
188195
{
189-
arrive_inc(params, thread_idx, flag_idx, rank, world_size);
196+
arrive_inc<Scope>(params, thread_idx, flag_idx, rank, world_size);
190197
wait_eq_reset(uc_ptr, thread_idx, flag_idx, world_size);
191198
}
192199
}

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ endif()
181181
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
182182
add_library(
183183
ar_gemm_src STATIC
184-
${ARGEMM_SRC_CU}
185-
${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cpp)
184+
${ARGEMM_SRC_CU} ${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cu)
186185
target_include_directories(
187186
ar_gemm_src
188187
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../internal_cutlass_kernels/include)
@@ -233,6 +232,11 @@ function(process_target target_name enable_hopper enable_blackwell)
233232
target_link_libraries(${target_name} PRIVATE ${MPI_C_LIBRARIES})
234233
endif()
235234

235+
if(ENABLE_NVSHMEM)
236+
target_link_libraries(${target_name} PRIVATE nvshmem::nvshmem_host
237+
nvshmem::nvshmem_device)
238+
endif()
239+
236240
endfunction()
237241

238242
set(TARGET_LIB

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm100.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class GemmAllReduceImplTwoshot_Sm100 : public GemmAllReduceImplInterface
138138
// Epilogue
139139
////////////////
140140
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
141-
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true, true>;
141+
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
142142
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
143143
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
144144
using FusionOp

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ class GemmAllReduceImplTwoshot_Sm90 : public GemmAllReduceImplInterface
100100
using RasterOrderOptions =
101101
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
102102

103-
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */,
104-
true /* membar.gpu */>;
103+
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */>;
105104

106105
// 16B alignment for TMA
107106
static constexpr int AlignmentA = 16 / sizeof(ElementA);

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/communication/sm90_allreduce_nvls_warpspecialized.hpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class CollectiveAllReduceMulticastWarpSpecialized
201201
auto [M, N, K, L] = problem_shape;
202202
auto [m, n, k, l] = tile_coord;
203203

204-
if (!tile_valid(m, n) || params_ptr->world_size == 1)
204+
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
205205
{
206206
return; // nothing to do
207207
}
@@ -212,7 +212,7 @@ class CollectiveAllReduceMulticastWarpSpecialized
212212

213213
// Wait for all multicast writes to be visible to us.
214214
// This is safe between phases.
215-
SystemBarrier::arrive_and_wait(
215+
SystemBarrier::arrive_and_wait<cuda::thread_scope::thread_scope_system>(
216216
params_ptr->barrier_params_final_sync, thread_idx, tile_index, params_ptr->rank, params_ptr->world_size);
217217
}
218218

@@ -297,21 +297,28 @@ class CollectiveAllReduceMulticastWarpSpecialized
297297
Tensor tGR_gD1_vec = zipped_divide(tGR_gD1(_, _, _, red_m, red_n), Vec);
298298
Tensor tRG_gOut_vec = zipped_divide(tRG_gOut(_, _, _, red_m, red_n), Vec);
299299

300-
auto pred_fn
301-
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
300+
// Create predicate tensor for bounds checking
301+
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_pD_vec)), Stride<_1>{});
302+
303+
// Set predicate values based on coordinate bounds
304+
CUTLASS_PRAGMA_UNROLL
305+
for (int i = 0; i < size(pred_tensor); ++i)
306+
{
307+
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
308+
}
302309

303310
// Read from self.
304-
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD0_vec, tGR_rD0_vec);
311+
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD0_vec, tGR_rD0_vec);
305312
// Read from remote.
306-
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD1_vec, tGR_rD1_vec);
313+
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD1_vec, tGR_rD1_vec);
307314
// Reduce
308315
CUTLASS_PRAGMA_UNROLL
309316
for (int i = 0; i < size(tGR_rD0_vec); i++)
310317
{
311318
tGR_rD0_vec(i) += tGR_rD1_vec(i);
312319
}
313320
// store to self.
314-
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_rD0_vec, tRG_gOut_vec);
321+
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_rD0_vec, tRG_gOut_vec);
315322
}
316323
}
317324
}
@@ -386,13 +393,21 @@ class CollectiveAllReduceMulticastWarpSpecialized
386393
Tensor tGR_gD_vec = zipped_divide(tGR_gD(_, _, _, red_m, red_n), Vec);
387394
Tensor tRG_gD_vec = zipped_divide(tRG_gD(_, _, _, red_m, red_n), Vec);
388395
Tensor tGR_pD_vec = zipped_divide(tGR_pD(_, _, _, red_m, red_n), Vec);
389-
// problem shape bounds check
390-
auto pred_fn
391-
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
396+
397+
// Create predicate tensor for bounds checking
398+
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_gD_vec)), Stride<_1>{});
399+
400+
// Set predicate values based on coordinate bounds
401+
CUTLASS_PRAGMA_UNROLL
402+
for (int i = 0; i < size(pred_tensor); ++i)
403+
{
404+
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
405+
}
406+
392407
// load-reduce in switch
393-
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD_vec, tGR_rD_vec);
408+
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD_vec, tGR_rD_vec);
394409
// store switch multicast
395-
cute::copy_if(CopyAtomR2G{}, pred_fn, tGR_rD_vec, tRG_gD_vec);
410+
cute::copy_if(CopyAtomR2G{}, pred_tensor, tGR_rD_vec, tRG_gD_vec);
396411
}
397412
}
398413
}

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/epilogue/sm100_visitor_allreduce_tma_warpspecialized.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ struct Sm100AllReduceArrive
171171
tma_store_wait<0>();
172172

173173
int tile_idx = params_ptr->tile_layout(m, n);
174-
SystemBarrier::arrive_inc(
174+
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
175175
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
176176
}
177177
}

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/epilogue/sm90_visitor_allreduce_tma_warpspecialized.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ struct Sm90AuxAllReduce
268268
tma_store_wait<0>();
269269

270270
int tile_idx = params_ptr->tile_layout(m, n);
271-
SystemBarrier::arrive_inc(
271+
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
272272
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
273273
}
274274
};

0 commit comments

Comments
 (0)