Skip to content

Commit 4ae2f54

Browse files
alexm-redhatMu Huai
authored andcommitted
[V1] TPU - Revert to exponential padding by default (vllm-project#15565)
Signed-off-by: Alexander Matveev <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 7e22baa commit 4ae2f54

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

vllm/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
100100
VLLM_V0_USE_OUTLINES_CACHE: bool = False
101101
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
102-
VLLM_TPU_BUCKET_PADDING_GAP: int = 64
102+
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
103103

104104

105105
def get_default_cache_root():
@@ -648,7 +648,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
648648
# 8, we will run forward pass with [16, 24, 32, ...].
649649
"VLLM_TPU_BUCKET_PADDING_GAP":
650650
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
651-
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64,
651+
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
652652
}
653653

654654
# end-env-vars-definition

vllm/v1/worker/tpu_model_runner.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -944,18 +944,35 @@ def _get_paddings(min_token_size: int, max_token_size: int,
944944
padding_gap: int) -> list[int]:
945945
"""Generate a list of padding size, starting from min_token_size,
946946
ending with a number that can cover max_token_size
947-
first increase the size to twice,
948-
then increase the padding size by padding_gap.
947+
948+
If padding_gap == 0 then:
949+
increase 2X each time (exponential)
950+
else:
951+
first increase the size to twice,
952+
then increase the padding size by padding_gap.
949953
"""
950954
paddings = []
951955
num = min_token_size
952-
while num <= padding_gap:
953-
paddings.append(num)
954-
num *= 2
955-
num //= 2
956-
while num < max_token_size:
957-
num += padding_gap
958-
paddings.append(num)
956+
957+
if padding_gap == 0:
958+
logger.info("Using exponential paddings:")
959+
while num <= max_token_size:
960+
logger.info(" %d", num)
961+
paddings.append(num)
962+
num *= 2
963+
964+
else:
965+
logger.info("Using incremental paddings:")
966+
while num <= padding_gap:
967+
logger.info(" %d", num)
968+
paddings.append(num)
969+
num *= 2
970+
num //= 2
971+
while num < max_token_size:
972+
num += padding_gap
973+
logger.info(" %d", num)
974+
paddings.append(num)
975+
959976
return paddings
960977

961978

0 commit comments

Comments
 (0)