Skip to content

Commit db08f05

Browse files
authored
Add GPUArraysCore update in place scalar (EnzymeAD#2220)
1 parent a069fba commit db08f05

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
[weakdeps]
2222
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2425
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2526
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2627
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2728

2829
[extensions]
2930
EnzymeBFloat16sExt = "BFloat16s"
3031
EnzymeChainRulesCoreExt = "ChainRulesCore"
32+
EnzymeGPUArraysCoreExt = "GPUArraysCore"
3133
EnzymeLogExpFunctionsExt = "LogExpFunctions"
3234
EnzymeSpecialFunctionsExt = "SpecialFunctions"
3335
EnzymeStaticArraysExt = "StaticArrays"
@@ -52,6 +54,7 @@ julia = "1.10"
5254
[extras]
5355
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
5456
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
57+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
5558
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
5659
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
5760
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

ext/EnzymeGPUArraysCore.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module EnzymeGPUArraysCoreExt
2+
3+
using GPUArraysCore
4+
using Enzyme
5+
6+
@inline function Enzyme.onehot(x::AbstractGPUArray)
7+
onehot_internal(zerosetfn, x, 0, length(x))
8+
end
9+
10+
@inline function Enzyme.onehot(x::AbstractGPUArray, start::Int, endl::Int)
11+
onehot_internal(zerosetfn, x, start-1, endl-start+1)
12+
end
13+
14+
function Enzyme.zerosetfn(x::AbstractGPUArray, i::Int)
15+
res = zero(x)
16+
@allowscalar @inbounds res[i] = 1
17+
return res
18+
end
19+
20+
function Enzyme.zerosetfn!(x::AbstractGPUArray, i::Int, val)
21+
@allowscalar @inbounds x[i] = += val
22+
return
23+
end
24+
25+
26+
end # module

src/sugar.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ function zerosetfn(x, i::Int)
77
return res
88
end
99

10+
function zerosetfn!(x, i::Int, val)
11+
@inbounds x[i] += val
12+
nothing
13+
end
14+
1015
@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array}
1116
ir = GPUCompiler.JuliaContext() do ctx
1217
Base.@_inline_meta
@@ -927,7 +932,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
927932
dx = MD ? Ref(z) : z
928933
res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx))
929934
tape = res[1]
930-
@inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3])))
935+
zerosetfn!(res[3], i, Compiler.default_adjoint(eltype(typeof(res[3]))))
931936
adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape)
932937
return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing)
933938
end
@@ -994,8 +999,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
994999
j = 0
9951000
for shadow in res[3]
9961001
j += 1
997-
@inbounds shadow[(i-1)*chunksize+j] +=
998-
Compiler.default_adjoint(eltype(typeof(shadow)))
1002+
zerosetfn!(shadow, (i-1)*chunksize+j,
1003+
Compiler.default_adjoint(eltype(typeof(shadow))))
9991004
end
10001005
(i == num ? adjoint2 : adjoint)(
10011006
Const(f),

0 commit comments

Comments
 (0)