Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/v1/core/sched/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarBitmask, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
Expand Down Expand Up @@ -41,6 +41,14 @@ def schedule(self) -> "SchedulerOutput":
"""
raise NotImplementedError

@abstractmethod
def get_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
) -> Optional["GrammarBitmask"]:
"""Get the grammar bitmask for the scheduled requests."""
raise NotImplementedError

@abstractmethod
def update_from_output(
self,
Expand Down
10 changes: 7 additions & 3 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,15 @@ class SchedulerOutput:
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None


@dataclass
class GrammarBitmask:

# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
35 changes: 15 additions & 20 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.output import (CachedRequestData, GrammarBitmask,
NewRequestData, SchedulerOutput)
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.utils import check_stop, remove_all
Expand Down Expand Up @@ -534,9 +534,6 @@ def schedule(self) -> SchedulerOutput:
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
structured_output_request_ids, grammar_bitmask = (
self.get_grammar_bitmask(self.running,
scheduled_spec_decode_tokens))
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
Expand All @@ -551,8 +548,6 @@ def schedule(self) -> SchedulerOutput:
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)

# NOTE(Kuntai): this function is designed for multiple purposes:
Expand Down Expand Up @@ -736,33 +731,33 @@ def _try_schedule_encoder_inputs(

def get_grammar_bitmask(
self,
requests: list[Request],
scheduled_spec_decode_tokens: dict[str, list[int]],
):
scheduler_output: SchedulerOutput,
) -> Optional[GrammarBitmask]:
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
req_ids = scheduler_output.num_scheduled_tokens.keys()
for i, req_id in enumerate(req_ids):
req = self.requests.get(req_id)
if req is not None and req.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i

if not structured_output_request_ids:
bitmask = None
else:
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
return None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduler_output.scheduled_spec_decode_tokens,
)
return GrammarBitmask(structured_output_request_ids, bitmask)

def update_from_output(
self,
Expand Down
49 changes: 42 additions & 7 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
self.model_executor.prepare_inputs(scheduler_output)
self.model_executor.execute_model()
bitmask = self.scheduler.get_grammar_bitmask(scheduler_output)
model_output = self.model_executor.sample(bitmask, non_block=False)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore

return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)

Expand Down Expand Up @@ -327,7 +327,10 @@
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.model_executor.prepare_inputs(scheduler_output)
self.model_executor.execute_model()
bitmask = self.scheduler.get_grammar_bitmask(scheduler_output)
future = self.model_executor.sample(bitmask)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore

Expand All @@ -353,6 +356,32 @@

return engine_core_outputs, scheduled_batch

def step_async(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
model_output = None
engine_core_outputs = None
bitmask = None

scheduler_output = self.scheduler.schedule()
is_scheduled = scheduler_output.total_num_scheduled_tokens > 0
if is_scheduled:
self.model_executor.prepare_inputs(scheduler_output)
if self.inflight_batch:

Check failure on line 369 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot determine type of "inflight_batch" [has-type]
model_output = self.model_executor.sample(self.prev_bitmask)

Check failure on line 370 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot determine type of "prev_bitmask" [has-type]
self.model_executor.execute_model()
bitmask = self.scheduler.get_grammar_bitmask(scheduler_output)
elif self.inflight_batch:

Check failure on line 373 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot determine type of "inflight_batch" [has-type]
model_output = self.model_executor.sample(self.prev_bitmask)

Check failure on line 374 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot determine type of "prev_bitmask" [has-type]

if model_output is not None:
engine_core_outputs = self.scheduler.update_from_output(
self.prev_scheduler_output, model_output.result())

Check failure on line 378 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "ModelRunnerOutput" of "Union[ModelRunnerOutput, Future[ModelRunnerOutput]]" has no attribute "result" [union-attr]

Check failure on line 378 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Cannot determine type of "prev_scheduler_output" [has-type]

self.inflight_batch = is_scheduled
self.prev_scheduler_output = scheduler_output
self.prev_bitmask = bitmask
return engine_core_outputs, is_scheduled

def shutdown(self):
self.structured_output_manager.clear_backend()
if self.model_executor:
Expand Down Expand Up @@ -529,8 +558,14 @@
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...")

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
if self.batch_queue is None:
if self.vllm_config.scheduler_config.async_scheduling:
self.step_fn = self.step_async
self.inflight_batch = False
else:
self.step_fn = self.step
else:
self.step_fn = self.step_with_batch_queue

@contextmanager
def _perform_handshakes(
Expand Down
15 changes: 11 additions & 4 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,19 @@ def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output

def execute_model(
def prepare_inputs(self, scheduler_output) -> None:
self.collective_rpc("prepare_inputs", args=(scheduler_output, ))

def execute_model(self) -> None:
self.collective_rpc("execute_model")

def sample(
self,
scheduler_output,
grammar_bitmask,
non_block: bool = True,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
del non_block
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

output = self.collective_rpc("sample", args=(grammar_bitmask, ))
return output[0]

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
Expand Down
55 changes: 39 additions & 16 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,11 @@ def _init_executor(self) -> None:

# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")

self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
Expand Down Expand Up @@ -162,26 +161,42 @@ def register_failure_callback(self, callback: FailureCallback):
else:
self.failure_callback = callback

def execute_model(
def prepare_inputs(self, scheduler_output) -> None:
self.collective_rpc(
"prepare_inputs",
args=(scheduler_output, ),
non_block=True,
skip_response=True,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

def execute_model(self) -> None:
self.collective_rpc(
"execute_model",
non_block=True,
skip_response=True,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

def sample(
self,
scheduler_output,
grammar_bitmask,
non_block: bool = True,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
non_block = self.max_concurrent_batches > 1

if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
"sample",
args=(grammar_bitmask, ),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output

# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
"sample",
args=(grammar_bitmask, ),
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)

Expand All @@ -203,6 +218,7 @@ def collective_rpc(self,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
skip_response: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
Expand All @@ -219,6 +235,15 @@ def collective_rpc(self,
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)

if skip_response:
if unique_reply_rank is not None:
raise ValueError(
"unique_reply_rank must be None "
f"when skip_response is True. got {unique_reply_rank}")
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, -1))
return []

self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank))

Expand Down Expand Up @@ -309,8 +334,6 @@ def check_health(self) -> None:

@property
def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size

def _get_output_rank(self) -> int:
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
if self.scheduler_config.async_scheduling:
return 2
return self.parallel_config.pipeline_parallel_size

def execute_model(

Check failure on line 63 in vllm/v1/executor/ray_distributed_executor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "execute_model" incompatible with supertype "Executor" [override]
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
Expand Down
Loading
Loading