Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
Expand Down Expand Up @@ -282,6 +283,9 @@ def event_loop(self):
),
):
self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,13 @@ class LoRAUpdateResult:


LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult


class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2


@dataclass
class BlockReqInput:
type: BlockReqType
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
Expand Down Expand Up @@ -486,6 +487,12 @@ def __init__(
)
self.init_profier()

self.input_blocker = (
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
else None
)

# Init metrics stats
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
Expand Down Expand Up @@ -992,6 +999,9 @@ def recv_requests(self) -> List[Req]:
else:
recv_reqs = None

if self.input_blocker is not None:
recv_reqs = self.input_blocker.handle(recv_reqs)

if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
Expand Down
106 changes: 106 additions & 0 deletions python/sglang/srt/managers/scheduler_input_blocker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, List, Optional

from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
from sglang.srt.poll_based_barrier import PollBasedBarrier

logger = logging.getLogger(__name__)


class SchedulerInputBlocker:
def __init__(self, noop: bool):
self._state = _State.UNBLOCKED
self._pending_reqs = []
self._noop = noop
self._global_unblock_barrier = PollBasedBarrier(noop=noop)

def handle(self, recv_reqs: Optional[List[Any]]):
assert (recv_reqs is None) == self._noop

if not self._noop:
output_reqs = []
for recv_req in recv_reqs:
output_reqs += self._handle_recv_req(recv_req)

global_arrived_unblock_barrier = (
self._global_unblock_barrier.poll_global_arrived()
)
if (
self._state == _State.GLOBAL_UNBLOCK_BARRIER
and global_arrived_unblock_barrier
):
output_reqs += self._handle_arrive_unblock_barrier()

if not self._noop:
return output_reqs

def _handle_recv_req(self, recv_req):
if isinstance(recv_req, BlockReqInput):
if recv_req.type == BlockReqType.BLOCK:
self._execute_block_req()
return []
elif recv_req.type == BlockReqType.UNBLOCK:
self._execute_unblock_req()
return []
else:
raise NotImplementedError(f"{recv_req=}")
else:
if self._state == _State.UNBLOCKED:
return [recv_req]
else:
self._pending_reqs.append(recv_req)
return []

def _execute_block_req(self):
logger.info("Handle block req")
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)

def _execute_unblock_req(self):
logger.info("Handle unblock req")
self._change_state(
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
)
self._global_unblock_barrier.local_arrive()

def _handle_arrive_unblock_barrier(self):
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
self._change_state(
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
)
output_reqs = [*self._pending_reqs]
self._pending_reqs.clear()
return output_reqs

def _change_state(self, original: "_State", target: "_State"):
assert self._state == original, f"{self._state=} {original=} {target=}"
self._state = target


class _State(Enum):
UNBLOCKED = auto()
BLOCKED = auto()
GLOBAL_UNBLOCK_BARRIER = auto()


@contextmanager
def input_blocker_guard_region(send_to_scheduler):
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
try:
yield
finally:
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
24 changes: 18 additions & 6 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import time
import uuid
from collections import deque
from contextlib import nullcontext
from datetime import datetime
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -68,6 +69,7 @@
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
Expand Down Expand Up @@ -112,6 +114,7 @@
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -796,12 +799,21 @@ async def _handle_batch_request(
rids.append(tmp_obj.rid)
else:
# Sequential tokenization and processing
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, state, request))
rids.append(tmp_obj.rid)
with (
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
else nullcontext()
):
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
state = self._send_one_request(
tmp_obj, tokenized_obj, created_time
)
generators.append(
self._wait_one_response(tmp_obj, state, request)
)
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
Expand Down
31 changes: 31 additions & 0 deletions python/sglang/srt/poll_based_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

from sglang.srt.distributed import get_world_group


class PollBasedBarrier:
def __init__(self, noop: bool = False):
self._noop = noop
self._local_arrived = False

def local_arrive(self):
assert not self._local_arrived
self._local_arrived = True

def poll_global_arrived(self) -> bool:
global_arrived = self._compute_global_arrived()
output = self._local_arrived and global_arrived
if output:
self._local_arrived = False
return output

def _compute_global_arrived(self) -> bool:
local_arrived = self._noop or self._local_arrived
global_arrived = torch.tensor(local_arrived)
# Can optimize if bottleneck
torch.distributed.all_reduce(
global_arrived,
torch.distributed.ReduceOp.MIN,
group=get_world_group().cpu_group,
)
return global_arrived.item()
Loading