diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 018e89c90..e0af458a9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -21,6 +21,7 @@ triton_tanh, triton_cast, torch_gpu_device, + is_cdna, ) from transformers.models.llama.modeling_llama import logger from packaging.version import Version @@ -332,7 +333,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : SOFTCAP = logit_softcapping, DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, - num_warps = 32, + num_warps = 32 if not is_cdna() else 16, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index cb8982df0..03397572d 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -62,6 +62,14 @@ def triton_cast(x, dtype): pass +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + + def calculate_settings(n : int) -> (int, int,): BLOCK_SIZE : int = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: