Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ jobs:
fail-fast: false
matrix:
runner:
- version: '1.11'
- version: '1'
os: 'ubuntu-latest'
- version: '1.11'
- version: '1'
os: 'macos-latest'
- version: '1.11'
- version: '1'
os: 'windows-latest'
- version: '1.11'
os: 'ubuntu-latest'
- version: 'min'
os: 'ubuntu-latest'
group:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/DocTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:

- uses: julia-actions/setup-julia@v2
with:
version: '1.11'
version: '1'

- uses: julia-actions/cache@v2

Expand Down
35 changes: 35 additions & 0 deletions .github/workflows/Enzyme.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Enzyme AD tests

on:
push:
branches:
- main
pull_request:

# needed to allow julia-actions/cache to delete old caches that it has created
permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
enzyme:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5

- uses: julia-actions/setup-julia@v2
with:
version: "1.11"

- uses: julia-actions/cache@v2

- name: Run flaky Enzyme tests
working-directory: test/integration/enzyme
run: |
julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()'
julia --project=. --color=yes main.jl
2 changes: 1 addition & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ julia> d = LogNormal() # support is (0, Inf)
LogNormal{Float64}(μ=0.0, σ=1.0)

julia> b = bijector(d) # log function transforms to unconstrained space
(::Base.Fix1{typeof(broadcast), typeof(log)}) (generic function with 1 method)
(::Base.Fix1{typeof(broadcast), typeof(log)}) (generic function with 2 methods)

julia> b(1.0)
0.0
Expand Down
6 changes: 0 additions & 6 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,7 +17,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -36,8 +33,6 @@ Combinatorics = "1.0.2"
DifferentiationInterface = "0.7.7"
DistributionsAD = "0.6.3"
Documenter = "1"
Enzyme = "0.13.12"
EnzymeTestUtils = "0.2.1"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10, 1.0.1"
Expand All @@ -47,7 +42,6 @@ LazyArrays = "1, 2"
LogDensityProblems = "2"
LogExpFunctions = "0.3.1"
MCMCDiagnosticTools = "0.3"
Mooncake = "0.4.147"
ReverseDiff = "1.4.2"
StableRNGs = "1"
Tracker = "0.2.11"
Expand Down
21 changes: 7 additions & 14 deletions test/ad/corr.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Enzyme: ForwardMode

@testset "VecCorrBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
ENZYME_FWD_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:Enzyme.ForwardMode}
# Enzyme is tested separately as these tests are flaky
# TODO(penelopeysm): Fix upstream and re-enable.
if adtype isa AutoEnzyme
continue
end

@testset "d = $d" for d in (1, 2, 4)
dist = LKJ(d, 2.0)
Expand All @@ -13,17 +15,8 @@ using Enzyme: ForwardMode

roundtrip(y) = sum(transform(b, binv(y)))
inverse_only(y) = sum(transform(binv, y))
if d == 4 && ENZYME_FWD_AND_1p11
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
roundtrip, adtype, y
)
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
inverse_only, adtype, y
)
else
test_ad(roundtrip, adtype, y)
test_ad(inverse_only, adtype, y)
end
test_ad(roundtrip, adtype, y)
test_ad(inverse_only, adtype, y)
end
end

Expand Down
39 changes: 9 additions & 30 deletions test/ad/flows.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Enzyme: Enzyme

@testset "PlanarLayer: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
# https://github.com/TuringLang/Bijectors.jl/issues/415
ENZYME_FWD_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:Enzyme.ForwardMode}
ENZYME_RVS_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:Enzyme.ReverseMode}
# Enzyme is tested separately as these tests are flaky
# TODO(penelopeysm): Fix upstream and re-enable.
if adtype isa AutoEnzyme
continue
end

# logpdf of a flow with a planar layer and two-dimensional inputs
function f(θ)
Expand All @@ -12,24 +12,15 @@ using Enzyme: Enzyme
x = θ[6:7]
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
end
if ENZYME_FWD_AND_1p11
@test_throws Enzyme.Compiler.EnzymeInternalError test_ad(f, adtype, randn(7))
else
test_ad(f, adtype, randn(7))
end
test_ad(f, adtype, randn(7))

function g(θ)
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(zeros(2), I), layer)
x = reshape(θ[6:end], 2, :)
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
end
if ENZYME_FWD_AND_1p11
@warn "Skipping forward-mode Enzyme for `g` on 1.11 due to segfault"
# @test_throws Enzyme.Compiler.EnzymeInternalError test_ad(g, adtype, randn(11))
else
test_ad(g, adtype, randn(11))
end
test_ad(g, adtype, randn(11))

# logpdf of a flow with the inverse of a planar layer and two-dimensional inputs
function finv(θ)
Expand All @@ -38,25 +29,13 @@ using Enzyme: Enzyme
x = θ[6:7]
return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
end
if ENZYME_FWD_AND_1p11
@warn "Skipping forward-mode Enzyme for `finv` on 1.11 due to segfault"
elseif ENZYME_RVS_AND_1p11
@test_throws Enzyme.LLVM.LLVMException test_ad(finv, adtype, randn(7))
else
test_ad(finv, adtype, randn(7))
end
test_ad(finv, adtype, randn(7))

function ginv(θ)
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(zeros(2), I), inverse(layer))
x = reshape(θ[6:end], 2, :)
return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x))
end
if ENZYME_FWD_AND_1p11 || ENZYME_RVS_AND_1p11
@warn "Skipping forward-mode Enzyme for `ginv` on 1.11 due to segfault"
elseif ENZYME_RVS_AND_1p11
@test_throws Enzyme.LLVM.LLVMException test_ad(finv, adtype, randn(7))
else
test_ad(ginv, adtype, randn(11))
end
test_ad(ginv, adtype, randn(11))
end
29 changes: 9 additions & 20 deletions test/ad/pd.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
_topd(x) = x * x' + I

@testset "PDVecBijector: $backend_name" for (backend_name, adtype) in TEST_ADTYPES
ENZYME_FWD_AND_1p11 = VERSION >= v"1.11" && adtype isa AutoEnzyme{<:Enzyme.ForwardMode}
# Enzyme is tested separately as these tests are flaky
# TODO(penelopeysm): Fix upstream and re-enable.
if adtype isa AutoEnzyme
continue
end

d = 4
b = Bijectors.PDVecBijector()
Expand All @@ -16,23 +20,8 @@ _topd(x) = x * x' + I
inverse_chol_lower(y) = sum(Bijectors.cholesky_lower(transform(binv, y)))
inverse_chol_upper(y) = sum(Bijectors.cholesky_upper(transform(binv, y)))

if ENZYME_FWD_AND_1p11
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
forward_only, adtype, vec(z)
)
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
inverse_only, adtype, vec(z)
)
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
inverse_chol_lower, adtype, y
)
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError test_ad(
inverse_chol_upper, adtype, y
)
else
test_ad(forward_only, adtype, vec(z))
test_ad(inverse_only, adtype, y)
test_ad(inverse_chol_lower, adtype, y)
test_ad(inverse_chol_upper, adtype, y)
end
test_ad(forward_only, adtype, vec(z))
test_ad(inverse_only, adtype, y)
test_ad(inverse_chol_lower, adtype, y)
test_ad(inverse_chol_upper, adtype, y)
end
11 changes: 11 additions & 0 deletions test/integration/enzyme/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
Bijectors = {path = "../../../"}
Loading