From 07fc806caeb53916c16205be7eac2bae614a54cf Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 31 Mar 2024 10:53:49 -0700 Subject: [PATCH 1/6] [Core] Eliminate parallel worker inter-token scheduling overhead There's no need for the parallel workers to be scheduled each step. --- vllm/distributed/communication_op.py | 4 +- vllm/engine/async_llm_engine.py | 6 +- vllm/engine/llm_engine.py | 6 +- vllm/executor/executor_base.py | 10 +++ vllm/executor/ray_gpu_executor.py | 101 ++++++++++++--------------- vllm/worker/worker.py | 42 +++++++---- 6 files changed, 96 insertions(+), 73 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index a3e93691a1e8..ff5e8611369d 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -147,14 +147,14 @@ def broadcast_tensor_dict( ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary.""" group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. world_size = torch.distributed.get_world_size(group=group) if world_size == 1: return tensor_dict + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" rank = torch.distributed.get_rank() if rank == src: metadata_list: List[Tuple[Any, Any]] = [] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f61049513512..3e361679db74 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,7 +217,11 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + outputs = self._process_model_outputs(output, scheduler_outputs) + if not outputs: + # Stop the execute model loop in parallel workers for now + await self.model_executor.halt_model_async() + return outputs async def encode_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee..8aca6fd4c315 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -726,7 +726,11 @@ def step(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + outputs = self._process_model_outputs(output, scheduler_outputs) + if not outputs: + # Stop the execute model loop in parallel workers for now + self.model_executor.halt_model() + return outputs def do_log_stats(self) -> None: """Forced log when no requests active.""" diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b7..109eef585a09 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -76,6 +76,11 @@ def execute_model(self, """Executes one model step on the given sequences.""" raise NotImplementedError + @abstractmethod + async def halt_model(self) -> None: + """Releases parallel workers from model loop.""" + pass + @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError @@ -108,6 +113,11 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError + @abstractmethod + async def halt_model_async(self) -> None: + """Releases parallel workers from model loop.""" + pass + async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f65253..43ac8c8aa00c 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -35,6 +35,7 @@ def _init_executor(self) -> None: assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group + self.model_running = False # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") @@ -223,19 +224,24 @@ def execute_model(self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + if not self.model_running: + # Start model execution loop running in the parallel workers + _ = self._run_workers("execute_model_parallel", + async_remote_only=True, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + self.model_running = True # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output + return self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + + def halt_model(self) -> None: + if self.model_running: + self.driver_worker.execute_model() + self.model_running = False def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -258,8 +264,7 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_remote_only: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, @@ -275,6 +280,7 @@ def _run_workers( # input. TODO(sang): Fix it. assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) + ray_worker_outputs = None else: # Start the ray workers first. ray_worker_outputs = [ @@ -282,14 +288,13 @@ def _run_workers( for worker in self.workers ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + if async_remote_only: + # Just return futures + return ray_worker_outputs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) + driver_worker_output = getattr(self.driver_worker, method)(*args, + **kwargs) # Get the results of the ray workers. if self.workers: @@ -348,33 +353,6 @@ def _check_if_any_actor_is_dead(self): class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Run the driver worker asynchronously. - driver_executor = make_async(getattr(self.driver_worker, method)) - coros.append(driver_executor(*driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs - async def execute_model_async( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -382,15 +360,26 @@ async def execute_model_async( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) + if not self.model_running: + # Start model execution loop running in the parallel workers + _ = asyncio.create_task(self._execute_model_parallel()) + self.model_running = True # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output + return await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + + async def _execute_model_parallel(self): + coros = [ + worker.execute_method.remote("execute_model_parallel") + for worker in self.workers + ] + return await asyncio.gather(*coros) + + async def halt_model_async(self) -> None: + if self.model_running: + await make_async(self.driver_worker.execute_model)() + self.model_running = False diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82491c6df661..8a3285a69b4a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -211,8 +211,11 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None + assert self.is_driver_worker + if seq_group_metadata_list is None: + data = {} + num_seq_groups = 0 + else: num_seq_groups = len(seq_group_metadata_list) assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None @@ -223,23 +226,36 @@ def execute_model( "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return None + + return self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + + @torch.inference_mode() + def execute_model_parallel(self) -> None: + """Execute model loop in parallel worker.""" + assert not self.is_driver_worker + while True: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data.get("num_seq_groups", 0) + blocks_to_swap_in = data.get("blocks_to_swap_in") + blocks_to_swap_out = data.get("blocks_to_swap_out") + blocks_to_copy = data.get("blocks_to_copy") + + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, + blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return None - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) - return output + self.model_runner.execute_model(None, self.gpu_cache) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) From 5f638d4296a536cfd68d0a4a291dadc9b9b5d467 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Apr 2024 16:50:19 -0700 Subject: [PATCH 2/6] Check result of worker tasks between loop executions So that any errors are still propagated properly --- vllm/executor/ray_gpu_executor.py | 46 +++++++++++++++++++++---------- vllm/worker/worker.py | 2 +- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 43ac8c8aa00c..4c4fc2ac6477 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -35,7 +35,10 @@ def _init_executor(self) -> None: assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group - self.model_running = False + + # This is non-None when the execute model loop is running + # in the parallel workers + self.parallel_worker_tasks = None # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") @@ -224,12 +227,13 @@ def execute_model(self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - if not self.model_running: + if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers - _ = self._run_workers("execute_model_parallel", - async_remote_only=True, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - self.model_running = True + parallel_worker_tasks = self._run_workers( + "execute_model_parallel", + async_remote_only=True, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + self.parallel_worker_tasks = asyncio.gather(*parallel_worker_tasks) # Only the driver worker returns the sampling results. return self.driver_worker.execute_model( @@ -239,9 +243,15 @@ def execute_model(self, blocks_to_copy=blocks_to_copy) def halt_model(self) -> None: - if self.model_running: - self.driver_worker.execute_model() - self.model_running = False + if self.parallel_worker_tasks is None: + return + + self.driver_worker.execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + ray.get(parallel_worker_tasks) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -360,10 +370,10 @@ async def execute_model_async( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> SamplerOutput: - if not self.model_running: + if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers - _ = asyncio.create_task(self._execute_model_parallel()) - self.model_running = True + self.parallel_worker_tasks = asyncio.create_task( + self._execute_model_parallel()) # Only the driver worker returns the sampling results. return await make_async(self.driver_worker.execute_model)( @@ -380,6 +390,12 @@ async def _execute_model_parallel(self): return await asyncio.gather(*coros) async def halt_model_async(self) -> None: - if self.model_running: - await make_async(self.driver_worker.execute_model)() - self.model_running = False + if self.parallel_worker_tasks is None: + return + + await make_async(self.driver_worker.execute_model)() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8a3285a69b4a..f2f9626aa57a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -213,8 +213,8 @@ def execute_model( ) -> Optional[SamplerOutput]: assert self.is_driver_worker if seq_group_metadata_list is None: - data = {} num_seq_groups = 0 + data = {} else: num_seq_groups = len(seq_group_metadata_list) assert blocks_to_swap_in is not None From 8fcfaa7a7eb3a1a1482f992258a66d18917472a6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Apr 2024 18:07:27 -0700 Subject: [PATCH 3/6] Change halt_model_async method to not be abstract Default behaviour is no-op (single GPU) --- vllm/executor/executor_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 109eef585a09..1ee3236ad9c6 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -113,7 +113,6 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def halt_model_async(self) -> None: """Releases parallel workers from model loop.""" pass From 4835c6af79fa3b7fa2d8718191a96dc00400b743 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Apr 2024 20:33:24 -0700 Subject: [PATCH 4/6] Change halt_model method to not be abstract Default behaviour is no-op (single GPU) --- vllm/executor/executor_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 1ee3236ad9c6..ed347721bea3 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -76,7 +76,6 @@ def execute_model(self, """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def halt_model(self) -> None: """Releases parallel workers from model loop.""" pass From 2d7161645cbe574dcadaba7423809ec5fe354be4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Apr 2024 20:58:13 -0700 Subject: [PATCH 5/6] Make ruff happy --- vllm/executor/executor_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index ed347721bea3..abb75bfd2db3 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -78,7 +78,7 @@ def execute_model(self, async def halt_model(self) -> None: """Releases parallel workers from model loop.""" - pass + return @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: @@ -114,7 +114,7 @@ async def execute_model_async( async def halt_model_async(self) -> None: """Releases parallel workers from model loop.""" - pass + return async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an From cc1b5455b0ef5b4ce42589f27a119c2cf0c79301 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 9 Apr 2024 14:03:33 +0100 Subject: [PATCH 6/6] Address review comments from @zhuohan123 --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- vllm/executor/executor_base.py | 4 ++-- vllm/executor/ray_gpu_executor.py | 28 ++++++++++++++-------------- vllm/worker/worker.py | 3 ++- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3e361679db74..7f0e16bdaedc 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -220,7 +220,7 @@ async def step_async(self) -> List[RequestOutput]: outputs = self._process_model_outputs(output, scheduler_outputs) if not outputs: # Stop the execute model loop in parallel workers for now - await self.model_executor.halt_model_async() + await self.model_executor.stop_remote_worker_execution_loop_async() return outputs async def encode_request_async( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8aca6fd4c315..824bf84a690d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -729,7 +729,7 @@ def step(self) -> List[RequestOutput]: outputs = self._process_model_outputs(output, scheduler_outputs) if not outputs: # Stop the execute model loop in parallel workers for now - self.model_executor.halt_model() + self.model_executor.stop_remote_worker_execution_loop() return outputs def do_log_stats(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index abb75bfd2db3..be2b0166e538 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -76,7 +76,7 @@ def execute_model(self, """Executes one model step on the given sequences.""" raise NotImplementedError - async def halt_model(self) -> None: + def stop_remote_worker_execution_loop(self) -> None: """Releases parallel workers from model loop.""" return @@ -112,7 +112,7 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError - async def halt_model_async(self) -> None: + async def stop_remote_worker_execution_loop_async(self) -> None: """Releases parallel workers from model loop.""" return diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4c4fc2ac6477..65aa2dc69b05 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,7 +3,8 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, + Tuple, Union) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase @@ -38,7 +39,7 @@ def _init_executor(self) -> None: # This is non-None when the execute model loop is running # in the parallel workers - self.parallel_worker_tasks = None + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") @@ -229,11 +230,10 @@ def execute_model(self, blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers - parallel_worker_tasks = self._run_workers( - "execute_model_parallel", - async_remote_only=True, + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + remote_workers_only_async=True, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - self.parallel_worker_tasks = asyncio.gather(*parallel_worker_tasks) # Only the driver worker returns the sampling results. return self.driver_worker.execute_model( @@ -242,7 +242,7 @@ def execute_model(self, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) - def halt_model(self) -> None: + def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: return @@ -274,7 +274,7 @@ def _run_workers( self, method: str, *args, - async_remote_only: bool = False, + remote_workers_only_async: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, @@ -290,7 +290,7 @@ def _run_workers( # input. TODO(sang): Fix it. assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) - ray_worker_outputs = None + ray_worker_outputs = [] else: # Start the ray workers first. ray_worker_outputs = [ @@ -298,7 +298,7 @@ def _run_workers( for worker in self.workers ] - if async_remote_only: + if remote_workers_only_async: # Just return futures return ray_worker_outputs @@ -373,7 +373,7 @@ async def execute_model_async( if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( - self._execute_model_parallel()) + self._start_worker_execution_loop()) # Only the driver worker returns the sampling results. return await make_async(self.driver_worker.execute_model)( @@ -382,14 +382,14 @@ async def execute_model_async( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) - async def _execute_model_parallel(self): + async def _start_worker_execution_loop(self): coros = [ - worker.execute_method.remote("execute_model_parallel") + worker.execute_method.remote("start_worker_execution_loop") for worker in self.workers ] return await asyncio.gather(*coros) - async def halt_model_async(self) -> None: + async def stop_remote_worker_execution_loop_async(self) -> None: if self.parallel_worker_tasks is None: return diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f2f9626aa57a..b74ef5fb373d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -213,6 +213,7 @@ def execute_model( ) -> Optional[SamplerOutput]: assert self.is_driver_worker if seq_group_metadata_list is None: + # No data to run, notify other workers to stop the execution loop. num_seq_groups = 0 data = {} else: @@ -238,7 +239,7 @@ def execute_model( self.gpu_cache) @torch.inference_mode() - def execute_model_parallel(self) -> None: + def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker.""" assert not self.is_driver_worker while True: