Skip to content
Merged
8 changes: 4 additions & 4 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MAX_FUSED_SIZE,
triton_tanh,
triton_cast,
torch_cuda_device,
torch_gpu_device,
)
from transformers.models.llama.modeling_llama import logger
from packaging.version import Version
Expand Down Expand Up @@ -301,7 +301,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)

with torch_cuda_device(device):
with torch_gpu_device(device):
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
Expand All @@ -319,7 +319,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)

with torch_cuda_device(device):
with torch_gpu_device(device):
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
logits, logits.stride(0),
losses,
Expand Down Expand Up @@ -363,7 +363,7 @@ def backward(ctx, dlosses):
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks : int = div + (mod != 0)

with torch_cuda_device(dlosses.device):
with torch_gpu_device(dlosses.device):
_cross_entropy_backward[(n_rows, n_blocks,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
Expand Down
10 changes: 5 additions & 5 deletions unsloth/kernels/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .utils import (
calculate_settings,
triton_tanh,
torch_cuda_device,
torch_gpu_device,
)


Expand Down Expand Up @@ -48,7 +48,7 @@ def geglu_exact_forward_kernel(gate, up):
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(device):
with torch_gpu_device(device):
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
Expand Down Expand Up @@ -105,7 +105,7 @@ def geglu_exact_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(e.device):
with torch_gpu_device(e.device):
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
Expand Down Expand Up @@ -143,7 +143,7 @@ def geglu_approx_forward_kernel(gate, up):
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(device):
with torch_gpu_device(device):
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
pass
Expand Down Expand Up @@ -207,7 +207,7 @@ def geglu_approx_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(e.device):
with torch_gpu_device(e.device):
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
6 changes: 3 additions & 3 deletions unsloth/kernels/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, torch_cuda_device
from .utils import calculate_settings, torch_gpu_device
from unsloth_zoo.patching_utils import (
patch_layernorm,
)
Expand Down Expand Up @@ -113,7 +113,7 @@ def forward(ctx, X, W, b, eps):
r = torch.empty(n_rows, dtype = torch.float32, device = device)
mu = torch.empty(n_rows, dtype = torch.float32, device = device)

with torch_cuda_device(device):
with torch_gpu_device(device):
layernorm_forward[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
Expand All @@ -140,7 +140,7 @@ def backward(ctx, dY):
X, W, b, r, mu = ctx.saved_tensors
n_rows, n_cols = dY.shape

with torch_cuda_device(dY.device):
with torch_gpu_device(dY.device):
layernorm_backward[(n_rows,)](
dY, dY.stride(0),
X, X .stride(0),
Expand Down
6 changes: 3 additions & 3 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, torch_cuda_device
from .utils import calculate_settings, torch_gpu_device

@triton.jit
def _rms_layernorm_forward(
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool =
r = torch.empty(n_rows, dtype = torch.float32, device = device)

fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
with torch_cuda_device(device):
with torch_gpu_device(device):
fx[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
Expand Down Expand Up @@ -186,7 +186,7 @@ def backward(ctx, dY : torch.Tensor):
# dW = X
dX = torch.empty_like(dY) if ctx.GEMMA else dY

with torch_cuda_device(dY.device):
with torch_gpu_device(dY.device):
_rms_layernorm_backward[(n_rows,)](
dY, dY.stride(0),
dX, dX.stride(0),
Expand Down
6 changes: 3 additions & 3 deletions unsloth/kernels/rope_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, torch_cuda_device
from .utils import calculate_settings, torch_gpu_device
ROPE_GROUP_SIZE : int = 4

def _rope_embedding(
Expand Down Expand Up @@ -100,7 +100,7 @@ def forward(ctx, Q, cos, sin):
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
n_groups : int = div + (mod != 0)

with torch_cuda_device(Q.device):
with torch_gpu_device(Q.device):
_rope_embedding[(n_rows, n_groups, )](
Q, Q.stride(0),
cos, cos.stride(0),
Expand Down Expand Up @@ -135,7 +135,7 @@ def backward(ctx, dY):
cos = ctx.cos
sin = ctx.sin

with torch_cuda_device(dY.device):
with torch_gpu_device(dY.device):
_rope_embedding[(n_rows, ctx.n_groups, )](
dY, dY .stride(0),
cos, cos.stride(0),
Expand Down
6 changes: 3 additions & 3 deletions unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, torch_cuda_device
from .utils import calculate_settings, torch_gpu_device


@triton.jit
Expand Down Expand Up @@ -43,7 +43,7 @@ def swiglu_fg_kernel(e, g):
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(e.device):
with torch_gpu_device(e.device):
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
pass
Expand Down Expand Up @@ -95,7 +95,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
with torch_cuda_device(e.device):
with torch_gpu_device(e.device):
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
pass
Loading