Skip to content

Commit 4fe2c9d

Browse files
authored
Merge pull request #1548 from FluxML/mcabbott-patch-1
Improve the limitations page
2 parents bb3730e + 55b3947 commit 4fe2c9d

File tree

2 files changed

+21
-28
lines changed

2 files changed

+21
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.7.0"
3+
version = "0.7.1"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

docs/src/limitations.md

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin
2020
```julia
2121
function f!(x)
2222
x .= 2 .* x
23-
2423
return x
2524
end
2625
```
@@ -42,43 +41,36 @@ Stacktrace:
4241
...
4342
```
4443
We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
45-
- setting values (`x .= ...`)
46-
- appending/popping values (`push!(x, v)` / `pop!(x)`)
47-
- calling mutating functions (`mul!(C, A, B)`)
44+
- setting values (`x[i] = val` or `x .= values`)
45+
- appending/popping values (`push!(x, v)` or `pop!(x)`)
46+
- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`)
4847

4948
!!! warning
5049

5150
Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.
5251

5352
```julia
54-
function g!(x, y)
55-
x .= 2 .* y
56-
53+
function g_inner!(x, y)
54+
for i in eachindex(x, y)
55+
x[i] = 2 * y[i]
56+
end
5757
return x
5858
end
59-
g(y) = g!(similar(y), y)
60-
```
61-
Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.
62-
63-
Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
64-
```julia
65-
function g!(x, y)
66-
x .= 2 .* y
6759

68-
return x
60+
function g_outer(y)
61+
z = similar(y)
62+
g_inner!(z, y)
63+
return z
6964
end
65+
```
66+
Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package.
7067

71-
function g(y)
72-
x = Zygote.Buffer(y) # Buffer supports syntax like similar
73-
g!(x, y)
74-
return copy(x) # this step makes the Buffer immutable (w/o actually copying)
75-
end
68+
How can you solve this problem?
69+
* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it.
70+
* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression.
71+
* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them.
7672

77-
julia> gradient(rand(3)) do y
78-
sum(g(y))
79-
end
80-
([2.0, 2.0, 2.0],)
81-
```
73+
Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended.
8274

8375
## Try-catch statements
8476

@@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f
136128
2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
137129
3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)
138130

139-
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.
131+
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension.
132+
If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value.
140133

141134
Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
142135
```julia

0 commit comments

Comments
 (0)