Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
'tokens. This is ignored on neuron devices and '
'set to max-model-len')

parser.add_argument('--enable-prefix-caching',
action='store_true',
Expand Down Expand Up @@ -780,7 +781,8 @@ def create_engine_config(self, ) -> EngineConfig:
served_model_name=self.served_model_name,
multimodal_config=multimodal_config)
cache_config = CacheConfig(
block_size=self.block_size,
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
Expand Down
20 changes: 12 additions & 8 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)

logger = init_logger(__name__)

Expand All @@ -24,14 +25,17 @@ def _init_executor(self) -> None:

def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
)
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
self.driver_worker.init_device()
self.driver_worker.load_model()

Expand Down
29 changes: 29 additions & 0 deletions vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""A Neuron worker class."""
import os
from typing import List, Optional, Tuple

import torch
import torch.distributed

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
Expand All @@ -24,12 +27,18 @@ def __init__(
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
Expand All @@ -40,6 +49,9 @@ def __init__(
self.is_driver_worker = True

def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "NEURON"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help take this out?

self.init_distributed_environment()

# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down Expand Up @@ -98,3 +110,20 @@ def get_cache_block_size_bytes(self) -> int:
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError

def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.

vLLM still needs the environment inited when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,
)