33from typing import Any , Callable , Dict , List , Optional
44
55import torch
6+ import torch .nn .functional as F
67from torch .nn import Module
78from torch .nn .parameter import Parameter
89
@@ -251,6 +252,17 @@ def create_weights(
251252 else :
252253 layer .register_parameter ("input_scale" , None )
253254
255+ def add_padding_to_weight (self , weight : torch .Tensor ) -> torch .Tensor :
256+ # Pad the weight tensor. This is an optimization on ROCm platform, which
257+ # can benefit from tensors located far enough from one another in memory
258+ if (envs .VLLM_ROCM_FP8_PADDING and current_platform .is_rocm ()
259+ and weight .stride (- 1 ) == 1
260+ and (weight .stride (- 2 ) * weight .element_size ()) % 512 == 0 ):
261+ num_pad = 256 // weight .element_size ()
262+ weight = F .pad (weight , (0 , num_pad ), "constant" , 0 )[..., :- num_pad ]
263+ torch .cuda .empty_cache ()
264+ return weight
265+
254266 def process_weights_after_loading (self , layer : Module ) -> None :
255267 # TODO(rob): refactor block quant into separate class.
256268 if self .block_quant :
@@ -264,6 +276,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
264276 weight = layer .weight .data
265277 weight_scale_inv = layer .weight_scale_inv .data
266278
279+ weight = self .add_padding_to_weight (weight )
280+
267281 # Torch.compile cannot use Parameter subclasses.
268282 layer .weight = Parameter (weight , requires_grad = False )
269283 layer .weight_scale_inv = Parameter (weight_scale_inv ,
@@ -327,6 +341,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
327341 logical_widths = layer .logical_widths ,
328342 )
329343
344+ weight = self .add_padding_to_weight (weight )
330345 # Update layer with new values.
331346 layer .weight = Parameter (weight .t (), requires_grad = False )
332347 layer .weight_scale = Parameter (weight_scale , requires_grad = False )
0 commit comments