Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
- DynamicExpressions
- Lux
- SciML
- KernelAbstractions
steps:
- uses: actions/checkout@v5
- uses: julia-actions/setup-julia@v2
Expand Down
11 changes: 11 additions & 0 deletions test/integration/KernelAbstractions/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

[sources]
Enzyme = {path = "../../.."}
EnzymeCore = {path = "../../../lib/EnzymeCore"}

[compat]
KernelAbstractions = "0.9"
55 changes: 55 additions & 0 deletions test/integration/KernelAbstractions/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Test
using Enzyme
using KernelAbstractions

@kernel function square!(A)
I = @index(Global, Linear)
@inbounds A[I] *= A[I]
end

function square_caller(A)
backend = get_backend(A)
kernel = square!(backend)
kernel(A, ndrange = size(A))
KernelAbstractions.synchronize(backend)
return
end


@kernel function mul!(A, B)
I = @index(Global, Linear)
@inbounds A[I] *= B
end

function mul_caller(A, B)
backend = get_backend(A)
kernel = mul!(backend)
kernel(A, B, ndrange = size(A))
KernelAbstractions.synchronize(backend)
return
end

@testset "kernels" begin
A = Array{Float64}(undef, 64)
dA = Array{Float64}(undef, 64)

A .= (1:1:64)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2))[1]

@test all(dA .≈ 1.2)
@test dB ≈ sum(1:1:64)

A .= (1:1:64)
dA .= 1

Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA))
@test all(dA .≈ 2:2:128)
end
Loading