Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 24 additions & 49 deletions tests/lora/test_punica_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from unittest.mock import patch

import pytest
import torch

Expand All @@ -16,7 +14,6 @@
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry

from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
Expand Down Expand Up @@ -235,9 +232,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)

Expand All @@ -262,33 +256,21 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
Expand Down Expand Up @@ -324,7 +306,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)
Expand Down Expand Up @@ -374,22 +355,16 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)

bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
Expand Down
73 changes: 24 additions & 49 deletions tests/lora/test_punica_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from unittest.mock import patch

import pytest
import torch

Expand All @@ -15,7 +13,6 @@
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry

from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
Expand Down Expand Up @@ -150,8 +147,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)
Expand All @@ -177,33 +172,22 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)

bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
Expand Down Expand Up @@ -239,8 +223,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)

Expand Down Expand Up @@ -289,22 +271,15 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_expand_slice_kernel(
input_ptr,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
Expand Down
3 changes: 1 addition & 2 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@

from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.libentry import libentry

__all__ += ["maybe_set_triton_cache_manager", "libentry"]
__all__ += ["maybe_set_triton_cache_manager"]
Loading