From 32ba9950478d8d35f332fa5083852a04de8a5d00 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 12 May 2023 02:20:44 +0200 Subject: [PATCH] Add derivative of `logabsgamma` --- Project.toml | 2 +- src/lib/real.jl | 10 ++++++++++ test/tracker.jl | 5 +++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 81402f1..d2f65ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Tracker" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.24" +version = "0.2.25" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lib/real.jl b/src/lib/real.jl index 569e9f4..83d399b 100644 --- a/src/lib/real.jl +++ b/src/lib/real.jl @@ -149,3 +149,13 @@ end collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs) collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs) + +# `logabsgamma` returns a tuple and hence its derivative is not defined in DiffRules +SpecialFunctions.logabsgamma(x::TrackedReal) = track(SpecialFunctions.logabsgamma, x) +@grad function SpecialFunctions.logabsgamma(x::Real) + data_x = data(x) + function logabsgamma_pullback(Δ) + return (SpecialFunctions.digamma(data_x) * first(Δ),) + end + return SpecialFunctions.logabsgamma(data_x), logabsgamma_pullback +end diff --git a/test/tracker.jl b/test/tracker.jl index a6169a3..c6f03bd 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -2,6 +2,7 @@ using Tracker, Test, NNlib using Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff using NNlib: conv, ∇conv_data, depthwiseconv using PDMats +using SpecialFunctions: logabsgamma using Printf: @sprintf using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet, I, Diagonal using Statistics: mean, std, var @@ -504,3 +505,7 @@ end @test size(y) == (5, 3) end +@testset "logabsgamma" begin + @test gradcheck(x -> logabsgamma(only(x))[1], rand(1)) + @test gradcheck(x -> logabsgamma(only(x))[2], rand(1)) +end