From 0643cef202718fc4c6922285176b254c69588a11 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Mon, 2 Jun 2025 11:48:13 +0200 Subject: [PATCH 1/4] prevent `get_backend` from overflowing the stack Prevent the `get_backend` methods from overflowing the stack/recurring without bound. Hoping this doesn't cause inference issues due to deeper call stacks. Fixes #588 --- ext/LinearAlgebraExt.jl | 4 ++-- src/KernelAbstractions.jl | 11 ++++++++++- test/test.jl | 4 ++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ext/LinearAlgebraExt.jl b/ext/LinearAlgebraExt.jl index adff3179b..45c9065dc 100644 --- a/ext/LinearAlgebraExt.jl +++ b/ext/LinearAlgebraExt.jl @@ -3,7 +3,7 @@ module LinearAlgebraExt using KernelAbstractions: KernelAbstractions using LinearAlgebra: Tridiagonal, Diagonal -KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend(A.diag) -KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend(A.d) +KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend_recur(x -> x.diag, A) +KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend_recur(x -> x.d, A) end diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index ae258a569..c7c3b79a0 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -510,8 +510,17 @@ Get a [`Backend`](@ref) instance suitable for array `A`. """ function get_backend end +function get_backend_recur(f::F, x) where {F} + t() = throw(ArgumentError("throwing to prevent a stack overflow, possibly a `get_backend` method is missing?")) + y = f(x) + if y isa typeof(x) + @noinline t() + end + return get_backend(y) +end + # Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.: -get_backend(A::AbstractArray) = get_backend(parent(A)) +get_backend(A::AbstractArray) = get_backend_recur(parent, A) # Define: # adapt_storage(::Backend, a::Array) = adapt(BackendArray, a) diff --git a/test/test.jl b/test/test.jl index 0ac7df17a..640b32c0d 100644 --- a/test/test.jl +++ b/test/test.jl @@ -7,6 +7,9 @@ using Adapt identity(x) = x +struct UnknownAbstractVector <: AbstractVector{Float32} # issue #588 +end + function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; skip_tests = Set{String}()) @conditional_testset "partition" skip_tests begin backend = Backend() @@ -80,6 +83,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk @test @inferred(KernelAbstractions.get_backend(view(A, 2:4, 1:3))) isa backendT @test @inferred(KernelAbstractions.get_backend(Diagonal(x))) isa backendT @test @inferred(KernelAbstractions.get_backend(Tridiagonal(A))) isa backendT + @test_throws ArgumentError KernelAbstractions.get_backend(UnknownAbstractVector()) # issue #588 end @conditional_testset "sparse" skip_tests begin From 557ae76866ffd080bd4ef975698250ee746ee7a7 Mon Sep 17 00:00:00 2001 From: Neven Sajko <4944410+nsajko@users.noreply.github.com> Date: Mon, 9 Jun 2025 17:25:16 +0200 Subject: [PATCH 2/4] inline the `AbstractArray` change, the other are not necessary Co-authored-by: Valentin Churavy --- src/KernelAbstractions.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index c7c3b79a0..4a07b8d65 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -520,7 +520,13 @@ function get_backend_recur(f::F, x) where {F} end # Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.: -get_backend(A::AbstractArray) = get_backend_recur(parent, A) +function get_backend(A::AbstractArray) + P = parent(A) + if P isa typeof(A) + throw(ArgumentError("Implement `KernelAbstractions.get_backend(::$(typeof(A)))`")) + end + return get_backend(P) +end # Define: # adapt_storage(::Backend, a::Array) = adapt(BackendArray, a) From 3c62fce2dbe5ebe958df71170d0ddce87f328f9a Mon Sep 17 00:00:00 2001 From: Neven Sajko <4944410+nsajko@users.noreply.github.com> Date: Mon, 9 Jun 2025 17:26:19 +0200 Subject: [PATCH 3/4] revert 1 --- ext/LinearAlgebraExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LinearAlgebraExt.jl b/ext/LinearAlgebraExt.jl index 45c9065dc..adff3179b 100644 --- a/ext/LinearAlgebraExt.jl +++ b/ext/LinearAlgebraExt.jl @@ -3,7 +3,7 @@ module LinearAlgebraExt using KernelAbstractions: KernelAbstractions using LinearAlgebra: Tridiagonal, Diagonal -KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend_recur(x -> x.diag, A) -KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend_recur(x -> x.d, A) +KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend(A.diag) +KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend(A.d) end From fc8bf7d875ebbe0e6ec8f848325bfd028f411697 Mon Sep 17 00:00:00 2001 From: Neven Sajko <4944410+nsajko@users.noreply.github.com> Date: Mon, 9 Jun 2025 17:27:12 +0200 Subject: [PATCH 4/4] revert 2 --- src/KernelAbstractions.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 4a07b8d65..15757e3a2 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -510,15 +510,6 @@ Get a [`Backend`](@ref) instance suitable for array `A`. """ function get_backend end -function get_backend_recur(f::F, x) where {F} - t() = throw(ArgumentError("throwing to prevent a stack overflow, possibly a `get_backend` method is missing?")) - y = f(x) - if y isa typeof(x) - @noinline t() - end - return get_backend(y) -end - # Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.: function get_backend(A::AbstractArray) P = parent(A)