diff --git a/src/Interpolations.jl b/src/Interpolations.jl index 5d480ee9..eeedaa22 100644 --- a/src/Interpolations.jl +++ b/src/Interpolations.jl @@ -3,7 +3,12 @@ module Interpolations using Base.Cartesian using Compat -import Base: size, eltype, getindex, ndims +import Base: + eltype, + gradient, + getindex, + ndims, + size export Interpolation, @@ -152,9 +157,31 @@ for IT in ( ret end )) + + eval(ngenerate( + :N, + :(Array{promote_type(T,typeof(x)...),1}), + :(gradient!{T,N}(g::Array{T,1}, itp::Interpolation{T,N,$IT,$EB}, x::NTuple{N,Real}...)), + N->quote + $(extrap_transform_x(gr,eb,N)) + $(define_indices(it,N)) + @nexprs $N dim->begin + @nexprs $N d->begin + (d==dim + ? $(gradient_coefficients(it,N,:d)) + : $(coefficients(it,N,:d))) + end + + @inbounds g[dim] = $(index_gen(degree(it),N)) + end + g + end + )) end end +gradient{T}(itp::Interpolation{T}, x...) = gradient!(Array(T,ndims(itp)), itp, x...) + # This creates prefilter specializations for all interpolation types that need them for IT in ( Quadratic{Flat,OnCell}, diff --git a/src/constant.jl b/src/constant.jl index 54b33f23..6e029bea 100644 --- a/src/constant.jl +++ b/src/constant.jl @@ -6,8 +6,18 @@ function define_indices(::Constant, N) :(@nexprs $N d->(ix_d = clamp(round(Int,x_d), 1, size(itp,d)))) end -function coefficients(::Constant, N) - :(@nexprs $N d->(c_d = one(typeof(x_d)))) +function coefficients(c::Constant, N) + :(@nexprs $N d->($(coefficients(c, N, :d)))) +end + +function coefficients(::Constant, N, d) + sym, symx = symbol(string("c_",d)), symbol(string("x_",d)) + :($sym = one(typeof($symx))) +end + +function gradient_coefficients(::Constant, N, d) + sym, symx = symbol(string("c_",d)), symbol(string("x_",d)) + :($sym = zero(typeof($symx))) end function index_gen(degree::ConstantDegree, N::Integer, offsets...) diff --git a/src/linear.jl b/src/linear.jl index 955c0d41..0543db2e 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -12,12 +12,23 @@ function define_indices(::Linear, N) end end -function coefficients(::Linear, N) +function coefficients(l::Linear, N) + :(@nexprs $N d->($(coefficients(l, N, :d)))) +end + +function coefficients(::Linear, N, d) + sym, symp, symfx = symbol(string("c_",d)), symbol(string("cp_",d)), symbol(string("fx_",d)) quote - @nexprs $N d->begin - c_d = one(typeof(fx_d)) - fx_d - cp_d = fx_d - end + $sym = one(typeof($symfx)) - $symfx + $symp = $symfx + end +end + +function gradient_coefficients(::Linear,N,d) + sym, symp, symfx = symbol(string("c_",d)), symbol(string("cp_",d)), symbol(string("fx_",d)) + quote + $sym = -one(typeof($symfx)) + $symp = one(typeof($symfx)) end end diff --git a/src/quadratic.jl b/src/quadratic.jl index 060c75ae..b5f29a30 100644 --- a/src/quadratic.jl +++ b/src/quadratic.jl @@ -28,13 +28,27 @@ function define_indices(q::Quadratic{Periodic}, N) end end -function coefficients(::Quadratic, N) +function coefficients(q::Quadratic, N) + :(@nexprs $N d->($(coefficients(q, N, :d)))) +end + +function coefficients(q::Quadratic, N, d) + symm, sym, symp = symbol(string("cm_",d)), symbol(string("c_",d)), symbol(string("cp_",d)) + symfx = symbol(string("fx_",d)) quote - @nexprs $N d->begin - cm_d = .5 * (fx_d-.5)^2 - c_d = .75 - fx_d^2 - cp_d = .5 * (fx_d+.5)^2 - end + $symm = .5 * ($symfx - .5)^2 + $sym = .75 - $symfx^2 + $symp = .5 * ($symfx + .5)^2 + end +end + +function gradient_coefficients(q::Quadratic, N, d) + symm, sym, symp = symbol(string("cm_",d)), symbol(string("c_",d)), symbol(string("cp_",d)) + symfx = symbol(string("fx_",d)) + quote + $symm = $symfx-.5 + $sym = -2*$symfx + $symp = $symfx+.5 end end diff --git a/test/gradient.jl b/test/gradient.jl new file mode 100644 index 00000000..77d809db --- /dev/null +++ b/test/gradient.jl @@ -0,0 +1,30 @@ +module GradientTests +println("Testing gradient evaluation") +using Base.Test, Interpolations + +nx = 10 +f1(x) = sin((x-3)*2pi/(nx-1) - 1) +g1(x) = 2pi/(nx-1) * cos((x-3)*2pi/(nx-1) - 1) + +# Gradient of Constant should always be 0 +itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1], + Constant(OnGrid()), ExtrapPeriodic()) +for x in 1:nx + @test gradient(itp1, x)[1] == 0 +end + +# Since Linear is OnGrid in the domain, check the gradients between grid points +itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1], + Linear(OnGrid()), ExtrapPeriodic()) +for x in 2.5:nx-1.5 + @test_approx_eq_eps g1(x) gradient(itp1, x)[1] abs(.1*g1(x)) +end + +# Since Quadratic is OnCell in the domain, check gradients at grid points +itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1], + Quadratic(Periodic(),OnGrid()), ExtrapPeriodic()) +for x in 2:nx-1 + @test_approx_eq_eps g1(x) gradient(itp1, x)[1] abs(.05*g1(x)) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c4525440..27fd6b0e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,9 @@ include("quadratic.jl") # indices inbounds in A. include("on-grid.jl") +# test gradient evaluation +include("gradient.jl") + # Tests copied from Grid.jl's old test suite #include("grid.jl")