diff --git a/vllm/envs.py b/vllm/envs.py index 45547416314f..8a8ad3a88e13 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -503,6 +504,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # Pad the fp8 weights to 256 bytes for ROCm + "VLLM_ROCM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Divisor for dynamic key scale factor calculation for FP8 KV Cache "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fe8ff7ca5e12..1ca39b0ffa82 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -251,6 +252,17 @@ def create_weights( else: layer.register_parameter("input_scale", None) + def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: @@ -264,6 +276,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight.data weight_scale_inv = layer.weight_scale_inv.data + weight = self.add_padding_to_weight(weight) + # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, @@ -327,6 +341,7 @@ def process_weights_after_loading(self, layer: Module) -> None: logical_widths=layer.logical_widths, ) + weight = self.add_padding_to_weight(weight) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9895537c219a..cd584b62eefb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -477,7 +477,7 @@ def w8a8_block_fp8_matmul( assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert B.ndim == 2 and Bs.ndim == 2 N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1]