Skip to content

Conversation

@xw285cornell
Copy link
Contributor

@xw285cornell xw285cornell commented Apr 8, 2025

This PR improves the device name handling, and add tuning files for llama4 Maverick.

For OAM, amdsmi amdsmi_get_gpu_asic_info returns MI300X-O which is an abbreviation. Change to amdsmi_get_gpu_board_info which seems to be a more reliable source of naming. Need to confirm with AMD if that applies to all MI300 SKUs.

When running the benchmark_moe script, it returns invalid device ordinal. This is because ray is setting up the ROCR_VISIBLE_DEVICES correctly so each subprocess can only see on device. So we'll first check if ROCR_VISIBLE_DEVICES is set - if so we'll skip the torch.cuda.device context manager.

And finally add the missing tuning file for 128 expert Maverick llama4 model.

Improving #16114

@github-actions
Copy link

github-actions bot commented Apr 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From our past experiments, we found amdsmi_get_gpu_asic_info()["market_name"] to be more reliable across a set of different MI Instinct machines. Hence, we should stick to the previous implementation for distinguishing AMD GPU names.
Also, if we make this change, we might also need to update the names of existing tuned configs for AMD gpus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divakar-amd for our SKU it returns MI300-O which is not really usable. Can you help dump some output between board info and asic info and we'll see which one is better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we can just introduce another mapping or rules to map to a generic name?

for example MI300X-O => MI300X

MI300X => MI300X?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some SKUs we cannot talk about that would break if you do that.

Copy link
Contributor

@divakar-amd divakar-amd Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xw285cornell @houseroad Let's retain the name as it is for your SKU (i.e. MI300X-O) and stick to amdsmi_get_gpu_asic_info. We'll push another config too with MI300X. So, we'll have both MI300X-O and MI300X

Copy link
Contributor Author

@xw285cornell xw285cornell Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divakar-amd can we get on the slack channel and discuss a solution? I don't really have a strong opinion on asic vs board, but duplicate the config file seems really ugly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it’s ugly if they are just duplicated, and may fail on another type of similar ASIC

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering the old approach - blindly setting guard, is there any problem with it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, with ROCR_VISIBLE_DEVICES, we can only see 1 device, and the device guard will use deviceX (X >=1) and this will fail

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a good potential fix and can be used to remove dependency on the ENV variable RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1

ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1

However, it comes with a caveat - this would require the users to be mindful of HIP_VISIBLE_DEVICES vs ROCR_VISIBLE_DEVICES; between the two, HIP_VISIBLE_DEVICES is more commonly used.
For example: if the HIP_VISIBLE_DEVICES is set in the env, this PR would throw the following error:

RuntimeError: HIP_VISIBLE_DEVICES contains more devices than ROCR_VISIBLE_DEVICES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels more like a Ray problem that it probably shouldn't set ROCR_VISIBLE_DEVICES. Or, set ROCR_VISIBLE_DEVICES based on HIP_VISIBLE_DEVICES. There are people not using docker and install from source and will hit this problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could force HIP_VISIBLE_DEVICES to be the same as ROCR_VISIBLE_DEVICE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should handle HIP_VISIBLE_DEVICES as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's add a guard which avoids any mismatch between HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super clear to me how to add the guard - the check happens at import time when the ray worker starts. So I just deleted the HIP_VISIBLE_DEVICES env var - I don't think ray will handle it anyway (it'll always use 8 GPU regardless of HIP_VISIBLE_DEVICES). Let me know what you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The map approach is much cleaner. :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to include these too in the map

		"0x74a0": "MI300A",
		"0x74a1": "MI300X",
		"0x74b5": "MI300X", // MI300X VF
		"0x74a5": "MI325X",
		"0x74b9": "MI325X", // MI325X VF
		"0x74a9": "MI300X-HF",
		"0x74bd": "MI300X-HF",

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Can we add a log message.
  • Also, lets use the value of HIP_VISIBLE_DEVICES to set the ROCR_VISIBLE_DEVICES. This would allow us to expose the number of GPUs for tuning correctly. e.g. if you only want to tune it over 4 gpus

Something like

logger.warning(
    "Removing HIP_VISIBLE_DEVICES. Using ROCR_VISIBLE_DEVICES "
    "for GPU visibility for Ray."
)
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@xw285cornell
Copy link
Contributor Author

@divakar-amd fixed if you want to take a look again

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Will temporarily put on hold until internal ROCm upgrade is done, sorry about the inconvenience.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) accessibility . -> accessibility.

@houseroad houseroad self-requested a review May 1, 2025 19:03
@houseroad
Copy link
Collaborator

@xw285cornell could you rebase again? We should be good to merge this PR :-)

@xw285cornell
Copy link
Contributor Author

sounds good, let me rebase

@xw285cornell
Copy link
Contributor Author

done, rebased :)

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Could you address the lint?

Signed-off-by: Lu Fang <[email protected]>
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I lint it :-)

@houseroad houseroad added rocm Related to AMD ROCm ready ONLY add when PR is ready to merge/full CI is needed labels May 2, 2025
@xw285cornell
Copy link
Contributor Author

oh sorry didnt' notice the lint error, thanks!

@houseroad houseroad enabled auto-merge (squash) May 2, 2025 18:29
@houseroad houseroad merged commit 9352cdb into vllm-project:main May 2, 2025
72 of 73 checks passed
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants