Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ext/cuda/cuda_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ function threads_via_occupancy(f!::F!, args) where {F!}
return config.threads
end

function config_via_occupancy(f!::F!, nitems, args) where {F!}
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
config = CUDA.launch_configuration(kernel.fun)
SM_count = CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
max_block_size =
SM_count = CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
if cld(nitems, config.threads) < config.blocks
# gpu will not saturate, so spread out threads across more SMs
even_distribution_threads = cld(nitems, SM_count)
even_distribution_threads =
even_distribution_threads > max_block_size ? div(even_distribution_threads, 2) :
even_distribution_threads
threads = min(even_distribution_threads, config.threads)
blocks = cld(nitems, threads)
else
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
end
return (; threads, blocks)
end

"""
thread_index()

Expand Down
46 changes: 46 additions & 0 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ function knl_copyto_linear!(dest, src, us)
return nothing
end

function knl_copyto_VIJFH_64!(dest, src, ::Val{P}) where {P}
# P is a boolean, indicating if the column is padded
P && threadIdx().x == 64 && return nothing
I = CartesianIndex(threadIdx().y, blockIdx().x, 1, threadIdx().x, blockIdx().y)
@inbounds dest[I] = src[I]
return nothing
end

if VERSION ≥ v"1.11.0-beta"
# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
Expand Down Expand Up @@ -104,6 +112,44 @@ else
end
end

# Specialized kernel launch for VIJFHStyle{63,4} and VIJFHStyle{64,4} arrays. This uses block and grid indices
# instead of computing cartesian indices from a linear index. The threads are launched so that
# a set 64 threads covers a column.
function Base.copyto!(
dest::AbstractData,
bc::BC,
to::ToCUDA,
mask::NoMask = NoMask(),
) where {BC <: Base.Broadcast.Broadcasted{<:ClimaCore.DataLayouts.VIJFHStyle{63, 4}}}
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(dest)
Nv > 0 && Nh > 0 || return dest
args = (dest, bc, Val(true))
auto_launch!(
knl_copyto_VIJFH_64!,
args;
threads_s = (64, Ni, 1),
blocks_s = (Nj, Nh, 1),
)
return dest
end
function Base.copyto!(
dest::AbstractData,
bc::BC,
to::ToCUDA,
mask::NoMask = NoMask(),
) where {BC <: Base.Broadcast.Broadcasted{<:ClimaCore.DataLayouts.VIJFHStyle{64, 4}}}
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(dest)
Nv > 0 && Nh > 0 || return dest
args = (dest, bc, Val(false))
auto_launch!(
knl_copyto_VIJFH_64!,
args;
threads_s = (64, Ni, 1),
blocks_s = (Nj, Nh, 1),
)
return dest
end

# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val
# And this is valid for the CPU or GPU, since the broadcasted object
Expand Down
9 changes: 3 additions & 6 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ NVTX.@annotate function multiple_field_solve!(
args = (device, caches, xs, As, bs, x1, us, mask, cart_inds, Val(Nnames))

nitems = Ni * Nj * Nh * Nnames
threads = threads_via_occupancy(multiple_field_solve_kernel!, args)
n_max_threads = min(threads, nitems)
p = linear_partition(nitems, n_max_threads)

(; threads, blocks) = config_via_occupancy(multiple_field_solve_kernel!, nitems, args)
auto_launch!(
multiple_field_solve_kernel!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
threads_s = threads,
blocks_s = blocks,
always_inline = true,
)
call_post_op_callback() && post_op_callback(x, dev, cache, x, A, b, x1)
Expand Down
8 changes: 3 additions & 5 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
mask = Spaces.get_mask(axes(x))
cart_inds = cartesian_indices_columnwise(us)
args = (device, cache, x, A, b, us, mask, cart_inds)
threads = threads_via_occupancy(single_field_solve_kernel!, args)
nitems = Ni * Nj * Nh
n_max_threads = min(threads, nitems)
p = linear_partition(nitems, n_max_threads)
(; threads, blocks) = config_via_occupancy(single_field_solve_kernel!, nitems, args)
auto_launch!(
single_field_solve_kernel!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
threads_s = threads,
blocks_s = blocks,
)
call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
end
Expand Down
47 changes: 46 additions & 1 deletion ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ function Base.copyto!(
)
else
bc′ = disable_shmem_style(bc)
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(out_fv)
# Specialized kernel launch for common case. This uses block and grid indices
# instead of computing cartesian indices from a linear index
if (Nv == 64 || Nv == 63) && mask isa NoMask && Ni == 4 && Nj == 4 && Nh >= 1500
args = (
strip_space(out, space),
strip_space(bc′, space),
axes(out),
bounds,
Val(Nv == 63),
)
auto_launch!(
copyto_stencil_kernel_64!,
args;
threads_s = (64, Ni, 1),
blocks_s = (Nj, Nh, 1),
)
return out
end
@assert !any_fd_shmem_style(bc′)
cart_inds = if mask isa NoMask
cartesian_indices(us)
Expand All @@ -102,7 +121,6 @@ function Base.copyto!(
else
masked_partition(mask, n_max_threads, us)
end

auto_launch!(
copyto_stencil_kernel!,
args;
Expand All @@ -115,6 +133,33 @@ function Base.copyto!(
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

# Specialized kernel for common case of Nv == 64 or Nv == 63
function copyto_stencil_kernel_64!(
out,
bc::Union{
StencilBroadcasted{CUDAColumnStencilStyle},
Broadcasted{CUDAColumnStencilStyle},
},
space,
bds,
::Val{P},
) where {P}
@inbounds begin
# P is a boolean, indicating if the column is padded
P && threadIdx().x == 64 && return nothing
i = threadIdx().y
j = blockIdx().x
v = threadIdx().x
h = blockIdx().y
hidx = (i, j, h)
(li, lw, rw, ri) = bds
idx = v - 1 + li
val = Operators.getidx(space, bc, idx, hidx)
setidx!(space, out, idx, hidx, val)
end
return nothing
end

function copyto_stencil_kernel!(
out,
bc::Union{
Expand Down
8 changes: 4 additions & 4 deletions test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
@info "$test_name:\n\tSolve Time = $solve_time_rounded s, \
Multiplication Time = $mul_time_rounded s (Ratio = \
$time_ratio_rounded)\n\t$error_string"

if use_rel_error
@test rel_error < 1e-5
else
@test max_eps_error <= 3
# when run with CUDA, the errors are larger due to larger columns
@test max_eps_error <= (using_cuda ? 4 : 3)
end

# In addition to ignoring the type instabilities from CUDA, ignore those
Expand All @@ -59,7 +59,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
cuda_frames...,
cublas_frames...,
AnyFrameModule(MatrixFields.KrylovKit),
AnyFrameModule(Base.CoreLogging),
AnyFrameModule(Base.CoreLogging)
)
using_cuda ||
@test_opt ignored_modules = ignored FieldMatrixWithSolver(A, b, alg)
Expand Down Expand Up @@ -127,7 +127,7 @@ end
MatrixFields.BlockLowerTriangularSolve(@name(c)),
MatrixFields.BlockArrowheadSolve(@name(c)),
MatrixFields.ApproximateBlockArrowheadIterativeSolve(@name(c)),
MatrixFields.StationaryIterativeSolve(; n_iters = using_cuda ? 28 : 18),
MatrixFields.StationaryIterativeSolve(; n_iters = using_cuda ? 42 : 18),
)
test_field_matrix_solver(;
test_name = "$(typeof(alg).name.name) for a block diagonal matrix \
Expand Down
14 changes: 10 additions & 4 deletions test/MatrixFields/matrix_field_test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ const comms_device = ClimaComms.device()
const using_cuda = comms_device isa ClimaComms.CUDADevice
cuda_module(ext) = using_cuda ? ext.CUDA : ext
const cuda_mod = cuda_module(Base.get_extension(ClimaComms, :ClimaCommsCUDAExt))
const cuda_frames = using_cuda ? (AnyFrameModule(cuda_mod),) : ()
const cuda_frames = using_cuda ? (AnyFrameModule(cuda_mod), AnyFrameModule(Base.StackTraces),) : ()
const cublas_frames = using_cuda ? (AnyFrameModule(cuda_mod.CUBLAS),) : ()
const invalid_ir_error = using_cuda ? cuda_mod.InvalidIRError : ErrorException

Expand Down Expand Up @@ -343,9 +343,15 @@ end

# Generate extruded finite difference spaces for testing. Include topography
# when possible.
function test_spaces(::Type{FT}) where {FT}
velem = 20 # This should be big enough to test high-bandwidth matrices.
helem = npoly = 1 # These should be small enough for the tests to be fast.
function test_spaces(::Type{FT}; high_res = using_cuda) where {FT}
if high_res
velem = 63
helem = 16
npoly = 3
else
velem = 20 # This should be big enough to test high-bandwidth matrices.
helem = npoly = 1 # These should be small enough for the tests to be fast.
end

comms_ctx = ClimaComms.SingletonCommsContext(comms_device)
hdomain = Domains.SphereDomain(FT(10))
Expand Down
Loading