diff --git a/.github/workflows/Integration.yml b/.github/workflows/Integration.yml index 71a1a308dc..54eea4c241 100644 --- a/.github/workflows/Integration.yml +++ b/.github/workflows/Integration.yml @@ -51,6 +51,7 @@ jobs: - DynamicExpressions - Lux - SciML + - KernelAbstractions steps: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v2 diff --git a/test/integration/KernelAbstractions/Project.toml b/test/integration/KernelAbstractions/Project.toml new file mode 100644 index 0000000000..ae658be845 --- /dev/null +++ b/test/integration/KernelAbstractions/Project.toml @@ -0,0 +1,11 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" + +[sources] +Enzyme = {path = "../../.."} +EnzymeCore = {path = "../../../lib/EnzymeCore"} + +[compat] +KernelAbstractions = "0.9" diff --git a/test/integration/KernelAbstractions/runtests.jl b/test/integration/KernelAbstractions/runtests.jl new file mode 100644 index 0000000000..f66bf87ae2 --- /dev/null +++ b/test/integration/KernelAbstractions/runtests.jl @@ -0,0 +1,55 @@ +using Test +using Enzyme +using KernelAbstractions + +@kernel function square!(A) + I = @index(Global, Linear) + @inbounds A[I] *= A[I] +end + +function square_caller(A) + backend = get_backend(A) + kernel = square!(backend) + kernel(A, ndrange = size(A)) + KernelAbstractions.synchronize(backend) + return +end + + +@kernel function mul!(A, B) + I = @index(Global, Linear) + @inbounds A[I] *= B +end + +function mul_caller(A, B) + backend = get_backend(A) + kernel = mul!(backend) + kernel(A, B, ndrange = size(A)) + KernelAbstractions.synchronize(backend) + return +end + +@testset "kernels" begin + A = Array{Float64}(undef, 64) + dA = Array{Float64}(undef, 64) + + A .= (1:1:64) + dA .= 1 + + Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA)) + @test all(dA .≈ (2:2:128)) + + A .= (1:1:64) + dA .= 1 + + _, dB = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2))[1] + + @test all(dA .≈ 1.2) + @test dB ≈ sum(1:1:64) + + A .= (1:1:64) + dA .= 1 + + Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA)) + @test all(dA .≈ 2:2:128) +end