Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions ext/BijectorsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using Bijectors:
find_alpha,
pd_logpdf_with_trans,
istraining,
mapvcat,
eachcolmaphcat,
sumeachcol,
pd_link,
Expand All @@ -36,10 +35,6 @@ using Bijectors.Distributions: LocationScale

@adjoint istraining() = true, _ -> nothing

@adjoint function mapvcat(f, args...)
g(f, args...) = map(f, args...)
return pullback(g, f, args...)
end
@adjoint function eachcolmaphcat(f, x1, x2)
function g(f, x1, x2)
init = reshape(f(view(x1, :, 1), x2[1]), :, 1)
Expand Down
11 changes: 11 additions & 0 deletions test/ad/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,15 @@
test_ad(y) do y
sum(transform(binv, y))
end

bvec = Stacked([b1, b2], [1:4, 5:5])
bvec_inv = inverse(bvec)

test_ad(y) do x
sum(transform(bvec, binv(x)))
end

test_ad(y) do y
sum(transform(bvec_inv, y))
end
end
Loading