Skip to content

Commit 43e2113

Browse files
authored
Merge pull request #214 from JuliaDiff/ox/scalarmultiple
Make scalar_rule's frule return a Composite not a tuple
2 parents 7582999 + 52ca7d6 commit 43e2113

File tree

5 files changed

+38
-21
lines changed

5 files changed

+38
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.12"
3+
version = "0.9.13"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/rule_definition_tools.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,12 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
148148
propagation_expr(Δs, ∂s)
149149
end
150150
if n_outputs > 1
151-
# For forward-mode we only return a tuple if output actually a tuple.
152-
pushforward_returns = Expr(:tuple, pushforward_returns...)
151+
# For forward-mode we return a Composite if output actually a tuple.
152+
pushforward_returns = Expr(
153+
:call, :(ChainRulesCore.Composite{typeof($(esc()))}), pushforward_returns...
154+
)
153155
else
154-
pushforward_returns = pushforward_returns[1]
156+
pushforward_returns = first(pushforward_returns)
155157
end
156158

157159
return quote

src/rules.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,30 @@ julia> Δsinx == cos(x)
2929
true
3030
```
3131
32-
unary input, binary output scalar function:
32+
Unary input, binary output scalar function:
3333
3434
```jldoctest frule
3535
julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);
3636
3737
julia> sincosx == sincos(x)
3838
true
3939
40-
julia> Δsincosx == (cos(x), -sin(x))
40+
julia> Δsincosx[1] == cos(x)
41+
true
42+
43+
julia> Δsincosx[2] == -sin(x)
4144
true
4245
```
4346
47+
Note that techically speaking julia does not have multiple output functions, just functions
48+
that return a single output that is iterable, like a `Tuple`.
49+
So this is actually a [`Composite`](@ref):
50+
```jldoctest frule
51+
julia> Δsincosx
52+
Composite{Tuple{Float64,Float64}}(0.6795498147167869, -0.7336293678134624)
53+
```.
54+
55+
4456
See also: [`rrule`](@ref), [`@scalar_rule`](@ref)
4557
"""
4658
frule(::Any, ::Vararg{Any}; kwargs...) = nothing

test/rule_definition_tools.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,21 @@ end
110110
)
111111
end
112112
end
113+
114+
@testset "@scalar_rule" begin
115+
@testset "@scalar_rule with multiple output" begin
116+
simo(x) = (x, 2x)
117+
@scalar_rule(simo(x), 1f0, 2f0)
118+
119+
y, simo_pb = rrule(simo, π)
120+
121+
@test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0)
122+
123+
y, ẏ = frule((NO_FIELDS, 50f0), simo, π)
124+
@test y == (π, 2π)
125+
@test== Composite{typeof(y)}(50f0, 100f0)
126+
# make sure type is exactly as expected:
127+
@testisa Composite{Tuple{Irrational{}, Float64}, Tuple{Float32, Float32}}
128+
end
129+
end
113130
end

test/rules.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,4 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
136136
@test ∂self == NO_FIELDS
137137
@test ∂x j′vp(central_fdm(5, 1), complex_times, Ω̄, x)[1]
138138
end
139-
end
140-
141-
142-
simo(x) = (x, 2x)
143-
@scalar_rule(simo(x), 1, 2)
144-
145-
@testset "@scalar_rule with multiple inputs" begin
146-
y, simo_pb = rrule(simo, π)
147-
148-
@test simo_pb((10, 20)) == (NO_FIELDS, 50)
149-
150-
y, ẏ = frule((NO_FIELDS, 50), simo, π)
151-
@test y == (π, 2π)
152-
@test== (50, 100)
153-
end
139+
end

0 commit comments

Comments
 (0)