Skip to content

Conversation

@leizhenyuan
Copy link
Contributor

Hi unsloth, we are going to support unsloth intel GPU with several prs and this is the third pr.

  • add intel dependent packages for PyTorch 2.6 in pyproject.toml
  • generalize device types and refactor device-bias code in init.py
  • refactor device-bias code in kernels
  • refactor device-bias code for llama models

For the first step we are aiming to support several models with LoRA, and increase our feature in the future (including BNB, FlashAttention, xformers).

For this PR, we resolve device specific API for cuda and Intel GPU(XPU) for model utils and llama model
For cuda specific path, we didn't change the logics, only add check and tab to pass python grammar.

cc:

torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
elif DEVICE_TYPE == "xpu":
if Version(torch_version) < Version("2.4.0"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if Version(torch_version) < Version("2.4.0"):
if Version(torch_version) < Version("2.6.0"):

if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one line blank space is ok?

@gujinghui
Copy link

@danielhanchen, @shimmyshimmer
Could you help review this PR? Thanks a lot!

try: vllm_version = f" vLLM: {importlib_version('vllm')}."
except: vllm_version = ""

statistics = \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On printouts, it's best not to make 2 - instead define new variables like device_count = torch.xpu.device_count() if DEVICE_TYPE == "xpu" else torch.cuda.device_count()

@leizhenyuan
Copy link
Contributor Author

Closed pr as duplicated with
#2801

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants