diff --git a/vllm/sequence.py b/vllm/sequence.py index 99208fbad65f..1f507add0d91 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1173,6 +1173,10 @@ class PoolingSequenceGroupOutput( # The actual type is in SequenceGroup.pooled_data data: Any + def get_data_nbytes(self) -> int: + data: torch.Tensor = self.data + return data.nbytes + def __repr__(self) -> str: return f"PoolingSequenceGroupOutput(data={self.data}" @@ -1234,6 +1238,9 @@ class PoolerOutput( """The output from a pooling operation in the pooling model.""" outputs: list[PoolingSequenceGroupOutput] + def get_data_nbytes(self) -> int: + return sum(o.get_data_nbytes() for o in self.outputs) + def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: return self.outputs[idx] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cd66d8bcd634..9e3ccb7073a3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,7 +41,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) @@ -1819,7 +1819,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: old_global_expert_indices = None rank_mapping = None - with DeviceMemoryProfiler() as m: # noqa: SIM117 + with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) if not hasattr(self, "model"): @@ -2215,12 +2215,11 @@ def _dummy_sampler_run( ) return sampler_output - @torch.inference_mode() - def _dummy_pooler_run( + def _dummy_pooler_run_task( self, hidden_states: torch.Tensor, - ) -> torch.Tensor: - + task: PoolingTask, + ) -> PoolerOutput: num_tokens = hidden_states.shape[0] max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) @@ -2232,37 +2231,55 @@ def _dummy_pooler_run( hidden_states_list = list( torch.split(hidden_states, num_scheduled_tokens_list)) - req_num_tokens = num_tokens // num_reqs - model = cast(VllmModelForPooling, self.model) - dummy_task = self.get_supported_pooling_tasks()[0] - dummy_pooling_params = PoolingParams(task=dummy_task) + dummy_prompt_lens = torch.tensor( + [h.shape[0] for h in hidden_states_list], + device=self.device, + ) + dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device) - to_update = model.pooler.get_pooling_updates(dummy_task) + model = cast(VllmModelForPooling, self.model) + dummy_pooling_params = PoolingParams(task=task) + to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) dummy_metadata = PoolingMetadata( - prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list], - device=self.device), - prompt_token_ids=torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device), - pooling_params=[dummy_pooling_params] * num_reqs) + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) try: - pooler_output = model.pooler(hidden_states=hidden_states_list, - pooling_metadata=dummy_metadata) + return model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( - "CUDA out of memory occurred when warming up pooler with " - f"{num_reqs} dummy requests. Please try lowering " - "`max_num_seqs` or `gpu_memory_utilization` when " + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " "initializing the engine.") from e else: raise e - return pooler_output + + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> PoolerOutput: + # Find the task that has the largest output for subsequent steps + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + # Run a full batch with each task to ensure none of them OOMs + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = output.get_data_nbytes() + del output # Allow GC + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + return self._dummy_pooler_run_task(hidden_states, max_task) def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache.