-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Neuron] Add custom_ops for neuron backend #13246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from vllm.model_executor.layers.activation import FastGELU, SiluAndMul | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) | ||
| @pytest.mark.parametrize("num_tokens,d,dtype", [ | ||
| (7, 512, torch.half), | ||
| (7, 512, torch.float), | ||
| (83, 512, torch.half), | ||
| ]) | ||
| @torch.inference_mode() | ||
| def test_act_and_mul( | ||
| activation: str, | ||
| num_tokens: int, | ||
| d: int, | ||
| dtype: torch.dtype, | ||
| ) -> None: | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| device = xm.xla_device() | ||
| current_platform.seed_everything(0) | ||
| torch.set_default_device("cpu") | ||
| x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) | ||
| if activation == "silu_and_mul": | ||
| layer = SiluAndMul() | ||
| fn = layer.forward_native | ||
| elif activation == "gelu_fast": | ||
| layer = FastGELU() | ||
| fn = F.gelu | ||
| else: | ||
| raise NotImplementedError( | ||
| f"activation {activation} is not implemented.") | ||
| assert x.is_xla, "input tensor under testing is expected to be XLA tensor." | ||
| out = layer.to(device=device).forward_neuron(x) | ||
| ref_out = fn(x.cpu()) | ||
| torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ | ||
| (7, 8, False, torch.half), | ||
| (83, 768, False, torch.half), | ||
| (83, 768, True, torch.half), | ||
| (83, 768, True, torch.bfloat16), | ||
| (83, 768, True, torch.float32), | ||
| ]) | ||
| @torch.inference_mode() | ||
| def test_rms_norm( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| add_residual: bool, | ||
| dtype: torch.dtype, | ||
| ) -> None: | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| device = xm.xla_device() | ||
| current_platform.seed_everything(0) | ||
| torch.set_default_device("cpu") | ||
| layer = RMSNorm(hidden_size).to(dtype=dtype) | ||
| layer.weight.data.normal_(mean=1.0, std=0.1) | ||
| scale = 1 / (2 * hidden_size) | ||
| x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) | ||
| x *= scale | ||
| residual = torch.randn_like(x) * scale if add_residual else None | ||
|
|
||
| residual_cpu = residual.cpu() if add_residual else None | ||
| ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) | ||
| assert x.is_xla, "input tensor under testing is expected to be XLA tensor." | ||
| out = layer.to(device=device)(x, residual) | ||
|
|
||
| # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger | ||
| # numerical errors than other operators because they involve reductions. | ||
| # Therefore, we use a larger tolerance. | ||
| if add_residual: | ||
| assert out[0].is_xla, "output tensor is expected to be XLA tensor" | ||
| torch.testing.assert_close(out[0].cpu(), | ||
| ref_out[0], | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
| torch.testing.assert_close(out[1].cpu(), | ||
| ref_out[1], | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
| else: | ||
| assert out.is_xla, "output tensor is expected to be XLA tensor" | ||
| torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import random | ||
| from typing import Tuple | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
| from vllm.model_executor.utils import set_random_seed | ||
| from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata | ||
| from vllm.utils import is_pin_memory_available | ||
|
|
||
|
|
||
| class MockLogitsProcessor(LogitsProcessor): | ||
|
|
||
| def __init__(self, vocab_size: int, scale: float, | ||
| fake_logits: torch.Tensor): | ||
| super().__init__(vocab_size=vocab_size, scale=scale) | ||
| self.fake_logits = fake_logits.clone() | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| with patch( | ||
| "vllm.model_executor.layers.logits_processor._prune_hidden_states", | ||
| lambda x, y: x | ||
| ), patch( | ||
| "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", | ||
| lambda *args, **kwargs: self.fake_logits): | ||
| return super().forward(*args, **kwargs) | ||
|
|
||
|
|
||
| def _prepare_test( | ||
| batch_size: int | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: | ||
| vocab_size = 32000 | ||
| input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) | ||
| fake_logits = torch.full((batch_size, vocab_size), | ||
| 1e-2, | ||
| dtype=input_tensor.dtype) | ||
| logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) | ||
| return input_tensor, fake_logits, logits_processor | ||
|
|
||
|
|
||
| RANDOM_SEEDS = list(range(8)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", RANDOM_SEEDS) | ||
| def test_logits_processors(seed: int): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| device = xm.xla_device() | ||
| set_random_seed(seed) | ||
| torch.set_default_device("cpu") | ||
| batch_size = random.randint(1, 256) | ||
| input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) | ||
|
|
||
| # This sample logits processor gives infinite score to the i-th token, | ||
| # where i is the length of the input sequence. | ||
| # We therefore expect the output token sequence to be [0, 1, 2, ...] | ||
| def pick_ith(token_ids, logits): | ||
| logits[len(token_ids)] = float("inf") | ||
| return logits | ||
|
|
||
| seq_group_metadata_list = [] | ||
| seq_lens = [] | ||
| for i in range(batch_size): | ||
| seq_group_metadata_list.append( | ||
| SequenceGroupMetadata( | ||
| request_id=f"test_{i}", | ||
| is_prompt=True, | ||
| seq_data={0: SequenceData.from_seqs([1, 2, 3])}, | ||
| sampling_params=SamplingParams(temperature=0, | ||
| logits_processors=[pick_ith]), | ||
| block_tables={0: [1]}, | ||
| )) | ||
| seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) | ||
|
|
||
| sampling_metadata = SamplingMetadata.prepare( | ||
| seq_group_metadata_list, | ||
| seq_lens, | ||
| query_lens=seq_lens, | ||
| device=device, | ||
| pin_memory=is_pin_memory_available()) | ||
| logits_processor_output = logits_processor( | ||
| lm_head=None, | ||
| hidden_states=input_tensor, | ||
| sampling_metadata=sampling_metadata) | ||
|
|
||
| fake_logits *= logits_processor.scale | ||
| torch.testing.assert_close(logits_processor_output[:, 1], | ||
| fake_logits[:, 1], | ||
| rtol=1e-4, | ||
| atol=0.0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| Tests for miscellaneous utilities | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "max_position,is_neox_style,rotary_dim,head_size,seq_len", [ | ||
| (16, False, 32, 32, 1024), | ||
| (16, False, 32, 128, 1024), | ||
| (16, True, 32, 32, 1024), | ||
| (16, True, 32, 128, 1024), | ||
| ]) | ||
| def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, | ||
| head_size, seq_len): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| device = xm.xla_device() | ||
| current_platform.seed_everything(0) | ||
| torch.set_default_device("cpu") | ||
|
|
||
| batch_size = 1 | ||
| base = 10000 | ||
| num_heads = 8 | ||
|
|
||
| rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, | ||
| is_neox_style, torch.float32) | ||
|
|
||
| positions = torch.randint(0, | ||
| max_position, (batch_size, seq_len), | ||
| device="cpu") | ||
| query = torch.randn(batch_size, | ||
| seq_len, | ||
| num_heads * head_size, | ||
| dtype=torch.float32, | ||
| device="cpu") | ||
| key = torch.randn_like(query) | ||
|
|
||
| assert positions.is_cpu, \ | ||
| "reference input tensor is expected to be CPU tensor." | ||
| ref_query, ref_key = rot.to(device="cpu").forward_native( | ||
| positions, query, key) | ||
| out_query, out_key = rot.to(device=device).forward_neuron( | ||
| positions.to(device=device), query.to(device=device), | ||
| key.to(device=device)) | ||
| assert out_query.is_xla and out_key.is_xla, \ | ||
| "output tensor is expected to be XLA tensor" | ||
| torch.testing.assert_close(out_query.cpu(), | ||
| ref_query, | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
| torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.