diff --git a/tests/neuron/test_comm_ops.py b/tests/neuron/test_comm_ops.py new file mode 100644 index 000000000000..3cad160b2cb7 --- /dev/null +++ b/tests/neuron/test_comm_ops.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +import functools +from typing import Callable +from unittest.mock import patch + +import pytest +import torch +import torch_xla.distributed.xla_multiprocessing as xmp +from typing_extensions import ParamSpec + +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils import get_distributed_init_method, get_open_port + +_P = ParamSpec("_P") + + +def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to reinitialize the Neuron Runtime before executing a test. + This is necessary for distributed tests which need to reallocate Neuron + Cores to separate subprocesses. + """ + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + runtime = torch.classes.neuron.Runtime() + runtime.initialize() + runtime.unsafe_close() + + f(*args, **kwargs) + runtime.initialize() + + return wrapper + + +def all_gather_test_worker(index, tp_degree, distributed_init_method): + init_distributed_environment(tp_degree, + index, + distributed_init_method, + index, + backend="xla") + ensure_model_parallel_initialized(tp_degree, 1) + + num_dimensions = 3 + tensor_size = list(range(2, num_dimensions + 2)) + total_size = 1 + for s in tensor_size: + total_size *= s + + all_gather_dimension = -1 + all_tensors = [ + torch.arange(total_size, dtype=torch.float32, + device="xla").reshape(tensor_size) * (r + 1) + for r in range(tp_degree) + ] + expected = torch.cat(all_tensors, dim=all_gather_dimension) + t = all_tensors[index % tp_degree] + t = tensor_model_parallel_all_gather(t, all_gather_dimension) + torch.testing.assert_close(t, expected) + + +def all_reduce_test_worker(index, tp_degree, distributed_init_method): + init_distributed_environment(tp_degree, + index, + distributed_init_method, + index, + backend="xla") + ensure_model_parallel_initialized(tp_degree, 1) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1) + for r in range(tp_degree) + ] + expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + t = all_tensors[index % tp_degree] + t = tensor_model_parallel_all_reduce(t) + torch.testing.assert_close(t, expected) + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("test_target", + [all_reduce_test_worker, all_gather_test_worker]) +@reinitialize_neuron_runtime +def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size, + test_target): + + with patch('torch_xla._XLAC._xla_runtime_is_initialized', + return_value=False): + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + monkeypatch.setenv("VLLM_USE_V1", "1") + monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size)) + monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES", + ','.join(['1' for _ in range(tp_size)])) + + xmp.spawn(test_target, args=(tp_size, distributed_init_method)) diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py new file mode 100644 index 000000000000..dfa4b5194bdb --- /dev/null +++ b/vllm/distributed/device_communicators/neuron_communicator.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.platforms import current_platform + +if current_platform.is_neuron(): + import torch_xla.core.xla_model as xm + + +class NeuronCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "Neuron only supports dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index b2eadb7932f3..c1f426e5b880 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional +from vllm import envs from vllm.logger import init_logger from .interface import Platform, PlatformEnum @@ -56,6 +57,13 @@ def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + @classmethod + def get_device_communicator_cls(cls) -> str: + if envs.VLLM_USE_V1: + return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa + else: + return Platform.get_device_communicator_cls() + @classmethod def use_all_gather(cls) -> bool: return True