-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Hardware][AMD] Improve OAM device ID + llama4 Maverick MOE tuning #16263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
7d87931
d20e3cd
b0e5f05
72637dc
b65a66b
0194b95
6e81b48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -442,8 +442,13 @@ def tune( | |
| hidden_size, search_space, | ||
| is_fp16, topk) | ||
|
|
||
| with torch.cuda.device(self.device_id) if current_platform.is_rocm( | ||
| ) else nullcontext(): | ||
| need_device_guard = False | ||
| if current_platform.is_rocm(): | ||
| visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None) | ||
| if visible_device != f"{self.device_id}": | ||
| need_device_guard = True | ||
|
|
||
| with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): | ||
| for config in tqdm(search_space): | ||
| try: | ||
| kernel_time = benchmark_config( | ||
|
|
@@ -578,6 +583,16 @@ def main(args: argparse.Namespace): | |
|
|
||
| use_deep_gemm = bool(args.use_deep_gemm) | ||
|
|
||
| if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: | ||
| # Ray will set ROCR_VISIBLE_DEVICES for device visibility | ||
| logger.warning( | ||
| "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility ." | ||
|
||
| "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." | ||
| ) | ||
| val = os.environ["HIP_VISIBLE_DEVICES"] | ||
| os.environ["ROCR_VISIBLE_DEVICES"] = val | ||
| del os.environ["HIP_VISIBLE_DEVICES"] | ||
|
||
|
|
||
| ray.init() | ||
| num_gpus = int(ray.available_resources()["GPU"]) | ||
| workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| { | ||
| "1": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 16, | ||
| "BLOCK_SIZE_K": 256, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "2": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 16, | ||
| "BLOCK_SIZE_K": 256, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "4": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 16, | ||
| "BLOCK_SIZE_K": 256, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 2, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "8": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "16": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "24": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "32": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "48": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 2, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "64": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 4, | ||
| "num_warps": 2, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "96": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 32, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 1, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "128": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 4, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "256": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 32, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 2, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "512": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 8, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "1024": { | ||
| "BLOCK_SIZE_M": 16, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 1, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "1536": { | ||
| "BLOCK_SIZE_M": 32, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 8, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "2048": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 64, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 16, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| }, | ||
| "3072": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 8, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 1 | ||
| }, | ||
| "4096": { | ||
| "BLOCK_SIZE_M": 64, | ||
| "BLOCK_SIZE_N": 128, | ||
| "BLOCK_SIZE_K": 128, | ||
| "GROUP_SIZE_M": 32, | ||
| "num_warps": 8, | ||
| "num_stages": 2, | ||
| "waves_per_eu": 0, | ||
| "matrix_instr_nonkdim": 16, | ||
| "kpack": 2 | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering the old approach - blindly setting guard, is there any problem with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, with ROCR_VISIBLE_DEVICES, we can only see 1 device, and the device guard will use deviceX (X >=1) and this will fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a good potential fix and can be used to remove dependency on the ENV variable
RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1vllm/docker/Dockerfile.rocm
Line 114 in e1a2c69
However, it comes with a caveat - this would require the users to be mindful of
HIP_VISIBLE_DEVICESvsROCR_VISIBLE_DEVICES; between the two,HIP_VISIBLE_DEVICESis more commonly used.For example: if the
HIP_VISIBLE_DEVICESis set in the env, this PR would throw the following error:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels more like a Ray problem that it probably shouldn't set ROCR_VISIBLE_DEVICES. Or, set ROCR_VISIBLE_DEVICES based on HIP_VISIBLE_DEVICES. There are people not using docker and install from source and will hit this problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could force HIP_VISIBLE_DEVICES to be the same as ROCR_VISIBLE_DEVICE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should handle HIP_VISIBLE_DEVICES as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's add a guard which avoids any mismatch between HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not super clear to me how to add the guard - the check happens at import time when the ray worker starts. So I just deleted the HIP_VISIBLE_DEVICES env var - I don't think ray will handle it anyway (it'll always use 8 GPU regardless of HIP_VISIBLE_DEVICES). Let me know what you think