From 1312204356ac997de963a0768fcfe858791712e2 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 26 Feb 2025 12:28:32 +0100 Subject: [PATCH] Check that malformed allocations throw and don't stackoverflow --- src/KernelAbstractions.jl | 8 ++++---- test/test.jl | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 582a61e9d..5103882ef 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -533,15 +533,15 @@ Allocate a storage array appropriate for the computational backend. !!! note Backend implementations **must** implement `allocate(::NewBackend, T, dims::Tuple)` """ -allocate(backend::Backend, T, dims...) = allocate(backend, T, dims) -allocate(backend::Backend, T, dims::Tuple) = throw(MethodError(allocate, (backend, T, dims))) +allocate(backend::Backend, T::Type, dims...) = allocate(backend, T, dims) +allocate(backend::Backend, T::Type, dims::Tuple) = throw(MethodError(allocate, (backend, T, dims))) """ zeros(::Backend, Type, dims...)::AbstractArray Allocate a storage array appropriate for the computational backend filled with zeros. """ -zeros(backend::Backend, T, dims...) = zeros(backend, T, dims) +zeros(backend::Backend, T::Type, dims...) = zeros(backend, T, dims) function zeros(backend::Backend, ::Type{T}, dims::Tuple) where {T} data = allocate(backend, T, dims...) fill!(data, zero(T)) @@ -553,7 +553,7 @@ end Allocate a storage array appropriate for the computational backend filled with ones. """ -ones(backend::Backend, T, dims...) = ones(backend, T, dims) +ones(backend::Backend, T::Type, dims...) = ones(backend, T, dims) function ones(backend::Backend, ::Type{T}, dims::Tuple) where {T} data = allocate(backend, T, dims) fill!(data, one(T)) diff --git a/test/test.jl b/test/test.jl index d86d9803b..4e017c8a9 100644 --- a/test/test.jl +++ b/test/test.jl @@ -307,6 +307,13 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk @test size(KernelAbstractions.zeros(backend, Float32, 0, 9)) == (0, 9) end + @testset "Malformed allocations" begin + backend = Backend() + @test_throws MethodError KernelAbstractions.zeros(backend, 2, 2) + @test_throws MethodError KernelAbstractions.ones(backend, 2, 2) + @test_throws MethodError KernelAbstractions.allocate(backend, 2, 2) + end + @kernel cpu = false function gpu_return_kernel!(x) i = @index(Global) if i ≤ (length(x) ÷ 2)