Skip to content

Commit c42cdb0

Browse files
xw285cornellhouseroad
authored andcommitted
[Hardware][AMD] Improve OAM device ID + llama4 Maverick MOE tuning (vllm-project#16263)
Signed-off-by: Lu Fang <[email protected]> Co-authored-by: Lu Fang <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 7c43ea5 commit c42cdb0

File tree

3 files changed

+231
-3
lines changed

3 files changed

+231
-3
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,14 @@ def tune(
442442
hidden_size, search_space,
443443
is_fp16, topk)
444444

445-
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
446-
) else nullcontext():
445+
need_device_guard = False
446+
if current_platform.is_rocm():
447+
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
448+
if visible_device != f"{self.device_id}":
449+
need_device_guard = True
450+
451+
with torch.cuda.device(
452+
self.device_id) if need_device_guard else nullcontext():
447453
for config in tqdm(search_space):
448454
try:
449455
kernel_time = benchmark_config(
@@ -578,6 +584,15 @@ def main(args: argparse.Namespace):
578584

579585
use_deep_gemm = bool(args.use_deep_gemm)
580586

587+
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
588+
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
589+
logger.warning(
590+
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
591+
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
592+
val = os.environ["HIP_VISIBLE_DEVICES"]
593+
os.environ["ROCR_VISIBLE_DEVICES"] = val
594+
del os.environ["HIP_VISIBLE_DEVICES"]
595+
581596
ray.init()
582597
num_gpus = int(ray.available_resources()["GPU"])
583598
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 16,
5+
"BLOCK_SIZE_K": 256,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 2,
9+
"waves_per_eu": 0,
10+
"matrix_instr_nonkdim": 16,
11+
"kpack": 2
12+
},
13+
"2": {
14+
"BLOCK_SIZE_M": 16,
15+
"BLOCK_SIZE_N": 16,
16+
"BLOCK_SIZE_K": 256,
17+
"GROUP_SIZE_M": 1,
18+
"num_warps": 4,
19+
"num_stages": 2,
20+
"waves_per_eu": 0,
21+
"matrix_instr_nonkdim": 16,
22+
"kpack": 1
23+
},
24+
"4": {
25+
"BLOCK_SIZE_M": 16,
26+
"BLOCK_SIZE_N": 16,
27+
"BLOCK_SIZE_K": 256,
28+
"GROUP_SIZE_M": 1,
29+
"num_warps": 2,
30+
"num_stages": 2,
31+
"waves_per_eu": 0,
32+
"matrix_instr_nonkdim": 16,
33+
"kpack": 2
34+
},
35+
"8": {
36+
"BLOCK_SIZE_M": 16,
37+
"BLOCK_SIZE_N": 64,
38+
"BLOCK_SIZE_K": 128,
39+
"GROUP_SIZE_M": 1,
40+
"num_warps": 4,
41+
"num_stages": 2,
42+
"waves_per_eu": 0,
43+
"matrix_instr_nonkdim": 16,
44+
"kpack": 2
45+
},
46+
"16": {
47+
"BLOCK_SIZE_M": 16,
48+
"BLOCK_SIZE_N": 128,
49+
"BLOCK_SIZE_K": 128,
50+
"GROUP_SIZE_M": 1,
51+
"num_warps": 4,
52+
"num_stages": 2,
53+
"waves_per_eu": 0,
54+
"matrix_instr_nonkdim": 16,
55+
"kpack": 2
56+
},
57+
"24": {
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 64,
60+
"BLOCK_SIZE_K": 128,
61+
"GROUP_SIZE_M": 1,
62+
"num_warps": 4,
63+
"num_stages": 2,
64+
"waves_per_eu": 0,
65+
"matrix_instr_nonkdim": 16,
66+
"kpack": 1
67+
},
68+
"32": {
69+
"BLOCK_SIZE_M": 16,
70+
"BLOCK_SIZE_N": 128,
71+
"BLOCK_SIZE_K": 128,
72+
"GROUP_SIZE_M": 1,
73+
"num_warps": 8,
74+
"num_stages": 2,
75+
"waves_per_eu": 0,
76+
"matrix_instr_nonkdim": 16,
77+
"kpack": 2
78+
},
79+
"48": {
80+
"BLOCK_SIZE_M": 16,
81+
"BLOCK_SIZE_N": 64,
82+
"BLOCK_SIZE_K": 128,
83+
"GROUP_SIZE_M": 1,
84+
"num_warps": 2,
85+
"num_stages": 2,
86+
"waves_per_eu": 0,
87+
"matrix_instr_nonkdim": 16,
88+
"kpack": 2
89+
},
90+
"64": {
91+
"BLOCK_SIZE_M": 16,
92+
"BLOCK_SIZE_N": 64,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 4,
95+
"num_warps": 2,
96+
"num_stages": 2,
97+
"waves_per_eu": 0,
98+
"matrix_instr_nonkdim": 16,
99+
"kpack": 2
100+
},
101+
"96": {
102+
"BLOCK_SIZE_M": 16,
103+
"BLOCK_SIZE_N": 32,
104+
"BLOCK_SIZE_K": 128,
105+
"GROUP_SIZE_M": 1,
106+
"num_warps": 1,
107+
"num_stages": 2,
108+
"waves_per_eu": 0,
109+
"matrix_instr_nonkdim": 16,
110+
"kpack": 1
111+
},
112+
"128": {
113+
"BLOCK_SIZE_M": 16,
114+
"BLOCK_SIZE_N": 64,
115+
"BLOCK_SIZE_K": 128,
116+
"GROUP_SIZE_M": 1,
117+
"num_warps": 4,
118+
"num_stages": 2,
119+
"waves_per_eu": 0,
120+
"matrix_instr_nonkdim": 16,
121+
"kpack": 1
122+
},
123+
"256": {
124+
"BLOCK_SIZE_M": 16,
125+
"BLOCK_SIZE_N": 32,
126+
"BLOCK_SIZE_K": 128,
127+
"GROUP_SIZE_M": 1,
128+
"num_warps": 2,
129+
"num_stages": 2,
130+
"waves_per_eu": 0,
131+
"matrix_instr_nonkdim": 16,
132+
"kpack": 2
133+
},
134+
"512": {
135+
"BLOCK_SIZE_M": 16,
136+
"BLOCK_SIZE_N": 128,
137+
"BLOCK_SIZE_K": 128,
138+
"GROUP_SIZE_M": 8,
139+
"num_warps": 8,
140+
"num_stages": 2,
141+
"waves_per_eu": 0,
142+
"matrix_instr_nonkdim": 16,
143+
"kpack": 2
144+
},
145+
"1024": {
146+
"BLOCK_SIZE_M": 16,
147+
"BLOCK_SIZE_N": 128,
148+
"BLOCK_SIZE_K": 128,
149+
"GROUP_SIZE_M": 1,
150+
"num_warps": 8,
151+
"num_stages": 2,
152+
"waves_per_eu": 0,
153+
"matrix_instr_nonkdim": 16,
154+
"kpack": 2
155+
},
156+
"1536": {
157+
"BLOCK_SIZE_M": 32,
158+
"BLOCK_SIZE_N": 128,
159+
"BLOCK_SIZE_K": 128,
160+
"GROUP_SIZE_M": 8,
161+
"num_warps": 8,
162+
"num_stages": 2,
163+
"waves_per_eu": 0,
164+
"matrix_instr_nonkdim": 16,
165+
"kpack": 2
166+
},
167+
"2048": {
168+
"BLOCK_SIZE_M": 64,
169+
"BLOCK_SIZE_N": 64,
170+
"BLOCK_SIZE_K": 128,
171+
"GROUP_SIZE_M": 16,
172+
"num_warps": 8,
173+
"num_stages": 2,
174+
"waves_per_eu": 0,
175+
"matrix_instr_nonkdim": 16,
176+
"kpack": 2
177+
},
178+
"3072": {
179+
"BLOCK_SIZE_M": 64,
180+
"BLOCK_SIZE_N": 128,
181+
"BLOCK_SIZE_K": 128,
182+
"GROUP_SIZE_M": 8,
183+
"num_warps": 8,
184+
"num_stages": 2,
185+
"waves_per_eu": 0,
186+
"matrix_instr_nonkdim": 16,
187+
"kpack": 1
188+
},
189+
"4096": {
190+
"BLOCK_SIZE_M": 64,
191+
"BLOCK_SIZE_N": 128,
192+
"BLOCK_SIZE_K": 128,
193+
"GROUP_SIZE_M": 32,
194+
"num_warps": 8,
195+
"num_stages": 2,
196+
"waves_per_eu": 0,
197+
"matrix_instr_nonkdim": 16,
198+
"kpack": 2
199+
}
200+
}

