-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
remove hardcoded device="cuda" to support more device
#2503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WoosukKwon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jikunshang Thanks for submitting the PR! While the code looks good overall, I have two concerns:
- I believe
deviceshould be automatically detected instead of being declared by users. This also aligns with our current design. While I don't know a good way to implement this is, I feel there should be a way to do it as long as the device is supported by PyTorch. - I believe
deviceshould not be an attribute ofModelConfig. Can we make a new config class likeDeviceConfig?
| pad=0, | ||
| dtype=torch.long) | ||
| dtype=torch.long, | ||
| device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the _set_default_torch_device context manager here to not repeat device=self.device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After second thoughts, I found that explicitly specifying the devices would be better as we might mix CPU and accelerator in some cases.
2df9f74 to
114a846
Compare
37ff8f3 to
88782e0
Compare
|
Hi @jikunshang, could you resolve the merge conflicts? The PR looks good overall. |
Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
Sure. I have resolved conflicts. All tests should have been fixed. |
| self.device = torch.device(torch.cuda.current_device()) | ||
| self.device_config = (device_config | ||
| if device_config is not None else DeviceConfig()) | ||
| self.device = self.device_config.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I intentionally avoided using torch.set_current_device, since this can affect the user code when using LLM class.
WoosukKwon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jikunshang LGTM! Thanks for submitting the PR and sorry for the delay in my second review.
While I think vLLM still has several torch.cuda calls, I believe this is a good first step towards supporting non-CUDA devices. Thanks for the great work!
…ct#2503) Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
…ct#2503) Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
…ct#2503) Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
…ct#2503) Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
…ct#2503) Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
Refer to #1948 , there are a lot of code use
cudaas device, especially in tensor creation, which is not friendly to add other device support. This PR aims to refactor the code to leave some interface for better and easily add new device likecpuorxpu