From c4ed25d065d1aa6d6529f7b81cc4423abb75de7d Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Tue, 25 Feb 2025 00:56:11 +0000 Subject: [PATCH] [Neuron] add custom_ops for neuron backend Co-authored-by: George Novack Co-authored-by: Aoyu Zhang Signed-off-by: Liangfu Chen --- tests/neuron/test_activation.py | 42 ++++++++ tests/neuron/test_layernorm.py | 56 +++++++++++ tests/neuron/test_logits_processor.py | 95 +++++++++++++++++++ tests/neuron/test_prefix_prefill.py | 7 +- tests/neuron/test_rotary_embedding.py | 58 +++++++++++ vllm/model_executor/custom_op.py | 7 ++ vllm/model_executor/layers/activation.py | 7 ++ .../model_executor/layers/logits_processor.py | 1 + .../model_executor/layers/rotary_embedding.py | 76 +++++++++++++++ 9 files changed, 346 insertions(+), 3 deletions(-) create mode 100644 tests/neuron/test_activation.py create mode 100644 tests/neuron/test_layernorm.py create mode 100644 tests/neuron/test_logits_processor.py create mode 100644 tests/neuron/test_rotary_embedding.py diff --git a/tests/neuron/test_activation.py b/tests/neuron/test_activation.py new file mode 100644 index 000000000000..ec2b1238e404 --- /dev/null +++ b/tests/neuron/test_activation.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.activation import FastGELU, SiluAndMul +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) +@pytest.mark.parametrize("num_tokens,d,dtype", [ + (7, 512, torch.half), + (7, 512, torch.float), + (83, 512, torch.half), +]) +@torch.inference_mode() +def test_act_and_mul( + activation: str, + num_tokens: int, + d: int, + dtype: torch.dtype, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) + if activation == "silu_and_mul": + layer = SiluAndMul() + fn = layer.forward_native + elif activation == "gelu_fast": + layer = FastGELU() + fn = F.gelu + else: + raise NotImplementedError( + f"activation {activation} is not implemented.") + assert x.is_xla, "input tensor under testing is expected to be XLA tensor." + out = layer.to(device=device).forward_neuron(x) + ref_out = fn(x.cpu()) + torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) diff --git a/tests/neuron/test_layernorm.py b/tests/neuron/test_layernorm.py new file mode 100644 index 000000000000..e96df8db6ccd --- /dev/null +++ b/tests/neuron/test_layernorm.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ + (7, 8, False, torch.half), + (83, 768, False, torch.half), + (83, 768, True, torch.half), + (83, 768, True, torch.bfloat16), + (83, 768, True, torch.float32), +]) +@torch.inference_mode() +def test_rms_norm( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + residual_cpu = residual.cpu() if add_residual else None + ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) + assert x.is_xla, "input tensor under testing is expected to be XLA tensor." + out = layer.to(device=device)(x, residual) + + # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger + # numerical errors than other operators because they involve reductions. + # Therefore, we use a larger tolerance. + if add_residual: + assert out[0].is_xla, "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out[0].cpu(), + ref_out[0], + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(out[1].cpu(), + ref_out[1], + atol=1e-2, + rtol=1e-2) + else: + assert out.is_xla, "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/neuron/test_logits_processor.py b/tests/neuron/test_logits_processor.py new file mode 100644 index 000000000000..37d59c9e76a7 --- /dev/null +++ b/tests/neuron/test_logits_processor.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from typing import Tuple +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available + + +class MockLogitsProcessor(LogitsProcessor): + + def __init__(self, vocab_size: int, scale: float, + fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size, scale=scale) + self.fake_logits = fake_logits.clone() + + def forward(self, *args, **kwargs): + with patch( + "vllm.model_executor.layers.logits_processor._prune_hidden_states", + lambda x, y: x + ), patch( + "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=input_tensor.dtype) + logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) + return input_tensor, fake_logits, logits_processor + + +RANDOM_SEEDS = list(range(8)) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_logits_processors(seed: int): + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + set_random_seed(seed) + torch.set_default_device("cpu") + batch_size = random.randint(1, 256) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + seq_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) + logits_processor_output = logits_processor( + lm_head=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + fake_logits *= logits_processor.scale + torch.testing.assert_close(logits_processor_output[:, 1], + fake_logits[:, 1], + rtol=1e-4, + atol=0.0) diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py index 347a139f39b4..2c6ac47888d5 100644 --- a/tests/neuron/test_prefix_prefill.py +++ b/tests/neuron/test_prefix_prefill.py @@ -345,6 +345,7 @@ def test_contexted_kv_attention( torch.manual_seed(0) torch.set_printoptions(sci_mode=False) + torch.set_default_device("cpu") dtype = torch.float32 min_ctx_len = 32 @@ -438,9 +439,9 @@ def pad_to_next_power_of_2(a): # transform block table active_block_table = get_active_block_tables( - block_table, - torch.tensor(query_lens), - torch.tensor(seq_lens), + block_table.cpu(), + torch.tensor(query_lens).cpu(), + torch.tensor(seq_lens).cpu(), block_size, num_active_blocks, ) diff --git a/tests/neuron/test_rotary_embedding.py b/tests/neuron/test_rotary_embedding.py new file mode 100644 index 000000000000..c015b80bd472 --- /dev/null +++ b/tests/neuron/test_rotary_embedding.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for miscellaneous utilities +""" + +import pytest +import torch + +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + + +@pytest.mark.parametrize( + "max_position,is_neox_style,rotary_dim,head_size,seq_len", [ + (16, False, 32, 32, 1024), + (16, False, 32, 128, 1024), + (16, True, 32, 32, 1024), + (16, True, 32, 128, 1024), + ]) +def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, + head_size, seq_len): + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + current_platform.seed_everything(0) + torch.set_default_device("cpu") + + batch_size = 1 + base = 10000 + num_heads = 8 + + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device="cpu") + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=torch.float32, + device="cpu") + key = torch.randn_like(query) + + assert positions.is_cpu, \ + "reference input tensor is expected to be CPU tensor." + ref_query, ref_key = rot.to(device="cpu").forward_native( + positions, query, key) + out_query, out_key = rot.to(device=device).forward_neuron( + positions.to(device=device), query.to(device=device), + key.to(device=device)) + assert out_query.is_xla and out_key.is_xla, \ + "output tensor is expected to be XLA tensor" + torch.testing.assert_close(out_query.cpu(), + ref_query, + atol=1e-2, + rtol=1e-2) + torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index ee4f41ea6ec9..dfd052f62521 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -59,6 +59,11 @@ def forward_hpu(self, *args, **kwargs): # PyTorch-native implementation. return self.forward_native(*args, **kwargs) + def forward_neuron(self, *args, **kwargs): + # By default, we assume that Neuron ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + def forward_oot(self, *args, **kwargs): # By default, we assume that OOT ops are compatible with the # PyTorch-native implementation. @@ -88,6 +93,8 @@ def dispatch_forward(self): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu + elif current_platform.is_neuron(): + return self.forward_neuron elif current_platform.is_out_of_tree(): return self.forward_oot else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f782920d06a0..1de0f499c1a6 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -89,6 +89,13 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: self.op(out, x) return out + def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) + result = s * x_reshaped[:, d:] + return result.view(*x.shape[:-1], d) + @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 9b1742998578..2f39a0e87854 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -53,6 +53,7 @@ def __init__(self, # Whether to use gather or all-gather to gather the logits. parallel_config = get_current_vllm_config().parallel_config self.use_all_gather = current_platform.is_tpu() \ + or current_platform.is_neuron() \ or envs.VLLM_USE_V1 \ or parallel_config.distributed_executor_backend == "external_launcher" # noqa diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index ce1bc98ea426..64c2dac524f2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -254,6 +254,82 @@ def forward_hpu( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_neuron( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _apply_rotary_emb_neuron( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + # x1 = x[..., ::2] + + # x2 = x[..., 1::2] + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) + x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + if offsets is not None: + positions = positions + offsets + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + + if self.rotary_dim == self.head_size: + query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) + query = query.reshape(query_shape) + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) + else: + head_size = query.shape[-1] + query_reshaped = query.view(-1, head_size) + query_pass = query_reshaped[:, self.rotary_dim:].view( + *query.shape[:-1], head_size - self.rotary_dim) + query_rot = query_reshaped[:, :self.rotary_dim].view( + *query.shape[:-1], self.rotary_dim) + query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), + dim=-1).reshape(query_shape) + + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}"