-
Notifications
You must be signed in to change notification settings - Fork 41
Remove Zygote dependency and update documentation for ForwardDiff integration #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Bijectors.jl documentation for PR #393 is available at: |
|
Keeping as draft until my local tests pass |
|
also next commit will remove the Zygote CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes Zygote support from Bijectors.jl and updates tests, configuration, and documentation to use ForwardDiff and other supported AD backends.
- Deletes the Zygote-specific extension module and removes Zygote entries from configs and tests
- Updates test calls and CI workflows to drop Zygote and ensure ForwardDiff/ReverseDiff/Mooncake coverage
- Revises documentation examples to use ForwardDiff.gradient instead of Zygote.gradient
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| test/runtests.jl | Removed using Zygote; needs to import Enzyme for AD tests |
| test/ad/utils.jl | Removed Zygote blocks from test_ad |
| test/ad/flows.jl | Added (:EnzymeForward,) broken tuple in tests; may be incorrect |
| test/Project.toml | Deleted Zygote dependency |
| src/chainrules.jl | Generalized @debug comment and added new rrule definitions |
| src/bijectors/pd.jl | Updated comment to drop Zygote reference |
| ext/BijectorsZygoteExt.jl | Entire file deleted to remove Zygote adjoints |
| docs/src/examples.md | Swapped out Zygote example for ForwardDiff |
| docs/Project.toml | Removed Zygote from documentation dependencies |
| .github/workflows/AD.yml | Removed Zygote from CI matrix |
Comments suppressed due to low confidence (3)
src/chainrules.jl:293
- New chain rules for
aT_band_vecwere added without corresponding tests; consider adding unit tests to validate their gradient behavior.
function ChainRulesCore.rrule(::typeof(aT_b), a::AbstractVector, b::AbstractMatrix)
test/runtests.jl:15
- Enzyme is used in AD tests but not imported; add
using Enzymealongside other AD backends
using Tracker
test/ad/flows.jl:3
- The second argument to
test_adspecifies broken backends, not selection of AD to test. This tuple will not limit to EnzymeForward as intended; remove the(:EnzymeForward,)argument or adjusttest_adto support selecting specific backends.
test_ad(randn(7), (:EnzymeForward,)) do θ
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good, one problem with the docs and a couple of questions about PlanarLayer tests as also discussed on Slack.
src/chainrules.jl
Outdated
| # Fixes AD issues with `@debug` | ||
| ChainRulesCore.@non_differentiable _debug(::Any) | ||
|
|
||
| # ChainRules for utility functions used by PlanarLayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these copied over from somewhere else, or new code? Do we need them? I think ChainRules rules are mostly for Zygote, although other AD packages use them too, so wondering if/why these are necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these functions are used by PlanarLayer (src/bijectors/planar_layer.jl). The ChainRules are necessary for proper automatic differentiation support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ChainRules is only directly used by Zygote, all other (major) AD systems have their own way of handling them. For example ForwardDiff and ReverseDiff use operator overloading. For some AD systems (Mooncake, Enzyme) you can import rules from ChainRules but this is an optional mechanism, because those two backends differentiate code at a lower level. (It's messy, but there's an overview of Julia AD at: https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/faq/differentiability/) So, I'm a bit surprised that we would need to add new rules in this PR, given that we're removing Zygote. Indeed I would have been more expecting to see removal of rules.
If you remove these do any of the tests fail?
(Also, please don't resolve comments unless you're sure that it's been resolved! Otherwise it's hard for people to see what still needs to be discussed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AoifeHughes, did you a chance to check yet what happens if you remove these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm doing it right now, it still breaks :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, odd. Happy to have a look at the error and figure it out together if it helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not gonna lie, I'm totally lost and feel like Im just randomly changing things I don't understand here. I'd really appreciate someone taking over so I can learn a bit about how best to work through the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. Ping me or Penny on Slack when you have a moment and let's have a look at it together.
|
I don't think there's a need for me to review this separately. I'm happy to follow up / approve if further changes are made but I think Markus's preexisting comments should be dealt with first, in particular the one about addition of ChainRulesCore functionality. |
|
Im super stuck on this and just cannot get the tests to work. I'd really appreciate any kind of help with this PR. |
0f5c58d to
ba1bf9c
Compare
| # Fixes AD issues with `@debug` | ||
| ChainRulesCore.@non_differentiable _debug(::Any) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if this is still needed? Does some other AD backend rely on this?
src/chainrules.jl
Outdated
| # Fixes AD issues with `@debug` | ||
| ChainRulesCore.@non_differentiable _debug(::Any) | ||
|
|
||
| # ChainRules for utility functions used by PlanarLayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AoifeHughes, did you a chance to check yet what happens if you remove these?
8676750 to
3ccd657
Compare
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @AoifeHughes!
Remove Zygote Support from
Bijectors.jl- Addresses: #374Summary
This PR removes Zygote as a supported automatic differentiation backend from
Bijectors.jl, aligning with the policy that only ForwardDiff, ReverseDiff, and Mooncake are officially supported forDynamicPPL.Changes Made
Files Deleted
ext/BijectorsZygoteExt.jlComplete extension module (199 lines) containing 20+ custom
@adjointimplementations for Zygote.Project Configuration
Project.tomlRemoved Zygote from:
[weakdeps]section[extensions]section (BijectorsZygoteExt = "Zygote")[compat]constraints[extras]dependenciestest/Project.tomldocs/Project.tomlSource Code
test/runtests.jlusing Zygoteimport statementtest/ad/utils.jltest_ad()::Zygotefrom broken backends listsrc/bijectors/pd.jlsrc/chainrules.jl@debugmacroDocumentation
docs/src/examples.mdZygote.gradientexample withForwardDiff.gradientin the normalizing flows training exampleCLAUDE.mdImpact
ForwardDiff,ReverseDiff, andMooncakeremain fully supportedTesting
ForwardDiff,ReverseDiff, andMooncakefunctionality