vllm/platforms/rocm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@
5858
"excessive use of shared memory. If this happens, disable Triton FA "
5959
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
6060
}
61+
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = {
62+
"0x74a0": "AMD_Instinct_MI300A",
63+
"0x74a1": "AMD_Instinct_MI300X",
64+
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
65+
"0x74a5": "AMD_Instinct_MI325X",
66+
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
67+
"0x74a9": "AMD_Instinct_MI300X_HF",
68+
"0x74bd": "AMD_Instinct_MI300X_HF",
69+
}
6170

6271
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
6372
if "HIP_VISIBLE_DEVICES" in os.environ:
@@ -225,7 +234,11 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool:
225234
def get_device_name(cls, device_id: int = 0) -> str:
226235
physical_device_id = device_id_to_physical_device_id(device_id)
227236
handle = amdsmi_get_processor_handles()[physical_device_id]
228-
return amdsmi_get_gpu_asic_info(handle)["market_name"]
237+
asic_info = amdsmi_get_gpu_asic_info(handle)
238+
device_name: str = asic_info["device_id"]
239+
if device_name in _ROCM_DEVICE_ID_NAME_MAP:
240+
return _ROCM_DEVICE_ID_NAME_MAP[device_name]
241+
return asic_info["market_name"]
229242

230243
@classmethod
231244
def get_device_total_memory(cls, device_id: int = 0) -> int:

0 commit comments

Comments
 (0)