From 03e97afc315737cdf54fcc1ca7ad4d894f70659e Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 16 Sep 2023 12:50:14 -0400 Subject: [PATCH 1/3] initial commit --- ext/IntegralsForwardDiffExt.jl | 4 +- ext/IntegralsZygoteExt.jl | 77 ++++++++++++------- .../src/IntegralsCubature.jl | 53 ++++--------- test/derivative_tests.jl | 26 +++---- test/interface_tests.jl | 34 ++++---- 5 files changed, 99 insertions(+), 95 deletions(-) diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index 6cfb3254..307b4b3e 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -29,7 +29,7 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, dfdp = function (out, x, p) dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p) if cache.batch > 0 - dx = similar(dualp, cache.nout, size(x, 2)) + dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) : similar(dualp, cache.nout, size(x, ndims(x))) else dx = similar(dualp, cache.nout) end @@ -49,7 +49,7 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p) ys = cache.f(x, dualp) if cache.batch > 0 - out = similar(p, V, nout, size(x, 2)) + out = similar(p, V, nout, size(x, ndims(x))) else out = similar(p, V, nout) end diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index a5c41870..8c0c284a 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -3,37 +3,44 @@ using Integrals if isdefined(Base, :get_extension) using Zygote import ChainRulesCore - import ChainRulesCore: NoTangent + import ChainRulesCore: NoTangent, ProjectTo else using ..Zygote import ..Zygote.ChainRulesCore - import ..Zygote.ChainRulesCore: NoTangent + import ..Zygote.ChainRulesCore: NoTangent, ProjectTo end ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) +ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99 function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, p; kwargs...) out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) + # the adjoint will be the integral of the input sensitivities, so it maps the + # sensitivity of the output to an object of the type of the parameters function quadrature_adjoint(Δ) - y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ + # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes + y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar + # this will not be type-stable, but I believe it is unavoidable due to two ambiguities: + # 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the + # output of the algorithm must be a scalar or a vector of length 1 + # 2. when nout = 1 the integrand can either be a scalar or a vector of length 1 if isinplace(cache) dx = zeros(cache.nout) _f = x -> cache.f(dx, x, p) if sensealg.vjp isa Integrals.ZygoteVJP dfdp = function (dx, x, p) - _, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(x, cache.nout, size(x, 2)) + z, back = Zygote.pullback(p) do p + _dx = cache.nout == 1 ? Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) : Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x))) cache.f(_dx, x, p) copy(_dx) end - - z = zeros(size(x, 2)) - for idx in 1:size(x, 2) - z[1] = 1 - dx[:, idx] = back(z)[1] - z[idx] = 0 + z .= zero(eltype(z)) + for idx in 1:size(x, ndims(x)) + z isa Vector ? (z[idx] = y) : (z[:,idx] .= y) + dx[:, idx] .= back(z)[1] + z isa Vector ? (z[idx] = zero(eltype(z))) : (z[:, idx] .= zero(eltype(z))) end end elseif sensealg.vjp isa Integrals.ReverseDiffVJP @@ -44,14 +51,19 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal if sensealg.vjp isa Integrals.ZygoteVJP if cache.batch > 0 dfdp = function (x, p) - _, back = Zygote.pullback(p -> cache.f(x, p), p) + z, back = Zygote.pullback(p -> cache.f(x, p), p) + # messy, there are 4 cases, some better in forward mode than reverse + # 1: length(y) == 1 and length(p) == 1 + # 2: length(y) > 1 and length(p) == 1 + # 3: length(y) == 1 and length(p) > 1 + # 4: length(y) > 1 and length(p) > 1 - out = zeros(length(p), size(x, 2)) - z = zeros(size(x, 2)) - for idx in 1:size(x, 2) - z[idx] = 1 - out[:, idx] = back(z)[1] - z[idx] = 0 + z .= zero(eltype(z)) + out = zeros(eltype(p), size(p)..., size(x, ndims(x))) + for idx in 1:size(x, ndims(x)) + z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) + out isa Vector ? (out[idx] = back(z)[1]) : (out[:, idx] .= back(z)[1]) + z isa Vector ? (z[idx] = zero(y)) : (z[:, idx] .= zero(eltype(y))) end out end @@ -76,17 +88,24 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal do_inf_transformation = Val(false), cache.kwargs...) - if p isa Number - dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1] - else - dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u - end + project_p = ProjectTo(p) + dp = project_p(Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u) if lb isa Number - dlb = -_f(lb) - dub = _f(ub) + dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb) + dub = cache.batch > 0 ? _f([ub]) : _f(ub) return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp) else + # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we + # can see from writing the multidimensional integral as an iterated integral + # alternatively we can use Stokes' theorem to replace the integral on the + # boundary with a volume integral of the flux of the integrand + # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the + # dimensionality of the integral or the quadrature used (such as quadratures + # that don't evaluate points on the boundaries) and it could be generalized to + # other kinds of domains. The only question is to determine ω in terms of f and + # the deformation of the surface (e.g. consider integral over an ellipse and + # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) end @@ -94,8 +113,8 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal out, quadrature_adjoint end -Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, - ::Val{:u}) - sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) -end +# Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, +# ::Val{:u}) +# sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) +# end end diff --git a/lib/IntegralsCubature/src/IntegralsCubature.jl b/lib/IntegralsCubature/src/IntegralsCubature.jl index 120fdb62..a97af4b8 100644 --- a/lib/IntegralsCubature/src/IntegralsCubature.jl +++ b/lib/IntegralsCubature/src/IntegralsCubature.jl @@ -54,6 +54,10 @@ function Integrals.__solvebp_call(prob::IntegralProblem, maxiters = typemax(Int)) nout = prob.nout if nout == 1 + # the output of prob.f could be either scalar or a vector of length 1, however + # the behavior of the output of the integration routine is undefined (could differ + # across algorithms) + # Cubature will output a real number in when called without nout/fdim if prob.batch == 0 if isinplace(prob) dx = zeros(eltype(lb), prob.nout) @@ -63,74 +67,53 @@ function Integrals.__solvebp_call(prob::IntegralProblem, end if lb isa Number if alg isa CubatureJLh - _val, err = Cubature.hquadrature(f, lb, ub; + val, err = Cubature.hquadrature(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) else - _val, err = Cubature.pquadrature(f, lb, ub; + val, err = Cubature.pquadrature(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) end - val = prob.f(lb, p) isa Number ? _val : [_val] else if alg isa CubatureJLh - _val, err = Cubature.hcubature(f, lb, ub; + val, err = Cubature.hcubature(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) else - _val, err = Cubature.pcubature(f, lb, ub; + val, err = Cubature.pcubature(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) end - if isinplace(prob) || !isa(prob.f(lb, p), Number) - val = [_val] - else - val = _val - end end else if isinplace(prob) - f = (x, dx) -> prob.f(dx', x, p) - elseif lb isa Number - if prob.f([lb ub], p) isa Vector - f = (x, dx) -> (dx .= prob.f(x', p)) - else - f = function (x, dx) - dx[:] = prob.f(x', p) - end - end + f = (x, dx) -> prob.f(dx, x, p) else - if prob.f([lb ub], p) isa Vector - f = (x, dx) -> (dx .= prob.f(x, p)) - else - f = function (x, dx) - dx .= prob.f(x, p)[:] - end - end + f = (x, dx) -> (dx .= prob.f(x, p)) end if lb isa Number if alg isa CubatureJLh - _val, err = Cubature.hquadrature_v(f, lb, ub; + val, err = Cubature.hquadrature_v(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) else - _val, err = Cubature.pquadrature_v(f, lb, ub; + val, err = Cubature.pquadrature_v(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) end else if alg isa CubatureJLh - _val, err = Cubature.hcubature_v(f, lb, ub; + val, err = Cubature.hcubature_v(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) else - _val, err = Cubature.pcubature_v(f, lb, ub; + val, err = Cubature.pcubature_v(f, lb, ub; reltol = reltol, abstol = abstol, maxevals = maxiters) end end - val = _val isa Number ? [_val] : _val end else if prob.batch == 0 @@ -166,13 +149,9 @@ function Integrals.__solvebp_call(prob::IntegralProblem, end else if isinplace(prob) - f = (x, dx) -> prob.f(dx, x, p) + f = (x, dx) -> (prob.f(dx, x, p); dx) else - if lb isa Number - f = (x, dx) -> (dx .= prob.f(x', p)) - else - f = (x, dx) -> (dx .= prob.f(x, p)) - end + f = (x, dx) -> (dx .= prob.f(x, p)) end if lb isa Number diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 8eeb5f0a..f6a95857 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -1,4 +1,4 @@ -using Integrals, Zygote, FiniteDiff, ForwardDiff, SciMLSensitivity +using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity using IntegralsCuba, IntegralsCubature using Test @@ -117,7 +117,7 @@ dp4 = ForwardDiff.gradient(p -> testf(lb, ub, p), p) @test dp1 ≈ dp4 ### Batch Single dim -f(x, p) = x * p[1] .+ p[2] * p[3] +f(x, p) = x * p[1] .+ p[2] * p[3] # scalar integrand lb = 1.0 ub = 3.0 @@ -130,14 +130,14 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)") +dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)") dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 #passes -@test_broken dp2 ≈ dp3 #passes +@test dp2 ≈ dp3 #passes ### Batch single dim, nout -f(x, p) = (x * p[1] .+ p[2] * p[3]) .* [1; 2] +f(x, p) = (x' * p[1] .+ p[2] * p[3]) .* [1; 2] lb = 1.0 ub = 3.0 @@ -150,11 +150,11 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 #passes -# @test dp2 ≈ dp3 #passes +@test dp2 ≈ dp3 #passes ### Batch multi dim f(x, p) = x[1, :] * p[1] .+ p[2] * p[3] @@ -190,15 +190,15 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 -# @test dp2 ≈ dp3 +@test dp2 ≈ dp3 -## iip Batch mulit dim +## iip Batch multi dim function g(dx, x, p) - dx .= sum(x * p[1] .+ p[2] * p[3], dims = 1) + dx .= dropdims(sum(x * p[1] .+ p[2] * p[3], dims = 1), dims = 1) end lb = [1.0, 1.0] @@ -236,8 +236,8 @@ function testf3(lb, ub, p; f = g) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 -# @test dp2 ≈ dp3 +@test dp2 ≈ dp3 diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 6789f653..2d597adc 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -50,14 +50,15 @@ exact_sol_v = [ ] batch_f(f) = (pts, p) -> begin - fevals = zeros(size(pts, 2)) - for i in 1:size(pts, 2) - x = pts[:, i] + fevals = zeros(size(pts, ndims(pts))) + for i in axes(pts, ndims(pts)) + x = pts isa Vector ? pts[i] : pts[:, i] fevals[i] = f(x, p) end fevals end +# TODO ? check if pts is a vector or matrix batch_iip_f(f) = (fevals, pts, p) -> begin for i in 1:size(pts, 2) x = pts[:, i] @@ -67,18 +68,19 @@ batch_iip_f(f) = (fevals, pts, p) -> begin end batch_f_v(f, nout) = (pts, p) -> begin - fevals = zeros(nout, size(pts, 2)) - for i in 1:size(pts, 2) - x = pts[:, i] - fevals[:, i] = f(x, p, nout) + fevals = zeros(nout, size(pts, ndims(pts))) + for i in axes(pts, ndims(pts)) + x = pts isa Vector ? pts[i] : pts[:, i] + fevals[:, i] .= f(x, p, nout) end fevals end +# TODO ? check if pts is a vector or matrix batch_iip_f_v(f, nout) = (fevals, pts, p) -> begin for i in 1:size(pts, 2) x = pts[:, i] - fevals[:, i] = f(x, p, nout) + fevals[:, i] .= f(x, p, nout) end nothing end @@ -158,7 +160,7 @@ end end @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" sol = solve(prob, alg, reltol = reltol, abstol = abstol) - @test sol.u≈[exact_sol[i](dim, nout, lb, ub)] rtol=1e-2 + @test sol.u[1]≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 end end end @@ -225,7 +227,11 @@ end end @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" sol = solve(prob, alg, reltol = reltol, abstol = abstol) - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 + if nout == 1 + @test sol.u[1]≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 + else + @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 + end end end end @@ -247,8 +253,8 @@ end nout = nout) @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" sol = solve(prob, alg, reltol = reltol, abstol = abstol) - if sol.u isa Number - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 + if nout == 1 + @test sol.u[1]≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 else @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 end @@ -277,8 +283,8 @@ end else sol = solve(prob, alg, reltol = reltol, abstol = abstol) end - if sol.u isa Number - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 + if nout == 1 + @test sol.u[1]≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 else @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 end From 058941a34efefad97233bf0cd0186eceb84215cd Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 16 Sep 2023 14:16:14 -0400 Subject: [PATCH 2/3] apply format --- ext/IntegralsForwardDiffExt.jl | 3 ++- ext/IntegralsZygoteExt.jl | 23 ++++++++++++++----- .../src/IntegralsCubature.jl | 1 - test/derivative_tests.jl | 8 +++---- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index 307b4b3e..83d4a064 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -29,7 +29,8 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, dfdp = function (out, x, p) dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p) if cache.batch > 0 - dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) : similar(dualp, cache.nout, size(x, ndims(x))) + dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) : + similar(dualp, cache.nout, size(x, ndims(x))) else dx = similar(dualp, cache.nout) end diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 8c0c284a..58d732ba 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -32,15 +32,18 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal if sensealg.vjp isa Integrals.ZygoteVJP dfdp = function (dx, x, p) z, back = Zygote.pullback(p) do p - _dx = cache.nout == 1 ? Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) : Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x))) + _dx = cache.nout == 1 ? + Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) : + Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x))) cache.f(_dx, x, p) copy(_dx) end z .= zero(eltype(z)) for idx in 1:size(x, ndims(x)) - z isa Vector ? (z[idx] = y) : (z[:,idx] .= y) + z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) dx[:, idx] .= back(z)[1] - z isa Vector ? (z[idx] = zero(eltype(z))) : (z[:, idx] .= zero(eltype(z))) + z isa Vector ? (z[idx] = zero(eltype(z))) : + (z[:, idx] .= zero(eltype(z))) end end elseif sensealg.vjp isa Integrals.ReverseDiffVJP @@ -62,8 +65,10 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal out = zeros(eltype(p), size(p)..., size(x, ndims(x))) for idx in 1:size(x, ndims(x)) z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) - out isa Vector ? (out[idx] = back(z)[1]) : (out[:, idx] .= back(z)[1]) - z isa Vector ? (z[idx] = zero(y)) : (z[:, idx] .= zero(eltype(y))) + out isa Vector ? (out[idx] = back(z)[1]) : + (out[:, idx] .= back(z)[1]) + z isa Vector ? (z[idx] = zero(y)) : + (z[:, idx] .= zero(eltype(y))) end out end @@ -89,7 +94,13 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal cache.kwargs...) project_p = ProjectTo(p) - dp = project_p(Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u) + dp = project_p(Integrals.__solvebp_call(dp_cache, + alg, + sensealg, + lb, + ub, + p; + kwargs...).u) if lb isa Number dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb) diff --git a/lib/IntegralsCubature/src/IntegralsCubature.jl b/lib/IntegralsCubature/src/IntegralsCubature.jl index a97af4b8..384a5008 100644 --- a/lib/IntegralsCubature/src/IntegralsCubature.jl +++ b/lib/IntegralsCubature/src/IntegralsCubature.jl @@ -85,7 +85,6 @@ function Integrals.__solvebp_call(prob::IntegralProblem, reltol = reltol, abstol = abstol, maxevals = maxiters) end - end else if isinplace(prob) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index f6a95857..400e1bbe 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -130,7 +130,7 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)") +dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)") dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 #passes @@ -150,7 +150,7 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 #passes @@ -190,7 +190,7 @@ function testf3(lb, ub, p; f = f) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 @@ -236,7 +236,7 @@ function testf3(lb, ub, p; f = g) end dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p) -dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] +dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p) @test dp1 ≈ dp3 From 2d3ee8eb25258eec31b4a09b82204ededd29c23a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 17 Sep 2023 01:45:47 -0400 Subject: [PATCH 3/3] Update ext/IntegralsZygoteExt.jl --- ext/IntegralsZygoteExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 58d732ba..986c522d 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -124,8 +124,8 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal out, quadrature_adjoint end -# Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, -# ::Val{:u}) -# sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) -# end +Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, + ::Val{:u}) + sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) +end end