Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.61"
version = "0.7.62"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
37 changes: 37 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,40 @@ function rrule(::typeof(fill), value::Any, dims::Int...)
end
return fill(value, dims), fill_pullback
end

#####
##### `repeat`
#####

function rrule(::typeof(repeat), x::AbstractVector, m::Integer)
function repeat_pullback(Ȳ)
return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist())
Comment on lines +112 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this close over length(x) rather than x (to reduce memory requirements)?

Suggested change
function repeat_pullback(Ȳ)
return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length(x), :); dims=2); dims=2), DoesNotExist())
length_x = length(x)
function repeat_pullback(Ȳ)
return (NO_FIELDS, dropdims(sum(reshape(Ȳ, length_x, :); dims=2); dims=2), DoesNotExist())

And likewise with the size calls in the rules below.

This is a genuine question, i'm not sure what best practice here is (cc @oxinabox).

I believe this would be manually doing one thing we want an OpaqueClosure to be able to do automatically.
Also maybe this should be mentioned on the "writing good rules docs"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I am cautious about doing that since it makes code uglier.
So I wouldn't block a PR over it (it can always be optimized later)
But it does save memory.
We should discuss this with pros and cons in "Writing God Rules"

end
return repeat(x, m), repeat_pullback
end

function rrule(::typeof(repeat), x::AbstractVecOrMat, m::Integer, n::Integer=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since things like repeat(rand(Int8,2,2,2),1,1,2,2) are allowed, it might be nice to treat the general case? Untested, but my attempt is:

function rrule(::typeof(repeat), x::AbstractArray, scales::Integer...)
    function repeat_pullback_1(dy)
        size2ndims = ntuple(d -> isodd(d) ? size(x,1+d÷2) : get(scales,d÷2,1), 2*ndims(dy))
        sumdy = sum(reshape(dy, size2ndims); dims = ntuple(d -> 2d, ndims(dy)))
        return (NO_FIELDS, reshape(sumdy, size(x)), map(_->DoesNotExist(), scales)...)
    end
    return repeat(x, scales...), repeat_pullback_1
end

What this won't handle is repeat(1:2, 3,2,1,0) but perhaps nobody will write that!

function repeat_pullback(Ȳ)
Ȳ′ = reshape(Ȳ, size(x, 1), m, size(x, 2), n)
return (NO_FIELDS, reshape(sum(Ȳ′; dims=(2,4)), size(x)), DoesNotExist(), DoesNotExist())
end
return repeat(x, m, n), repeat_pullback
end

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
function repeat_pullback(Ȳ)
Ȳ′ = zero(xs)
S = size(xs)
for (dest_idx, val) ∈ pairs(IndexCartesian(), Ȳ)
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim ∈ 1:length(S)]
Ȳ′[src_idx...] += val
end
return (NO_FIELDS, Ȳ′)
end
return repeat(xs; inner=inner, outer=outer), repeat_pullback
end

function rrule(::typeof(repeat), x::AbstractArray{<:Real, 0}, m::Integer)
repeat_pullback(Ȳ) = (NO_FIELDS, similar(x, eltype(Ȳ)) .= sum(Ȳ), DoesNotExist())
return repeat(x, m), repeat_pullback
end
10 changes: 10 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@ end
test_rrule(fill, 44.0, 4; check_inferred=false)
test_rrule(fill, 2.0, (3, 3, 3) ⊢ DoesNotExist())
end

@testset "repeat" begin
test_rrule(repeat, randn(5), 3)
test_rrule(repeat, randn(5), 3, 3)
test_rrule(repeat, randn(3, 3), 2)
test_rrule(repeat, randn(5, 5), 2,5)
test_rrule(repeat, randn(5, 4, 3); fkwargs=(inner=(2, 2, 1), outer=(1, 1, 3)))
test_rrule(repeat, fill(4.0), 3)

end