Skip to content

Commit 0b439f7

Browse files
committed
fix format issue
Signed-off-by: Yu Gong <[email protected]>
1 parent 22faf7e commit 0b439f7

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,19 @@ def add_lora_fused_moe(
389389
top_k_num,
390390
lora_ids,
391391
adapter_enabled,
392-
shrink_config.get("BLOCK_SIZE_M", 64),
393-
shrink_config.get("BLOCK_SIZE_N", 64),
394-
shrink_config.get("BLOCK_SIZE_K", 32),
395-
shrink_config.get("GROUP_SIZE_M", 8),
396-
shrink_config.get("NUM_WARPS", 4),
397-
shrink_config.get("NUM_STAGES", 3),
398-
shrink_config.get("SPLIT_K", 1),
399-
expand_config.get("BLOCK_SIZE_M", 64),
400-
expand_config.get("BLOCK_SIZE_N", 64),
401-
expand_config.get("BLOCK_SIZE_K", 64),
402-
expand_config.get("GROUP_SIZE_M", 64),
403-
expand_config.get("NUM_WARPS", 4),
404-
expand_config.get("NUM_STAGES", 3),
405-
expand_config.get("SPLIT_K", 1),
392+
shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m") or 64,
393+
shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n") or 64,
394+
shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k") or 32,
395+
shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m") or 8,
396+
shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps") or 4,
397+
shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages") or 3,
398+
shrink_config.get("SPLIT_K") or shrink_config.get("split_k") or 1,
399+
expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m") or 64,
400+
expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n") or 64,
401+
expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k") or 64,
402+
expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m") or 64,
403+
expand_config.get("NUM_WARPS") or expand_config.get("num_warps") or 4,
404+
expand_config.get("NUM_STAGES") or expand_config.get("num_stages") or 3,
405+
expand_config.get("SPLIT_K") or expand_config.get("split_k") or 1,
406406
mul_routed_weight,
407407
)

0 commit comments

Comments
 (0)