Skip to content

Conversation

@gdalle
Copy link
Contributor

@gdalle gdalle commented Oct 8, 2025

@gdalle gdalle marked this pull request as draft October 8, 2025 11:48
@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl
index 37d96040..0cd7d57e 100644
--- a/lib/EnzymeCore/src/EnzymeCore.jl
+++ b/lib/EnzymeCore/src/EnzymeCore.jl
@@ -863,6 +863,6 @@ pick_chunksize(::LargestChunk, a::AbstractArray) = Val(length(a))  # allows infe
 
 pick_chunksize(::AutoChunk, n::Integer) = Val(min(DEFAULT_CHUNK_SIZE, n))  # TODO: improve
 pick_chunksize(s::AutoChunk, a::AbstractArray) = pick_chunksize(s, length(a))
-pick_chunksize(::FixedChunk{C}, ::Union{Integer,AbstractArray}) where {C} = Val{C}()
+pick_chunksize(::FixedChunk{C}, ::Union{Integer, AbstractArray}) where {C} = Val{C}()
 
 end # module EnzymeCore
diff --git a/src/sugar.jl b/src/sugar.jl
index 319f937e..71ed1ed6 100644
--- a/src/sugar.jl
+++ b/src/sugar.jl
@@ -418,7 +418,7 @@ end
 const ExtendedChunkStrategy = Union{ChunkStrategy, Nothing, Val}
 
 # eats and returns a type because generated functions work on argument types
-get_strategy(chunk::Type{CS}) where {CS<:ChunkStrategy} = chunk
+get_strategy(chunk::Type{CS}) where {CS <: ChunkStrategy} = chunk
 
 function get_strategy(::Type{Nothing})
     Base.depwarn(
@@ -457,7 +457,7 @@ end
 @inline tupleconcat(x, y) = (x..., y...)
 @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)
 
-@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any,N}) where {X, N}
+@generated function create_shadows(chunk::ExtendedChunkStrategy, x::X, vargs::Vararg{Any, N}) where {X, N}
     chunk_strategy = get_strategy(chunk)
     args =  Union{Symbol,Expr}[:x]
     tys =  Type[X]
@@ -614,9 +614,9 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
     f::F,
     x::ty_0,
     args::Vararg{Any,N};
-    chunk::ExtendedChunkStrategy = LargestChunk(),
+        chunk::ExtendedChunkStrategy = LargestChunk(),
     shadows::ST = create_shadows(chunk, x, args...),
-) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero,ST, ty_0, N}
+    ) where {F, ReturnPrimal, ABI, ErrIfFuncWritten, RuntimeActivity, StrongZero, ST, ty_0, N}
 
     chunk_strategy = get_strategy(chunk)
 
@@ -827,10 +827,10 @@ end
     mode::ReverseMode{ReturnPrimal},
     RT::RType,
     n_outs::OutType,
-    chunk::ExtendedChunkStrategy,
+        chunk::ExtendedChunkStrategy,
     f::F,
     xs::Vararg{Any, Nargs}
-) where {ReturnPrimal,RType, F,Nargs,OutType}
+    ) where {ReturnPrimal, RType, F, Nargs, OutType}
     chunk_strategy = get_strategy(chunk)
     fty = if f <: Enzyme.Annotation
         f.parameters[1]
@@ -1255,8 +1255,8 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
     f::F,
     xs::Vararg{Any, Nargs};
     n_outs::OutType = nothing,
-    chunk::ExtendedChunkStrategy = LargestChunk(),
-) where {F,Nargs, OutType}
+        chunk::ExtendedChunkStrategy = LargestChunk(),
+    ) where {F, Nargs, OutType}
 
     fty = if f <: Enzyme.Annotation
         f.parameters[1]
diff --git a/test/sugar.jl b/test/sugar.jl
index 00d13f18..78325557 100644
--- a/test/sugar.jl
+++ b/test/sugar.jl
@@ -675,9 +675,9 @@ fchunk2(x) = map(sin, x) + map(cos, reverse(x))
         @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}}
         @test Enzyme.chunkedonehot(ones(10), Enzyme.LargestChunk()) isa Tuple{NTuple{10}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.LargestChunk()) isa Tuple{NTuple{30}}
