Skip to content
Open
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
DynamicPPL = "0.35, 0.36, 0.37"
EnzymeCore = "0.8.14"
EnzymeCore = "0.8.15"
Enzyme_jll = "0.0.203"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.14"
version = "0.8.15"

[compat]
Adapt = "3, 4"
Expand Down
40 changes: 40 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff, ignore_derivatives
export needs_primal
export ChunkStrategy, OneChunk, AutoChunk, pick_chunksize

function batch_size end

Expand Down Expand Up @@ -797,4 +798,43 @@ end

Combined(mode::ReverseMode) = mode

"""
ChunkStrategy

Abstract type gathering strategies for chunk size selection.

# See also

- [`OneChunk`](@ref)
- [`AutoChunk`](@ref)
"""
abstract type ChunkStrategy end

"""
OneChunk()

Select chunk size so that the corresponding array is processed in a single chunk.
"""
struct OneChunk <: ChunkStrategy end

"""
AutoChunk()

Select chunk size automatically based on internal Enzyme-specific heuristics.
"""
struct AutoChunk <: ChunkStrategy end

const DEFAULT_CHUNK_SIZE = 16

"""
pick_chunksize(s::ChunkStrategy, a::AbstractArray)

Return the chunk size chosen by strategy `s` based on the dimension of array `a`.

- In forward-mode gradients and Jacobians, `a` would be the input array.
- In reverse-mode Jacobians, `a` would be the output array.
"""
pick_chunksize(::OneChunk, a::AbstractArray) = length(a)
pick_chunksize(::AutoChunk, a::AbstractArray) = min(DEFAULT_CHUNK_SIZE, length(a))

end # module EnzymeCore
12 changes: 12 additions & 0 deletions lib/EnzymeCore/test/chunk.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Test
using EnzymeCore

@testset "OneChunk" begin
@test pick_chunksize(OneChunk(), ones(10)) == 10
@test pick_chunksize(OneChunk(), ones(100)) == 100
end

@testset "AutoChunk" begin
@test pick_chunksize(AutoChunk(), ones(10)) == 10
@test pick_chunksize(AutoChunk(), ones(100)) == 16
end
17 changes: 9 additions & 8 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ using EnzymeCore
@testset "Mode modification" begin
include("mode_modification.jl")
end
end

@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end

@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
@testset "Chunk strategy" begin
include("chunk.jl")
end
@testset "within_autodiff" begin
@test !EnzymeCore.within_autodiff()
end
@testset "ignore_derivatives" begin
@test EnzymeCore.ignore_derivatives(3) == 3
end
end
2 changes: 2 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ export autodiff,
make_zero!,
remake_zero!

import EnzymeCore: ChunkStrategy, pick_chunksize

export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export batch_size, onehot, chunkedonehot

Expand Down
8 changes: 6 additions & 2 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ end
return ((one(x),),)
end

@inline function chunkedonehot(x, strategy::ChunkStrategy)
return chunkedonehot(x, Val(pick_chunksize(strategy, x)))
end

@inline tupleconcat(x) = x
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)
Expand Down Expand Up @@ -712,7 +716,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
push!(subderivatives, :(values($resp[1])))
end
:(($(subderivatives...),))
else
else # TODO: handle OneChunk and MaxChunk
subderivatives = Union{Symbol,Expr}[]
for an in 1:argnum
dargs = Union{Symbol,Expr}[]
Expand Down Expand Up @@ -914,7 +918,7 @@ end

chunksize = if chunk <: Val
chunk.parameters[1]
else
else # TODO: handle OneChunk and MaxChunk
1
end
num = ((n_out_val + chunksize - 1) ÷ chunksize)
Expand Down
9 changes: 9 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,12 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

@testset "Chunk size strategies" begin # not passing yet
@test_nowarn gradient(Forward, sum, ones(10); chunk=OneChunk())
@test_nowarn gradient(Forward, sum, ones(10); chunk=AutoChunk())
@test_nowarn jacobian(Forward, copy, ones(10); chunk=OneChunk())
@test_nowarn jacobian(Forward, copy, ones(10); chunk=AutoChunk())
@test_nowarn jacobian(Reverse, copy, ones(10); chunk=OneChunk())
@test_nowarn jacobian(Reverse, copy, ones(10); chunk=AutoChunk())
end