Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/neuron/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
from vllm.platforms import current_platform


@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
@pytest.mark.parametrize("num_tokens,d,dtype", [
(7, 512, torch.half),
(7, 512, torch.float),
(83, 512, torch.half),
])
@torch.inference_mode()
def test_act_and_mul(
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
if activation == "silu_and_mul":
layer = SiluAndMul()
fn = layer.forward_native
elif activation == "gelu_fast":
layer = FastGELU()
fn = F.gelu
else:
raise NotImplementedError(
f"activation {activation} is not implemented.")
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device).forward_neuron(x)
ref_out = fn(x.cpu())
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)
56 changes: 56 additions & 0 deletions tests/neuron/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform


@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
(7, 8, False, torch.half),
(83, 768, False, torch.half),
(83, 768, True, torch.half),
(83, 768, True, torch.bfloat16),
(83, 768, True, torch.float32),
])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

residual_cpu = residual.cpu() if add_residual else None
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device)(x, residual)

# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if add_residual:
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out[0].cpu(),
ref_out[0],
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out[1].cpu(),
ref_out[1],
atol=1e-2,
rtol=1e-2)
else:
assert out.is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
95 changes: 95 additions & 0 deletions tests/neuron/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Tuple
from unittest.mock import patch

import pytest
import torch

from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available


class MockLogitsProcessor(LogitsProcessor):

def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()

def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
return input_tensor, fake_logits, logits_processor


RANDOM_SEEDS = list(range(8))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_logits_processors(seed: int):
import torch_xla.core.xla_model as xm

device = xm.xla_device()
set_random_seed(seed)
torch.set_default_device("cpu")
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

fake_logits *= logits_processor.scale
torch.testing.assert_close(logits_processor_output[:, 1],
fake_logits[:, 1],
rtol=1e-4,
atol=0.0)
7 changes: 4 additions & 3 deletions tests/neuron/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def test_contexted_kv_attention(

torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
torch.set_default_device("cpu")
dtype = torch.float32

min_ctx_len = 32
Expand Down Expand Up @@ -438,9 +439,9 @@ def pad_to_next_power_of_2(a):

# transform block table
active_block_table = get_active_block_tables(
block_table,
torch.tensor(query_lens),
torch.tensor(seq_lens),
block_table.cpu(),
torch.tensor(query_lens).cpu(),
torch.tensor(seq_lens).cpu(),
block_size,
num_active_blocks,
)
Expand Down
58 changes: 58 additions & 0 deletions tests/neuron/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""

import pytest
import torch

from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform


@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
(16, False, 32, 32, 1024),
(16, False, 32, 128, 1024),
(16, True, 32, 32, 1024),
(16, True, 32, 128, 1024),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len):
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")

batch_size = 1
base = 10000
num_heads = 8

rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)

positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query)

assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device))
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)
7 changes: 7 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def forward_hpu(self, *args, **kwargs):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_neuron(self, *args, **kwargs):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
Expand Down Expand Up @@ -88,6 +93,8 @@ def dispatch_forward(self):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree():
return self.forward_oot
else:
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
self.op(out, x)
return out

def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
result = s * x_reshaped[:, d:]
return result.view(*x.shape[:-1], d)


@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or current_platform.is_neuron() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa

Expand Down
Loading