Skip to content

Commit 7626844

Browse files
gshtrasAkshat-Tripathi
authored andcommitted
[ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm (vllm-project#13231)
1 parent 3ffae46 commit 7626844

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

vllm/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
VLLM_SKIP_P2P_CHECK: bool = False
7575
VLLM_DISABLED_KERNELS: List[str] = []
7676
VLLM_USE_V1: bool = False
77+
VLLM_ROCM_FP8_PADDING: bool = True
7778
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
7879
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
7980
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -507,6 +508,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
507508
"VLLM_USE_V1":
508509
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
509510

511+
# Pad the fp8 weights to 256 bytes for ROCm
512+
"VLLM_ROCM_FP8_PADDING":
513+
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
510514
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
511515
"K_SCALE_CONSTANT":
512516
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, Dict, List, Optional
44

55
import torch
6+
import torch.nn.functional as F
67
from torch.nn import Module
78
from torch.nn.parameter import Parameter
89

@@ -251,6 +252,17 @@ def create_weights(
251252
else:
252253
layer.register_parameter("input_scale", None)
253254

255+
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
256+
# Pad the weight tensor. This is an optimization on ROCm platform, which
257+
# can benefit from tensors located far enough from one another in memory
258+
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
259+
and weight.stride(-1) == 1
260+
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
261+
num_pad = 256 // weight.element_size()
262+
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
263+
torch.cuda.empty_cache()
264+
return weight
265+
254266
def process_weights_after_loading(self, layer: Module) -> None:
255267
# TODO(rob): refactor block quant into separate class.
256268
if self.block_quant:
@@ -264,6 +276,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
264276
weight = layer.weight.data
265277
weight_scale_inv = layer.weight_scale_inv.data
266278

279+
weight = self.add_padding_to_weight(weight)
280+
267281
# Torch.compile cannot use Parameter subclasses.
268282
layer.weight = Parameter(weight, requires_grad=False)
269283
layer.weight_scale_inv = Parameter(weight_scale_inv,
@@ -327,6 +341,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
327341
logical_widths=layer.logical_widths,
328342
)
329343

344+
weight = self.add_padding_to_weight(weight)
330345
# Update layer with new values.
331346
layer.weight = Parameter(weight.t(), requires_grad=False)
332347
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul(
494494
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
495495
M = A.numel() // A.shape[-1]
496496

497-
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
497+
assert B.ndim == 2 and Bs.ndim == 2
498498
N, K = B.shape
499499
assert triton.cdiv(N, block_n) == Bs.shape[0]
500500
assert triton.cdiv(K, block_k) == Bs.shape[1]

0 commit comments

Comments
 (0)