-        @test Enzyme.chunkedonehot(ones(3), Enzyme.FixedChunk{1}()) isa Tuple{NTuple{1},NTuple{1},NTuple{1}}
-        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{4}()) isa Tuple{NTuple{4},NTuple{4},NTuple{2}}
-        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{5}()) isa Tuple{NTuple{5},NTuple{5}}
+        @test Enzyme.chunkedonehot(ones(3), Enzyme.FixedChunk{1}()) isa Tuple{NTuple{1}, NTuple{1}, NTuple{1}}
+        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{4}()) isa Tuple{NTuple{4}, NTuple{4}, NTuple{2}}
+        @test Enzyme.chunkedonehot(ones(10), Enzyme.FixedChunk{5}()) isa Tuple{NTuple{5}, NTuple{5}}
         @test Enzyme.chunkedonehot(ones(10), Enzyme.AutoChunk()) isa Tuple{NTuple{10}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}}
         @test Enzyme.chunkedonehot(ones(30), Enzyme.AutoChunk()) isa Tuple{NTuple{16}, NTuple{14}}

@gdalle gdalle marked this pull request as ready for review October 29, 2025 09:00
@gdalle gdalle requested a review from wsmoses October 29, 2025 09:00
@gdalle
Copy link
Contributor Author

gdalle commented Oct 29, 2025

@wsmoses I think this is a good first step, and I'm not able to do the reverse-mode Jacobian fix on my own. Can someone help? Also, should we make create_shadows public and documented?

@gdalle
Copy link
Contributor Author

gdalle commented Oct 30, 2025

@wsmoses this is good to go, I fixed the reverse mode too

src/sugar.jl Outdated

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
gradient(::ForwardMode, f, x, args...; chunk=nothing, shadows=create_shadows(chunk, x, args...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should change chunk=nothing, to the relevant correct explicit default.

we should also not support val/nothing inside of here and isntead add a deprecated method (or perhaps first check in the expr) if its one of the legacy methods and mark as deprecated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the latest commit.

An issue with the current code is that the deprecation warning will only be visible at the first function call, since that is the only time where the generating function is actually generated:

julia> jacobian(Forward, copy, ones(2); chunk=nothing)
┌ Warning: The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=SmallestChunk()` instead.
│   caller = #s719#135 at sugar.jl:461 [inlined]
└ @ Core ~/Documents/GitHub/Julia/Enzyme.jl/src/sugar.jl:461
┌ Warning: The `chunk=nothing` configuration will be deprecated in a future release. Please use `chunk=SmallestChunk()` instead.
│   caller = #s717#137 at sugar.jl:621 [inlined]
└ @ Core ~/Documents/GitHub/Julia/Enzyme.jl/src/sugar.jl:621
([1.0 0.0; 0.0 1.0],)

julia> jacobian(Forward, copy, ones(2); chunk=nothing)
([1.0 0.0; 0.0 1.0],)

Not sure whether that's an issue or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move it to a permanent warning at every call, we should probably add tests for it too

@gdalle
Copy link
Contributor Author

gdalle commented Oct 30, 2025

Do you want this to error?

jacobian(Forward, copy, ones(2); chunk=FixedSize{3}())

Currently Enzyme doesn't mind when the chunk size is larger than the input, I'm not sure what the expected behavior there is (and it's hard to deduce from the code). For comparison:

julia> cfg = ForwardDiff.JacobianConfig(copy, ones(2), ForwardDiff.Chunk{3}());

julia> ForwardDiff.jacobian(copy, ones(2), cfg, Val(true))
ERROR: ArgumentError: chunk size cannot be greater than ForwardDiff.structural_length(x) (3 > 2)

@codecov
Copy link

codecov bot commented Oct 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 53.47%. Comparing base (107b327) to head (dc5ff05).
⚠️ Report is 9 commits behind head on main.

❗ There is a different number of reports uploaded between BASE (107b327) and HEAD (dc5ff05). Click for more details.

HEAD has 29 uploads less than BASE
Flag BASE (107b327) HEAD (dc5ff05)
34 5
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2659       +/-   ##
===========================================
- Coverage   72.61%   53.47%   -19.15%     
===========================================
  Files          58       12       -46     
  Lines       18746     1210    -17536     
===========================================
- Hits        13613      647    -12966     
+ Misses       5133      563     -4570     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wsmoses
Copy link
Member

wsmoses commented Nov 20, 2025

I'm okay for it not erring at the moment, but we can always add it later

@gdalle gdalle requested a review from wsmoses November 21, 2025 07:05
@gdalle
Copy link
Contributor Author

gdalle commented Nov 21, 2025

@wsmoses I think the last aspect to settle is #2659 (comment), but we can also change it in postprod since it's only about when to display deprecation warnings

@gdalle
Copy link
Contributor Author

gdalle commented Nov 21, 2025

Good to go, if CI passes!

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm but I'm going to wait for @vchuravy to get back from vacation (ending this week iirc) to see if he has any api thoughts before merging

@gdalle
Copy link
Contributor Author

gdalle commented Dec 2, 2025

@vchuravy what do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants