Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,10 @@ function optimization_passes(
dus_to_concat::Bool=false,
recognize_comms::Bool=true,
lower_comms::Bool=true,
max_constant_threshold::Int=1024,
backend::String="gpu",
)
(; max_constant_threshold) = compile_options

transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down
161 changes: 147 additions & 14 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TestUtils

using ..Reactant: Reactant, TracedRArray
using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils
using Reactant.Ops: @opcall
using ReactantCore: ReactantCore
using LinearAlgebra: LinearAlgebra

Expand All @@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
return reshape(collect(T, 1:prod(dims)), dims...)
end

function finite_difference_gradient(
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
) where {T}
# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
if fdtype == :forward
return sqrt(eps(real(T)))
elseif fdtype == :central
return cbrt(eps(real(T)))
elseif fdtype == :hcentral
return eps(T)^(T(1 / 4))
else
return one(real(T))
end
end

function get_perturbation(x::AbstractArray{T}, epsilon) where {T}
onehot_matrix = Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
TracedRArray{Reactant.unwrapped_eltype(T),2},
LinearAlgebra.Diagonal(fill(epsilon, length(x)));
)
return permutedims(
reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...)
)
end

function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_ .- perturbation; dims=1)
end

function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_; dims=1)
end

function finite_difference_gradient(
f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central)
) where {F}
argprefix = gensym("finitediffarg")
resprefix = gensym("finitediffresult")
resargprefix = gensym("finitediffresarg")

# TODO: can we detect and prevent using functions that mutate their arguments?
mlir_fn_res = TracedUtils.make_mlir_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do the full make_mlir_fn here, or can we reuse the equivalent of traced call or traced functions to achieve the same effect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using make_mlir_fn mostly to linearlize the arguments. not sure how to use traced_call here though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so it was linearization, not efficiency

f,
args,
(),
"finite_difference_gradient_fn",
false;
args_in_result=:none,
argprefix,
resprefix,
resargprefix,
)
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)

f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))
return ReactantCore.materialize_traced_array(
reshape(
(f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon),
size(x),
),

seenargs = Reactant.OrderedIdDict()
Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath)
for (i, arg) in enumerate(args)
Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath)
end

linear_args = Reactant.TracedType[]
for (k, v) in seenargs
v isa Reactant.TracedType || continue
push!(linear_args, v)
end

if (
length(mlir_fn_res.linear_results) != 1 ||
!(mlir_fn_res.linear_results[1] isa TracedRNumber)
)
error("`finite_difference_gradient` only supports functions with a single scalar \
output. Received : $(mlir_fn_res.linear_results)")
end

gradient_results = TracedRArray[]
gradient_result_map_path = []
for i in 1:length(linear_args)
arg = linear_args[i]
if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix)
path = TracedUtils.get_idx(arg, argprefix)
if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1
continue
end

# We need the gradient wrt this argument
# we will naively insert the args here, cse will take care of the rest
new_arguments = TracedRArray[]

epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg))
pertubed_arg = generate_perturbed_array(method, arg, epsilon)

bsize = size(pertubed_arg, 1)
for j in 1:length(linear_args)
if i == j
new_arg = pertubed_arg
elseif linear_args[j] isa TracedRNumber
new_arg = @opcall broadcast_in_dim(
linear_args[j], Int64[], Int64[bsize]
)
else
new_arg = @opcall broadcast_in_dim(
linear_args[j],
collect(Int64, 2:(ndims(linear_args[j]) + 1)),
Int64[bsize, size(linear_args[j])...],
)
end
new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];)
push!(new_arguments, new_arg)
end

batched_res = @opcall batch(
new_arguments,
[
Reactant.MLIR.IR.TensorType(
Int64[bsize],
Reactant.MLIR.IR.Type(
Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1])
),
),
],
Int64[bsize];
fn=mlir_fn_res.f,
)
batched_res = only(batched_res)

if method isa Val{:central}
diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]
grad_res = diff ./ (2 * epsilon)
elseif method isa Val{:forward}
diff = batched_res[1:(end - 1)] .- batched_res[end:end]
grad_res = diff ./ epsilon
end

push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix))
push!(
gradient_results,
ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))),
)
end
end

results = deepcopy(args)
for (path, grad_res) in zip(gradient_result_map_path, gradient_results)
TracedUtils.set!(results, path[2:end], grad_res.mlir_data)
end
length(args) == 1 && return results[1]
return results
end

end
11 changes: 10 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T,
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
Base.complex(x::TracedRArray{<:Complex}) = x

function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict)
if haskey(stackdict, x)
return stackdict[x]::typeof(x)
end
y = copy(x)
stackdict[x] = y
return y
end

TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)

function maybe_assert_scalar_setindexing(
Expand Down Expand Up @@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
return accumulate!(op, A, B; dims=1)
end

if isdefined(Base, :_accumulate_promote_op)
@static if isdefined(Base, :_accumulate_promote_op)
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
if init !== nothing
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
Expand Down
1 change: 1 addition & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ for (jlop, hloop) in (
(:(Base.log), :log),
(:(Base.log1p), :log_plus_one),
(:(Base.sqrt), :sqrt),
(:(Base.cbrt), :cbrt),
(:(Base.acos), :acos),
(:(Base.acosh), :acosh),
(:(Base.asin), :asin),
Expand Down
4 changes: 2 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ function overloaded_mul!(
return C
end

if isdefined(LinearAlgebra, :_triu)
@static if isdefined(LinearAlgebra, :_triu)
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_triu(materialize_traced_array(A), k)
end
Expand All @@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu)
end
end

if isdefined(LinearAlgebra, :_tril)
@static if isdefined(LinearAlgebra, :_tril)
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_tril(materialize_traced_array(A), k)
end
Expand Down
41 changes: 41 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,44 @@ end

@test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) ≈ fill(6, (3, 2))
end

@testset "Finite Difference Gradient" begin
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2))
res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x)
@test res isa Reactant.ConcreteRArray{Float16,2}
end

function fdiff_multiple_args(f, nt, x)
return sum(abs2, f(nt.y .+ x .- nt.x))
end

struct WrapperFunc{T}
x::T
end

(f::WrapperFunc)(x) = x .^ 3 .+ f.x

@testset "Finite Difference Gradient (non vector inputs)" begin
nt = (;
x=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
y=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
)
fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4))
x = Reactant.TestUtils.construct_test_array(Float64, 3, 4)

nt_ra = Reactant.to_rarray(nt)
fn_ra = Reactant.to_rarray(fn)
x_ra = Reactant.to_rarray(x)

results_fd = @jit Reactant.TestUtils.finite_difference_gradient(
fdiff_multiple_args, fn_ra, nt_ra, x_ra
)
@test results_fd isa typeof((fn_ra, nt_ra, x_ra))

results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra)

@test results_fd[1].x ≈ results_enz[1].x
@test results_fd[2].x ≈ results_enz[2].x
@test results_fd[2].y ≈ results_enz[2].y
@test results_fd[3] ≈ results_enz[3]
end
Loading
Loading