Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -16,23 +15,13 @@

class CPUExecutor(ExecutorBase):

def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
assert device_config.device_type == "cpu"
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)

# Instantiate the worker and load the model to CPU.
self._init_worker()
Expand Down Expand Up @@ -96,7 +85,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

def check_health(self) -> None:
Expand Down
27 changes: 20 additions & 7 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
TensorizerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata

Expand All @@ -16,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""

@abstractmethod
def __init__(
self,
model_config: ModelConfig,
Expand All @@ -27,8 +26,23 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
raise NotImplementedError
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config

self._init_executor()

@abstractmethod
def _init_executor(self) -> None:
pass

@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -71,7 +85,7 @@ def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
raise NotImplementedError

@abstractmethod
Expand All @@ -94,8 +108,7 @@ async def execute_model_async(
"""Executes one model step on the given sequences."""
raise NotImplementedError

@abstractmethod
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
self.check_health()
32 changes: 4 additions & 28 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -15,24 +12,8 @@

class GPUExecutor(ExecutorBase):

def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig]) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config

assert (not speculative_config
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for GPU backend"

# Instantiate the worker and load the model to GPU.
Expand Down Expand Up @@ -103,7 +84,7 @@ def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

def check_health(self) -> None:
Expand All @@ -127,8 +108,3 @@ async def execute_model_async(
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output

async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
29 changes: 6 additions & 23 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -13,24 +10,10 @@

class NeuronExecutor(ExecutorBase):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (self.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not self.speculative_config
), "Speculative decoding not yet supported for Neuron backend."

# Instantiate the worker and load the model to the device.
Expand Down Expand Up @@ -80,7 +63,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

def check_health(self) -> None:
Expand Down
34 changes: 4 additions & 30 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
Expand All @@ -32,27 +29,8 @@

class RayGPUExecutor(ExecutorBase):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."

assert self.parallel_config.worker_use_ray
Expand Down Expand Up @@ -273,7 +251,7 @@ def remove_lora(self, lora_id: int) -> bool:
lora_id=lora_id,
)

def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

def _run_workers(
Expand Down Expand Up @@ -416,7 +394,3 @@ async def execute_model_async(
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output

async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()