diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index d7274cf2ccba..1c0852041d9c 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -5,7 +5,6 @@ from torch.cuda.memory import CUDAPluggableAllocator from sglang.srt.distributed.parallel_state import GroupCoordinator -from sglang.srt.managers.schedule_batch import global_server_args_dict nccl_allocator_source = """ #include @@ -29,12 +28,26 @@ _mem_pool = None _registered_base_addrs = set() _graph_pool_id = None +_cached_pool_snapshot = None def is_symmetric_memory_enabled(): + # Import here to avoid circular import + from sglang.srt.managers.schedule_batch import global_server_args_dict + return global_server_args_dict["enable_symm_mem"] +def is_symmetric_memory_tensor(tensor: torch.Tensor): + if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: + return False + for segment in _cached_pool_snapshot: + for block in segment["blocks"]: + if block["address"] == tensor.untyped_storage().data_ptr(): + return True + return False + + def set_graph_pool_id(graph_pool_id): global _graph_pool_id _graph_pool_id = graph_pool_id @@ -64,8 +77,17 @@ def get_nccl_mem_pool(): class use_symmetric_memory: - def __init__(self, group_coordinator: GroupCoordinator): - if not is_symmetric_memory_enabled(): + def __init__( + self, + group_coordinator: GroupCoordinator, + disabled: bool = False, + ): + self.disabled = ( + disabled + or not is_symmetric_memory_enabled() + or group_coordinator.world_size == 1 + ) + if self.disabled: self.group_coordinator = None self._mem_pool_ctx = None self.is_graph_capture = None @@ -79,7 +101,7 @@ def __init__(self, group_coordinator: GroupCoordinator): self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") def __enter__(self): - if not is_symmetric_memory_enabled(): + if self.disabled: return self assert ( self.group_coordinator.pynccl_comm is not None @@ -101,17 +123,14 @@ def __enter__(self): self._mem_pool_ctx.__enter__() return self - def tag(self, tensor: torch.Tensor): - if not is_symmetric_memory_enabled(): - return - tensor.symmetric_memory = True - def __exit__(self, exc_type, exc_val, exc_tb): - if not is_symmetric_memory_enabled(): + if self.disabled: return + global _cached_pool_snapshot global _registered_base_addrs self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) - for segment in get_nccl_mem_pool().snapshot(): + _cached_pool_snapshot = get_nccl_mem_pool().snapshot() + for segment in _cached_pool_snapshot: if segment["address"] not in _registered_base_addrs: if segment["stream"] == 0 and self.pre_2_8_0: # PyTorch version < 2.8.0 has a multi-thread MemPool bug diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index dc120f761814..6be29a4ed5d2 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -284,7 +284,13 @@ def __init__( from sglang.srt.distributed.device_communicators.pynccl import ( PyNcclCommunicator, ) + from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_tensor, + use_symmetric_memory, + ) + self.is_symmetric_memory_tensor = is_symmetric_memory_tensor + self.use_symmetric_memory = use_symmetric_memory if is_hip(): from sglang.srt.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce, @@ -509,11 +515,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if ( - self.pynccl_comm is not None - and hasattr(input_, "symmetric_memory") - and input_.symmetric_memory - ): + if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_): with self.pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream() ): @@ -579,9 +581,23 @@ def reduce_scatter_tensor( self, output: torch.Tensor, input: torch.Tensor, - ) -> None: - # TODO(ch-wan): support other backends - torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) + ) -> torch.Tensor: + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and ( + not pynccl_comm.disabled + or ( + self.is_symmetric_memory_tensor(output) + and self.is_symmetric_memory_tensor(input) + ) + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.reduce_scatter(output, input) + else: + torch.distributed.reduce_scatter_tensor( + output, input, group=self.device_group + ) return output def reduce_scatter( @@ -628,8 +644,17 @@ def reduce_scatterv( def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.all_gather(output, input) + if pynccl_comm is not None and ( + not pynccl_comm.disabled + or ( + self.is_symmetric_memory_tensor(output) + and self.is_symmetric_memory_tensor(input) + ) + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.all_gather(output, input) else: torch.distributed.all_gather_into_tensor( output, input, group=self.device_group @@ -691,9 +716,10 @@ def all_gather( # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) + with self.use_symmetric_memory(self): + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. if input_.is_cpu and is_shm_available( diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4e422a3601a2..64a4df14f610 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -21,8 +21,12 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, attn_tp_reduce_scatter_tensor, @@ -469,7 +473,12 @@ def _gather_hidden_states_and_residual( use_layer_norm_before_gather = context.attn_tp_size == 1 if use_layer_norm_before_gather and hidden_states.shape[0] != 0: residual = hidden_states - hidden_states = layernorm(hidden_states) + with use_symmetric_memory( + get_tp_group(), + disabled=not forward_batch.dp_padding_mode.is_max_len(), + ): + hidden_states = layernorm(hidden_states) + hidden_states, local_hidden_states = ( get_global_dp_buffer(), hidden_states, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 1250636eb900..2863e0097d54 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -17,6 +17,9 @@ get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -72,6 +75,7 @@ class _DpGatheredBufferWrapper: _device: torch.device _global_dp_buffer_len: int _local_dp_buffer_len: int + _dp_max_padding: bool = True _global_num_tokens: Optional[List[int]] @classmethod @@ -85,27 +89,33 @@ def set_dp_buffer_len( cls, global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): cls._global_dp_buffer_len = global_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len + cls._dp_max_padding = dp_max_padding cls._global_num_tokens = global_num_tokens @classmethod def get_global_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._global_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group()): + buffer = torch.empty( + (cls._global_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_local_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._local_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding): + buffer = torch.empty( + (cls._local_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_global_dp_buffer_len(cls) -> int: @@ -119,14 +129,19 @@ def get_local_dp_buffer_len(cls) -> int: def get_dp_global_num_tokens(cls) -> List[int]: return cls._global_num_tokens + @classmethod + def is_dp_max_padding(cls) -> bool: + return cls._dp_max_padding + def set_dp_buffer_len( global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): _DpGatheredBufferWrapper.set_dp_buffer_len( - global_dp_buffer_len, local_dp_buffer_len, global_num_tokens + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens ) @@ -150,6 +165,10 @@ def get_dp_global_num_tokens() -> List[int]: return _DpGatheredBufferWrapper.get_dp_global_num_tokens() +def is_dp_max_padding() -> bool: + return _DpGatheredBufferWrapper.is_dp_max_padding() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -408,7 +427,10 @@ def _dp_gather_via_all_gather( scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[ get_attention_tp_rank() ] - get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens) + if get_attention_tp_size() > 1: + get_attention_tp_group().reduce_scatter_tensor( + scattered_local_tokens, local_tokens + ) get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens) @@ -467,7 +489,7 @@ def dp_scatter( def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): - if get_tensor_model_parallel_world_size() == get_attention_dp_size(): + if get_attention_tp_size() == 1: get_tp_group().reduce_scatter_tensor(output, input) else: scattered_local_tokens = input.tensor_split( diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 47dfc7324fc0..71fa26ffd9dc 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,7 +13,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -21,6 +21,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.layers.dp_attention import is_dp_max_padding from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1316,9 +1317,8 @@ def forward(self, input_, skip_all_reduce=False): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - sm.tag(output_parallel) if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a4fb29929de7..4f255c8f4163 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -194,6 +194,7 @@ def compute_dp_attention_metadata(self): set_dp_buffer_len( self.global_dp_buffer_len, self.dp_local_num_tokens, + False, self.global_num_tokens_for_logprob_cpu, ) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index d0fb4e3ef48b..f26bf5dbb61a 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -40,6 +40,7 @@ def cutlass_fused_experts_fp8( problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, use_fp8_blockscale: bool = True, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. @@ -94,7 +95,7 @@ def cutlass_fused_experts_fp8( b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with block scaling. Currently, only `True` is supported. Defaults to `True`. - + output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created. Returns: torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. @@ -202,9 +203,11 @@ def cutlass_fused_experts_fp8( workspace, ) - result = torch.empty((m, k), device=device, dtype=out_dtype) - apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype)) - return result + if output is None: + output = torch.empty((m, k), device=device, dtype=out_dtype) + + apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype)) + return output FLOAT4_E2M1_MAX = 6.0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b88c60d969b2..cacdfa5595a2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -18,6 +18,7 @@ use_symmetric_memory, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata +from sglang.srt.layers.dp_attention import is_dp_max_padding from sglang.srt.layers.moe import ( MoeRunnerConfig, get_moe_runner_backend, @@ -812,15 +813,12 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): raise NotImplementedError() # Matrix multiply. - with use_symmetric_memory(get_tp_group()) as sm: - - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - topk_output=topk_output, - moe_runner_config=self.moe_runner_config, - ) - sm.tag(final_hidden_states) + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + topk_output=topk_output, + moe_runner_config=self.moe_runner_config, + ) final_hidden_states = final_hidden_states[ ..., :origin_hidden_states_dim @@ -1018,6 +1016,8 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): router_logits = router_logits.to(torch.float32) + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): + symm_output = torch.empty_like(hidden_states) result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, routing_bias=topk_config.correction_bias.to(hidden_states.dtype), @@ -1052,6 +1052,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): ), routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, + output=symm_output, )[0] return result diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4915d4d084e1..4e46a31817d4 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -10,6 +10,12 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_dp_max_padding + try: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -1033,6 +1039,8 @@ def apply( from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 topk_weights, topk_ids, _ = topk_output + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): + symm_output = torch.empty_like(x) output = cutlass_fused_experts_fp8( x, layer.w13_weight.transpose(1, 2), @@ -1055,6 +1063,7 @@ def apply( self.problem_sizes1, self.problem_sizes2, use_fp8_blockscale=True, + output=symm_output, ) # Scale by routed_scaling_factor is fused into select_experts. return output @@ -1112,31 +1121,32 @@ def apply_with_router_logits( correction_bias = topk_config.correction_bias.to(x.dtype) else: correction_bias = None - return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - tile_tokens_dim=get_tile_tokens_dim( - x.shape[0], topk_config.top_k, layer.num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): + return trtllm_fp8_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=get_tile_tokens_dim( + x.shape[0], topk_config.top_k, layer.num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index bd43672341f6..3dfdb895e1a7 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -8,7 +8,14 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tp_group -from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import ( + get_dp_global_num_tokens, + get_local_dp_buffer, + is_dp_max_padding, +) from sglang.srt.layers.moe import ( should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_trtllm_moe, @@ -1303,7 +1310,9 @@ def apply( topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids output_dtype = x.dtype + output_col = x.shape[1] x_sf = None + if should_use_flashinfer_cutlass_moe_fp4_allgather(): from flashinfer import fp4_quantize, nvfp4_block_scale_interleave @@ -1323,7 +1332,13 @@ def apply( ) x_sf = nvfp4_block_scale_interleave(x_sf) + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): + symm_output = torch.empty( + x.shape[0], output_col, dtype=output_dtype, device=x.device + ) + output = flashinfer_cutlass_fused_moe( + output=symm_output, input=x, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index c353cbba32a4..992a16dd6da4 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -22,6 +22,11 @@ import torch from torch.nn.parameter import Parameter +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.distributed.parallel_state import get_tp_group +from sglang.srt.layers.dp_attention import is_dp_max_padding from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -653,6 +658,8 @@ def apply( top_k = topk_output.topk_config.top_k router_logits = topk_output.router_logits + with use_symmetric_memory(get_tp_group(), disabled=not is_dp_max_padding()): + symm_output = torch.empty_like(x) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -681,6 +688,7 @@ def apply( self._get_tile_tokens_dim(x, top_k), 1, # routing_method_type, renormalize True, # do finalize + output=symm_output, )[0] return trtllm_gen_output diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 66abb75410bc..7b9f9c3d2cda 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,7 +11,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, tensor_model_parallel_all_reduce, ) from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -473,9 +473,8 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group(), disabled=not self.enable_tp): output_parallel = self.quant_method.embedding(self, masked_input.long()) - sm.tag(output_parallel) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2effec9c02a0..453385eb6860 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -638,7 +638,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) kwargs = {} if ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 65c0a07f8ab1..8fe1a7f54ff4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -652,7 +652,9 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): num_tokens = global_num_tokens[0] self.global_dp_buffer_len = buffer_len - set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) + set_dp_buffer_len( + buffer_len, num_tokens, dp_padding_mode.is_max_len(), global_num_tokens + ) bs = self.batch_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bceb60cfefb5..2b949ba0a66d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -35,9 +35,6 @@ parallel_state, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -231,7 +228,8 @@ def forward( gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj( - x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter + x, + skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, ) return x @@ -482,12 +480,8 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) + final_hidden_states += shared_output - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion @@ -522,11 +516,7 @@ def forward_normal( # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) + final_hidden_states += shared_output if ( self.tp_size > 1 and not should_allreduce_fusion @@ -595,7 +585,9 @@ def forward_cpu( return final_hidden_states def forward_deepep( - self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: shared_output = None if hidden_states.shape[0] > 0: @@ -1301,7 +1293,13 @@ def forward_absorb_prepare( return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator def forward_absorb_core( - self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator + self, + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, ): if ( self.current_attention_backend == "fa3" @@ -1892,7 +1890,10 @@ def forward( forward_batch ) hidden_states = self.mlp( - hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter + hidden_states, + forward_batch, + should_allreduce_fusion, + use_reduce_scatter, ) if should_allreduce_fusion: diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index f8730cd77232..9d824587c4da 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -85,6 +85,7 @@ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._local_dp_buffer_len = forward_batch.input_ids.shape[0] self._global_num_tokens = forward_batch.global_num_tokens_cpu + self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len() def next(self): assert not self.done @@ -95,6 +96,7 @@ def next(self): set_dp_buffer_len( self._global_dp_buffer_len, self._local_dp_buffer_len, + self._is_dp_max_padding, self._global_num_tokens, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 3ee3b1c54967..81097c4b7063 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -244,7 +244,11 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 4f4403fee50c..f1f4533b8d10 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -273,7 +273,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc