diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ec3aeaabf..881f4e9ad 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: version: - - "1.0" # LTS + - "1.6" # LTS - "1" # Latest Release os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 2dfec9103..55a1947f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.21" +version = "1.22" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -15,10 +15,10 @@ ChainRulesCore = "1.11.5" ChainRulesTestUtils = "1" Compat = "3.35" FiniteDifferences = "0.12.20" -JuliaInterpreter = "0.8" +JuliaInterpreter = "0.8" # latest is "0.9.1" RealDot = "0.1" StaticArrays = "1.2" -julia = "1" +julia = "1.6" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index f2806fd3d..197ce7ccb 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -13,13 +13,6 @@ using Statistics # to the normal rule of only overload via `ChainRulesCore.rrule`. import ChainRulesCore: rrule, frule -if VERSION < v"1.3.0-DEV.142" - # In prior versions, the BLAS submodule also exported `dot`, which caused a conflict - # with its parent module. To get around this, we can simply create a hard binding for - # the one we want to use without qualification. - import LinearAlgebra: dot -end - # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 7eb0a2fdc..e95ff6eff 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -148,9 +148,7 @@ end @scalar_rule sinc(x) cosc(x) # the position of the minus sign below warrants the correct type for π -if VERSION ≥ v"1.6" - @scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) -end +@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) @scalar_rule( clamp(x, low, high), diff --git a/src/rulesets/Base/evalpoly.jl b/src/rulesets/Base/evalpoly.jl index 014374097..ab06e6a88 100644 --- a/src/rulesets/Base/evalpoly.jl +++ b/src/rulesets/Base/evalpoly.jl @@ -1,151 +1,149 @@ -if VERSION ≥ v"1.4" - function frule((_, ẋ, ṗ), ::typeof(evalpoly), x, p::Union{Tuple,AbstractVector}) - Δx, Δp = ẋ, unthunk(ṗ) - N = length(p) - @inbounds y = p[N] - Δy = Δp[N] - @inbounds for i in (N - 1):-1:1 - Δy = muladd(Δx, y, muladd(x, Δy, Δp[i])) - y = muladd(x, y, p[i]) - end - return y, Δy +function frule((_, ẋ, ṗ), ::typeof(evalpoly), x, p::Union{Tuple,AbstractVector}) + Δx, Δp = ẋ, unthunk(ṗ) + N = length(p) + @inbounds y = p[N] + Δy = Δp[N] + @inbounds for i in (N - 1):-1:1 + Δy = muladd(Δx, y, muladd(x, Δy, Δp[i])) + y = muladd(x, y, p[i]) end + return y, Δy +end - function rrule(::typeof(evalpoly), x, p::Union{Tuple,AbstractVector}) - y, ys = _evalpoly_intermediates(x, p) - project_x = ProjectTo(x) - project_p = p isa Tuple ? identity : ProjectTo(p) - function evalpoly_pullback(Δy) - ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) - return NoTangent(), project_x(∂x), project_p(∂p) - end - return y, evalpoly_pullback +function rrule(::typeof(evalpoly), x, p::Union{Tuple,AbstractVector}) + y, ys = _evalpoly_intermediates(x, p) + project_x = ProjectTo(x) + project_p = p isa Tuple ? identity : ProjectTo(p) + function evalpoly_pullback(Δy) + ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) + return NoTangent(), project_x(∂x), project_p(∂p) end + return y, evalpoly_pullback +end - function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo - y, ys = _evalpoly_intermediates(x, p) - function evalpoly_pullback(Δy) - ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) - return NoTangent(), ∂x, ∂p - end - return y, evalpoly_pullback +function rrule(::typeof(evalpoly), x, p::Vector{<:Matrix}) # does not type infer with ProjectTo + y, ys = _evalpoly_intermediates(x, p) + function evalpoly_pullback(Δy) + ∂x, ∂p = _evalpoly_back(x, p, ys, Δy) + return NoTangent(), ∂x, ∂p end + return y, evalpoly_pullback +end - # evalpoly but storing intermediates - function _evalpoly_intermediates(x, p::Tuple) - return if @generated - N = length(p.parameters) - exs = [] - vars = [] - ex = :(p[$N]) - for i in 1:(N - 1) - yi = Symbol("y", i) - push!(vars, yi) - push!(exs, :($yi = $ex)) - ex = :(muladd(x, $yi, p[$(N - i)])) - end - push!(exs, :(y = $ex)) - Expr(:block, exs..., :(y, ($(vars...),))) - else - _evalpoly_intermediates_fallback(x, p) +# evalpoly but storing intermediates +function _evalpoly_intermediates(x, p::Tuple) + return if @generated + N = length(p.parameters) + exs = [] + vars = [] + ex = :(p[$N]) + for i in 1:(N - 1) + yi = Symbol("y", i) + push!(vars, yi) + push!(exs, :($yi = $ex)) + ex = :(muladd(x, $yi, p[$(N - i)])) end + push!(exs, :(y = $ex)) + Expr(:block, exs..., :(y, ($(vars...),))) + else + _evalpoly_intermediates_fallback(x, p) end - function _evalpoly_intermediates_fallback(x, p::Tuple) - N = length(p) - y = p[N] - ys = (y, ntuple(N - 2) do i - return y = muladd(x, y, p[N - i]) - end...) - y = muladd(x, y, p[1]) - return y, ys - end - function _evalpoly_intermediates(x, p) - N = length(p) - @inbounds yn = one(x) * p[N] - ys = similar(p, typeof(yn), N - 1) - @inbounds ys[1] = yn - @inbounds for i in 2:(N - 1) - ys[i] = muladd(x, ys[i - 1], p[N - i + 1]) - end - @inbounds y = muladd(x, ys[N - 1], p[1]) - return y, ys +end +function _evalpoly_intermediates_fallback(x, p::Tuple) + N = length(p) + y = p[N] + ys = (y, ntuple(N - 2) do i + return y = muladd(x, y, p[N - i]) + end...) + y = muladd(x, y, p[1]) + return y, ys +end +function _evalpoly_intermediates(x, p) + N = length(p) + @inbounds yn = one(x) * p[N] + ys = similar(p, typeof(yn), N - 1) + @inbounds ys[1] = yn + @inbounds for i in 2:(N - 1) + ys[i] = muladd(x, ys[i - 1], p[N - i + 1]) end + @inbounds y = muladd(x, ys[N - 1], p[1]) + return y, ys +end - # TODO: Handle following cases - # 1) x is a UniformScaling, pᵢ is a matrix - # 2) x is a matrix, pᵢ is a UniformScaling - @inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi' - @inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x) - @inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi)) - @inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x +# TODO: Handle following cases +# 1) x is a UniformScaling, pᵢ is a matrix +# 2) x is a matrix, pᵢ is a UniformScaling +@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi' +@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x) +@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi)) +@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x - @inline _evalpoly_backp(pi, ∂yi) = ∂yi +@inline _evalpoly_backp(pi, ∂yi) = ∂yi - function _evalpoly_back(x, p::Tuple, ys, Δy) - return if @generated - exs = [] - vars = [] - N = length(p.parameters) - for i in 2:(N - 1) - ∂pi = Symbol("∂p", i) - push!(vars, ∂pi) - push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi))) - push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi))) - push!(exs, :(∂yi = x′ * ∂yi)) - end - push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN - Expr( - :block, - :(x′ = x'), - :(∂yi = Δy), - :(∂p1 = _evalpoly_backp(p[1], ∂yi)), - :(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)), - :(∂yi = x′ * ∂yi), - exs..., - :(∂p = (∂p1, $(vars...))), - :(∂x, Tangent{typeof(p),typeof(∂p)}(∂p)), - ) - else - _evalpoly_back_fallback(x, p, ys, Δy) +function _evalpoly_back(x, p::Tuple, ys, Δy) + return if @generated + exs = [] + vars = [] + N = length(p.parameters) + for i in 2:(N - 1) + ∂pi = Symbol("∂p", i) + push!(vars, ∂pi) + push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi))) + push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi))) + push!(exs, :(∂yi = x′ * ∂yi)) end + push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN + Expr( + :block, + :(x′ = x'), + :(∂yi = Δy), + :(∂p1 = _evalpoly_backp(p[1], ∂yi)), + :(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)), + :(∂yi = x′ * ∂yi), + exs..., + :(∂p = (∂p1, $(vars...))), + :(∂x, Tangent{typeof(p),typeof(∂p)}(∂p)), + ) + else + _evalpoly_back_fallback(x, p, ys, Δy) end - function _evalpoly_back_fallback(x, p::Tuple, ys, Δy) - x′ = x' - ∂yi = unthunk(Δy) - N = length(p) - ∂p1 = _evalpoly_backp(p[1], ∂yi) +end +function _evalpoly_back_fallback(x, p::Tuple, ys, Δy) + x′ = x' + ∂yi = unthunk(Δy) + N = length(p) + ∂p1 = _evalpoly_backp(p[1], ∂yi) + ∂x = _evalpoly_backx(x, ys[N - 1], ∂yi) + ∂yi = x′ * ∂yi + ∂p = ( + ∂p1, + ntuple(N - 2) do i + ∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi) + ∂pi = _evalpoly_backp(p[i+1], ∂yi) + ∂yi = x′ * ∂yi + return ∂pi + end..., + _evalpoly_backp(p[N], ∂yi), # ∂pN + ) + return ∂x, Tangent{typeof(p),typeof(∂p)}(∂p) +end +function _evalpoly_back(x, p, ys, Δy) + x′ = x' + ∂yi = one(x′) * Δy + N = length(p) + @inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi) + ∂p = similar(p, typeof(∂p1)) + @inbounds begin ∂x = _evalpoly_backx(x, ys[N - 1], ∂yi) ∂yi = x′ * ∂yi - ∂p = ( - ∂p1, - ntuple(N - 2) do i - ∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi) - ∂pi = _evalpoly_backp(p[i+1], ∂yi) - ∂yi = x′ * ∂yi - return ∂pi - end..., - _evalpoly_backp(p[N], ∂yi), # ∂pN - ) - return ∂x, Tangent{typeof(p),typeof(∂p)}(∂p) - end - function _evalpoly_back(x, p, ys, Δy) - x′ = x' - ∂yi = one(x′) * Δy - N = length(p) - @inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi) - ∂p = similar(p, typeof(∂p1)) - @inbounds begin - ∂x = _evalpoly_backx(x, ys[N - 1], ∂yi) + ∂p[1] = ∂p1 + for i in 2:(N - 1) + ∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi) + ∂p[i] = _evalpoly_backp(p[i], ∂yi) ∂yi = x′ * ∂yi - ∂p[1] = ∂p1 - for i in 2:(N - 1) - ∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi) - ∂p[i] = _evalpoly_backp(p[i], ∂yi) - ∂yi = x′ * ∂yi - end - ∂p[N] = _evalpoly_backp(p[N], ∂yi) end - return ∂x, ∂p + ∂p[N] = _evalpoly_backp(p[N], ∂yi) end + return ∂x, ∂p end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 11b03d962..21bf194be 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -115,12 +115,10 @@ function frule( y = sum(abs2, x; dims=dims) ∂y = if dims isa Colon 2 * realdot(x, ẋ) - elseif VERSION ≥ v"1.2" # multi-iterator mapreduce introduced in v1.2 + else mapreduce(+, x, ẋ; dims=dims) do xi, dxi 2 * realdot(xi, dxi) end - else - 2 * sum(realdot.(x, ẋ); dims=dims) end return y, ∂y end @@ -419,16 +417,10 @@ end # To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps # be replaced by _peel1 like Iterators.peel -if VERSION >= v"1.6" - _reverse1(x) = Iterators.reverse(x) - _drop1(x) = Iterators.drop(x, 1) - _zip2(x, y) = zip(x, y) # for `accumulate`, below -else - # Old versions don't support accumulate(::itr), nor multi-dim reverse - _reverse1(x) = reverse(vec(x)) - _drop1(x) = vec(x)[2:end] - _zip2(x, y) = collect(zip(x, y)) -end +_reverse1(x) = Iterators.reverse(x) +_drop1(x) = Iterators.drop(x, 1) +_zip2(x, y) = zip(x, y) # for `accumulate`, below + _reverse1(x::Tuple) = reverse(x) _drop1(x::Tuple) = Base.tail(x) _zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) @@ -480,16 +472,9 @@ function rrule( function decumulate(dy) dy_plain = _no_tuple_tangent(unthunk(dy)) rev_list = if init === _InitialValue() - if VERSION >= v"1.6" - # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) - # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" - _zip2(_reverse1(hobbits), _reverse1(dy_plain)) - else - # However, on 1.0 and some others, zip does not stop early. But since accumulate - # also doesn't work on iterators, `_drop1` doesn't make one, so this should work: - _zip2(_reverse1(hobbits), _reverse1(_drop1(dy_plain))) - # What an awful tangle. - end + # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) + # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" + _zip2(_reverse1(hobbits), _reverse1(dy_plain)) else _zip2(_reverse1(hobbits), _reverse1(dy_plain)) end diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 0336a2ba1..f4764b54e 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -155,7 +155,7 @@ @non_differentiable fd(::Base.Filesystem.File) @non_differentiable fd(::IOStream) @non_differentiable fieldtype(T, ::Union{Symbol, Integer}) -VERSION >= v"1.1" && @non_differentiable fieldtypes(T) +@non_differentiable fieldtypes(T) @non_differentiable fieldname(T, ::Integer) @non_differentiable fieldnames(T) @@ -199,8 +199,8 @@ VERSION >= v"1.1" && @non_differentiable fieldtypes(T) @non_differentiable ignorestatus(::Cmd) @non_differentiable in(::Any, ::Any) -VERSION >= v"1.6" && @non_differentiable insorted(::Any, ::AbstractVector) -VERSION >= v"1.6" && @non_differentiable insorted(::Any, ::AbstractRange) +@non_differentiable insorted(::Any, ::AbstractVector) +@non_differentiable insorted(::Any, ::AbstractRange) @non_differentiable include_dependency(::AbstractString) @non_differentiable isa(::Any, ::Any) @non_differentiable isabspath(::AbstractString) @@ -219,7 +219,7 @@ VERSION >= v"1.6" && @non_differentiable insorted(::Any, ::AbstractRange) @non_differentiable isdigit(::AbstractChar) @non_differentiable isdir(::Any...) @non_differentiable isdirpath(::AbstractString) -VERSION >= v"1.5" && @non_differentiable isdisjoint(::Any, ::Any) +@non_differentiable isdisjoint(::Any, ::Any) @non_differentiable isdispatchtuple(::Any) @non_differentiable isempty(::Any) @non_differentiable isequal(::Any) @@ -238,9 +238,9 @@ VERSION >= v"1.5" && @non_differentiable isdisjoint(::Any, ::Any) @non_differentiable ismarked(::IO) @non_differentiable ismissing(::Any) @non_differentiable ismount(::Any...) -VERSION >= v"1.5" && @non_differentiable ismutable(::Any) +@non_differentiable ismutable(::Any) @non_differentiable isnan(::Any) -VERSION >= v"1.1" && @non_differentiable isnothing(::Any) +@non_differentiable isnothing(::Any) @non_differentiable isnumeric(::AbstractChar) @non_differentiable isodd(::Any) @non_differentiable isone(::Any) @@ -268,7 +268,7 @@ VERSION >= v"1.1" && @non_differentiable isnothing(::Any) @non_differentiable issubnormal(::Any) @non_differentiable issubset(::Any, ::Any) @non_differentiable istaskdone(::Task) -VERSION >= v"1.3" && @non_differentiable istaskfailed(::Task) +@non_differentiable istaskfailed(::Task) @non_differentiable istaskstarted(::Task) @non_differentiable istextmime(::AbstractString) @non_differentiable isuppercase(::AbstractChar) @@ -325,7 +325,7 @@ VERSION >= v"1.3" && @non_differentiable istaskfailed(::Task) @non_differentiable occursin(::Union{AbstractChar, AbstractString}, ::AbstractString) @non_differentiable one(::Any) @non_differentiable ones(::Any...) -VERSION >= v"1.4" && @non_differentiable only(::Char) +@non_differentiable only(::Char) @non_differentiable open(::Any) @non_differentiable partialsortperm(::AbstractVector, ::Union{Integer, OrdinalRange}) @@ -398,7 +398,7 @@ end @non_differentiable splitdrive(::AbstractString) @non_differentiable splitext(::AbstractString) @non_differentiable startswith(::AbstractString, ::AbstractString) -VERSION >= v"1.1" && @non_differentiable splitpath(::AbstractString) +@non_differentiable splitpath(::AbstractString) @non_differentiable startswith(::AbstractString, ::Regex) @non_differentiable stat(::AbstractString) @non_differentiable stat(::Base.Filesystem.File) @@ -451,9 +451,9 @@ VERSION >= v"1.1" && @non_differentiable splitpath(::AbstractString) @non_differentiable Base.time_ns() @non_differentiable Base.typename(::Any) @non_differentiable Base.depwarn(::Any...) -VERSION >= v"1.6" && @non_differentiable Base.cumulative_compile_time_ns_before() -VERSION >= v"1.6" && @non_differentiable Base.cumulative_compile_time_ns_after() -VERSION >= v"1.5" && @non_differentiable Base.time_print(::Any...) +@non_differentiable Base.cumulative_compile_time_ns_before() +@non_differentiable Base.cumulative_compile_time_ns_after() +@non_differentiable Base.time_print(::Any...) @non_differentiable Broadcast.combine_styles(::Any...) @non_differentiable Broadcast.result_style(::Any) @@ -470,13 +470,13 @@ VERSION >= v"1.5" && @non_differentiable Base.time_print(::Any...) @non_differentiable Sys.cpu_summary(::IO) @non_differentiable Sys.isapple(::Symbol) @non_differentiable Sys.isbsd(::Symbol) -VERSION >= v"1.1" && @non_differentiable Sys.isdragonfly(::Symbol) +@non_differentiable Sys.isdragonfly(::Symbol) @non_differentiable Sys.isexecutable(::AbstractString) -VERSION >= v"1.1" && @non_differentiable Sys.isfreebsd(::Symbol) -VERSION >= v"1.2" && @non_differentiable Sys.isjsvm(::Symbol) +@non_differentiable Sys.isfreebsd(::Symbol) +@non_differentiable Sys.isjsvm(::Symbol) @non_differentiable Sys.islinux(::Symbol) -VERSION >= v"1.1" && @non_differentiable Sys.isnetbsd(::Symbol) -VERSION >= v"1.1" && @non_differentiable Sys.isopenbsd(::Symbol) +@non_differentiable Sys.isnetbsd(::Symbol) +@non_differentiable Sys.isopenbsd(::Symbol) @non_differentiable Sys.isunix(::Symbol) @non_differentiable Sys.iswindows(::Symbol) @non_differentiable Sys.which(::AbstractString) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index d7530a16e..41bc24b8c 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -286,7 +286,7 @@ end function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} ΔA isa AbstractZero && return (eigen!(A; kwargs...), ΔA) if ishermitian(A) - sortby = get(kwargs, :sortby, VERSION ≥ v"1.2.0" ? LinearAlgebra.eigsortby : nothing) + sortby = get(kwargs, :sortby, LinearAlgebra.eigsortby) return if sortby === nothing frule((ZeroTangent(), Hermitian(ΔA)), eigen!, Hermitian(A)) else @@ -398,7 +398,7 @@ function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) whe ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA if ishermitian(A) λ, ∂λ = frule((ZeroTangent(), Hermitian(ΔA)), eigvals!, Hermitian(A)) - sortby = get(kwargs, :sortby, VERSION ≥ v"1.2.0" ? LinearAlgebra.eigsortby : nothing) + sortby = get(kwargs, :sortby, LinearAlgebra.eigsortby) _sorteig!_fwd(∂λ, λ, sortby) else F = eigen!(A; kwargs...) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 513f8e3fb..1a5a4bcd0 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -71,21 +71,21 @@ function rrule(::typeof(diag), A::AbstractMatrix) end return diag(A), diag_pullback end -if VERSION ≥ v"1.3" - function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) - function diag_pullback(ȳ) - return (NoTangent(), diagm(size(A)..., k => ȳ), NoTangent()) - end - return diag(A, k), diag_pullback + +function rrule(::typeof(diag), A::AbstractMatrix, k::Integer) + function diag_pullback(ȳ) + return (NoTangent(), diagm(size(A)..., k => ȳ), NoTangent()) end + return diag(A, k), diag_pullback +end - function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) - function diagm_pullback(ȳ) - return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) - end - return diagm(m, n, kv...), diagm_pullback +function rrule(::typeof(diagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + function diagm_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) end + return diagm(m, n, kv...), diagm_pullback end + function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...) function diagm_pullback(ȳ) return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) diff --git a/src/rulesets/Random/random.jl b/src/rulesets/Random/random.jl index 6ea2cc648..d1eeab514 100644 --- a/src/rulesets/Random/random.jl +++ b/src/rulesets/Random/random.jl @@ -37,7 +37,5 @@ end @non_differentiable copy(::AbstractRNG) @non_differentiable copy!(::AbstractRNG, ::AbstractRNG) -@static if VERSION > v"1.3" - @non_differentiable Random.default_rng() - @non_differentiable Random.default_rng(::Int) -end +@non_differentiable Random.default_rng() +@non_differentiable Random.default_rng(::Int) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 12ece43db..7a82bd350 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -29,7 +29,7 @@ end @testset "inhomogeneous type" begin test_rrule( Base.vect, 5.0, 3f0; - atol=1e-6, rtol=1e-6, check_inferred=VERSION>=v"1.6", + atol=1e-6, rtol=1e-6, ) # tolerance due to Float32. test_rrule(Base.vect, 5.0, randn(3, 3); check_inferred=false) test_rrule(Base.vect, (5.0, 4.0), (y=randn(3),); check_inferred=false) @@ -50,7 +50,7 @@ end # Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all @test invperm((3, 1, 2)) != (3, 1, 2) - test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2); check_inferred=VERSION>=v"1.1") + test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2)) @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) x = rand(2, 3, 4) @@ -65,29 +65,26 @@ end test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),)) test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3))) - test_rrule(repeat, rand(4, ), 2; check_inferred=VERSION>=v"1.6") - test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.6") - test_rrule(repeat, rand(4, 5), 2, 3; check_inferred=VERSION>=v"1.6") + test_rrule(repeat, rand(4, ), 2) + test_rrule(repeat, rand(4, 5), 2) + test_rrule(repeat, rand(4, 5), 2, 3) test_rrule(repeat, rand(1,2,3), 2,3,4; check_inferred=VERSION>v"1.6") test_rrule(repeat, rand(0,2,3), 2,0,4; check_inferred=VERSION>v"1.6") test_rrule(repeat, rand(1,1,1,1), 2,3,4,5; check_inferred=VERSION>v"1.6") + # These need Julia 1.6 + test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3))) + test_rrule(repeat, rand(1,2,3), 2,3) + test_rrule(repeat, rand(1,2,3), 2,3,4,2) + test_rrule(repeat, fill(1.0), 2) + test_rrule(repeat, fill(1.0), 2, 3) - if VERSION>=v"1.6" - # These are cases where repeat itself fails in earlier versions - test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3))) - test_rrule(repeat, rand(1,2,3), 2,3) - test_rrule(repeat, rand(1,2,3), 2,3,4,2) - test_rrule(repeat, fill(1.0), 2) - test_rrule(repeat, fill(1.0), 2, 3) + # These fail for other v1.0 related issues (add!!) + # v"1.0": fill(1.0) + fill(1.0) != fill(2.0) + # v"1.6: fill(1.0) + fill(1.0) == fill(2.0) # Expected + test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) + test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,)) - # These fail for other v1.0 related issues (add!!) - # v"1.0": fill(1.0) + fill(1.0) != fill(2.0) - # v"1.6: fill(1.0) + fill(1.0) == fill(2.0) # Expected - test_rrule(repeat, fill(1.0); fkwargs = (inner=2,)) - test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,)) - - end @test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4] @test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4] @@ -95,12 +92,12 @@ end end @testset "hcat" begin - test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1") - test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1") - test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2); check_inferred=VERSION>v"1.1") + test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3)) + test_rrule(hcat, rand(), rand(1,2), rand(1,2,1)) + test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2)) # mix types - test_rrule(hcat, rand(2, 2), rand(2, 2)'; check_inferred=VERSION>v"1.1") + test_rrule(hcat, rand(2, 2), rand(2, 2)') end @testset "reduce hcat" begin @@ -120,17 +117,17 @@ end # mix types mats = [randn(2, 2), rand(2, 2)'] - test_rrule(reduce, hcat, mats; check_inferred=VERSION>v"1.1") + test_rrule(reduce, hcat, mats) end @testset "vcat" begin - test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4); check_inferred=VERSION>v"1.1") - test_rrule(vcat, rand(), rand(); check_inferred=VERSION>v"1.1") - test_rrule(vcat, rand(), rand(3), rand(3,1,1); check_inferred=VERSION>v"1.1") - test_rrule(vcat, rand(3,1,2), rand(4,1,2); check_inferred=VERSION>v"1.1") + test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4)) + test_rrule(vcat, rand(), rand()) + test_rrule(vcat, rand(), rand(3), rand(3,1,1)) + test_rrule(vcat, rand(3,1,2), rand(4,1,2)) # mix types - test_rrule(vcat, rand(2, 2), rand(2, 2)'; check_inferred=VERSION>v"1.1") + test_rrule(vcat, rand(2, 2), rand(2, 2)') end @testset "reduce vcat" begin @@ -145,22 +142,22 @@ end end @testset "cat" begin - test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,), check_inferred=VERSION>v"1.1") - test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),), check_inferred=VERSION>v"1.1") - test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],), check_inferred=VERSION>v"1.1") + test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,)) + test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),)) + test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],)) test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any} - test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,), check_inferred=VERSION>v"1.1") + test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,)) end @testset "hvcat" begin - test_rrule(hvcat, 2, rand(ComplexF64, 6)...; check_inferred=VERSION>v"1.1") - test_rrule(hvcat, (2, 1), rand(), rand(1,1), rand(2,2); check_inferred=VERSION>v"1.1") - test_rrule(hvcat, 1, rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3); check_inferred=VERSION>v"1.1") - test_rrule(hvcat, 1, rand(0,3), rand(2,3), rand(1,3,1); check_inferred=VERSION>v"1.1") + test_rrule(hvcat, 2, rand(ComplexF64, 6)...) + test_rrule(hvcat, (2, 1), rand(), rand(1,1), rand(2,2)) + test_rrule(hvcat, 1, rand(3)' ⊢ rand(1,3), transpose(rand(3)) ⊢ rand(1,3)) + test_rrule(hvcat, 1, rand(0,3), rand(2,3), rand(1,3,1)) # mix types (adjoint and transpose) - test_rrule(hvcat, 1, rand(3)', transpose(rand(3)) ⊢ rand(1,3); check_inferred=VERSION>v"1.1") + test_rrule(hvcat, 1, rand(3)', transpose(rand(3)) ⊢ rand(1,3)) end @testset "reverse" begin @@ -175,10 +172,8 @@ end test_frule(reverse, rand(5), fkwargs=(dims=1,)) test_frule(reverse, rand(3,4), fkwargs=(dims=2,)) - if VERSION >= v"1.6" - test_frule(reverse, rand(3,4)) - test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) - end + test_frule(reverse, rand(3,4)) + test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) # Reverse test_rrule(reverse, rand(5)) @@ -186,16 +181,14 @@ end test_rrule(reverse, rand(5), fkwargs=(dims=1,)) test_rrule(reverse, rand(3,4), fkwargs=(dims=2,)) - if VERSION >= v"1.6" - test_rrule(reverse, rand(3,4)) - test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) - - # Structured - y, pb = rrule(reverse, Diagonal([1,2,3])) - # We only preserve structure in this case if given structured tangent (no ProjectTo) - @test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal - @test unthunk(pb(rand(3, 3))[2]) isa AbstractArray - end + test_rrule(reverse, rand(3,4)) + test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),)) + + # Structured + y, pb = rrule(reverse, Diagonal([1,2,3])) + # We only preserve structure in this case if given structured tangent (no ProjectTo) + @test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal + @test unthunk(pb(rand(3, 3))[2]) isa AbstractArray end end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index bb2324b9d..77dff1827 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -58,13 +58,11 @@ test_scalar(sinc, x) end - if VERSION ≥ v"1.6" - @testset "sincospi" for T in (Float64, ComplexF64) - Δz = Tangent{Tuple{T,T}}(randn(T), randn(T)) + @testset "sincospi" for T in (Float64, ComplexF64) + Δz = Tangent{Tuple{T,T}}(randn(T), randn(T)) - test_frule(sincospi, randn(T)) - test_rrule(sincospi, randn(T); output_tangent=Δz) - end + test_frule(sincospi, randn(T)) + test_rrule(sincospi, randn(T); output_tangent=Δz) end end # Trig @@ -108,10 +106,9 @@ # This promotion is only for FiniteDifferences, the rules allow mixtures: x, y = Base.promote(x, y) - # Inference fails on 1.0, passes on 1.6 - test_rrule(*, x, y, x+y; check_inferred=VERSION>v"1.5") - test_rrule(*, x, y, 17x, 23y; check_inferred=VERSION>v"1.5") - test_rrule(*, x, y, 7x, 3y, x+y+pi; check_inferred=VERSION>v"1.5") + test_rrule(*, x, y, x+y) + test_rrule(*, x, y, 17x, 23y) + test_rrule(*, x, y, 7x, 3y, x+y+pi) end end diff --git a/test/rulesets/Base/evalpoly.jl b/test/rulesets/Base/evalpoly.jl index 34e0a7f46..25f0202ed 100644 --- a/test/rulesets/Base/evalpoly.jl +++ b/test/rulesets/Base/evalpoly.jl @@ -1,4 +1,5 @@ -VERSION ≥ v"1.4" && @testset "evalpoly" begin + +@testset "evalpoly" begin # test fallbacks for when code generation fails @testset "fallbacks for $T" for T in (Float64, ComplexF64) x, p = randn(T), Tuple(randn(T, 10)) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 36f7612ef..c9a740206 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -222,7 +222,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig() test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64))) test_rrule(foldl, max, rand(3); fkwargs=(; init=999)) end - VERSION >= v"1.5" && @testset "foldl(f, ::Tuple)" begin + @testset "foldl(f, ::Tuple)" begin y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1) @test y1 == 6 b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14)) @@ -295,11 +295,9 @@ end @test y1 == [1, 2, 6, 24] @test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6]) - if VERSION >= v"1.5" - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) - @test y2 ≈ accumulate(/, [1 2; 3 4]) - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 - end + y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) + @test y2 ≈ accumulate(/, [1 2; 3 4]) + @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 # Test execution order c3 = Counter() @@ -328,12 +326,10 @@ end # Finite differencing test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) - if VERSION >= v"1.5" - test_rrule(accumulate, /, 1 .+ rand(3, 4)) - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) - end + test_rrule(accumulate, /, 1 .+ rand(3, 4)) + test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) end - VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin + @testset "accumulate(f, ::Tuple)" begin # Simple y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) @test y1 == (1, 2, 6, 24) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index c7413dbb3..601000d7e 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -64,8 +64,7 @@ @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) test_frule(pinv, F(randn(T, 3))) - check_inferred = VERSION ≥ v"1.5" - test_rrule(pinv, F(randn(T, 3)); check_inferred=check_inferred) + test_rrule(pinv, F(randn(T, 3))) # Check types. # TODO: Do we need this still? diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 982141568..97b0cfa0f 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -383,8 +383,7 @@ end # also we might be missing some overloads for different tangent-types in the rules @testset "cholesky" begin @testset "Real" begin - check_inferred = VERSION ≥ v"1.5" - test_rrule(cholesky, 0.8; check_inferred=check_inferred) + test_rrule(cholesky, 0.8) end @testset "Diagonal{<:Real}" begin D = Diagonal(rand(5) .+ 0.1) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 07e80c2cf..3551583fa 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -79,7 +79,7 @@ @testset "rrule" begin test_rrule(norm, x) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5") + test_rrule(norm, MT(x)) end ȳ = rand_tangent(norm(x)) @@ -123,7 +123,7 @@ test_rrule(fnorm, x, p; kwargs...) x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular) - test_rrule(fnorm, MT(x), p; kwargs..., check_inferred=VERSION>=v"1.5") + test_rrule(fnorm, MT(x), p; kwargs...) end ȳ = rand_tangent(fnorm(x, p)) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 7a02223d8..653b5852c 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -40,8 +40,9 @@ test_rrule(diag, Diagonal(randn(N))) test_rrule(diag, randn(N, N) ⊢ Diagonal(randn(N))) test_rrule(diag, Diagonal(randn(N)) ⊢ Diagonal(randn(N))) - VERSION ≥ v"1.3" && @testset "k=$k" for k in (-1, 0, 2) + for k in (-1, 0, 2) test_rrule(diag, randn(N, N), k) + @test_skip test_rrule(diag, Diagonal(randn(N)), k) end end @testset "diagm" begin @@ -65,7 +66,7 @@ @test ∂px.second ≈ ∂x_fd end end - VERSION ≥ v"1.3" && @testset "with size" begin + @testset "with size" begin M, N = 7, 9 a, ā = randn(M), randn(M) b, b̄ = randn(M), randn(M) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 35e20fa96..648b446d3 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -233,7 +233,7 @@ end n = 10 - VERSION ≥ v"1.3.0" && @testset "rrule for svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "rrule for svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -262,7 +262,7 @@ return (; (s => getproperty(F_, s) for s in nzprops)...) end - VERSION ≥ v"1.6.0-DEV.1686" && @maybe_inferred back(∂F) + @maybe_inferred back(∂F) ∂self, ∂symA = back(∂F) @test ∂self === NoTangent() @test ∂symA isa typeof(symA) diff --git a/test/runtests.jl b/test/runtests.jl index 61548644b..52656f4cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,6 @@ using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm -using Compat: Compat, hasproperty, only, cispi, eachcol using FiniteDifferences using LinearAlgebra using LinearAlgebra.BLAS @@ -18,7 +17,7 @@ using Statistics using Test using JuliaInterpreter -union!(JuliaInterpreter.compiled_modules, Any[Base, Base.Broadcast, Compat, LinearAlgebra, Random, StaticArrays, Statistics]) +union!(JuliaInterpreter.compiled_modules, Any[Base, Base.Broadcast, LinearAlgebra, Random, StaticArrays, Statistics]) Random.seed!(1) # Set seed that all testsets should reset to.