Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ Each distance corresponds to a *distance type*. You can always compute a certain

```julia
r = evaluate(dist, x, y)
r = dist(x, y)
```

Here, dist is an instance of a distance type. For example, the type for Euclidean distance is ``Euclidean`` (more distance types will be introduced in the next section), then you can compute the Euclidean distance between ``x`` and ``y`` as

```julia
r = evaluate(Euclidean(), x, y)
r = Euclidean()(x, y)
```

Common distances also come with convenient functions for distance evaluation. For example, you may also compute Euclidean distance between two vectors as below
Expand Down
15 changes: 6 additions & 9 deletions src/bhattacharyya.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ struct BhattacharyyaDist <: SemiMetric end

struct HellingerDist <: Metric end


# Bhattacharyya coefficient

function bhattacharyya_coeff(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number}
Expand Down Expand Up @@ -37,13 +36,11 @@ bhattacharyya_coeff(a::T, b::T) where {T <: Number} = throw("Bhattacharyya coeff


# Bhattacharyya distance
evaluate(dist::BhattacharyyaDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
bhattacharyya(a::AbstractVector, b::AbstractVector) = evaluate(BhattacharyyaDist(), a, b)
evaluate(dist::BhattacharyyaDist, a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
bhattacharyya(a::T, b::T) where {T <: Number} = evaluate(BhattacharyyaDist(), a, b)
(::BhattacharyyaDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
(::BhattacharyyaDist)(a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
bhattacharyya(a, b) = BhattacharyyaDist()(a, b)

# Hellinger distance
evaluate(dist::HellingerDist, a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
hellinger(a::AbstractVector, b::AbstractVector) = evaluate(HellingerDist(), a, b)
evaluate(dist::HellingerDist, a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
hellinger(a::T, b::T) where {T <: Number} = evaluate(HellingerDist(), a, b)
(::HellingerDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
(::HellingerDist)(a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
hellinger(a, b) = HellingerDist()(a, b)
44 changes: 22 additions & 22 deletions src/bregman.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
# Bregman divergence
# Bregman divergence

"""
Implements the Bregman divergence, a friendly introduction to which can be found
[here](http://mark.reid.name/blog/meet-the-bregman-divergences.html).
Bregman divergences are a minimal implementation of the "mean-minimizer" property.
[here](http://mark.reid.name/blog/meet-the-bregman-divergences.html).
Bregman divergences are a minimal implementation of the "mean-minimizer" property.
It is assumed that the (convex differentiable) function F maps vectors (of any type or size) to real numbers.
The inner product used is `Base.dot`, but one can be passed in either by defining `inner` or by
passing in a keyword argument. If an analytic gradient isn't available, Julia offers a suite
of good automatic differentiation packages.
It is assumed that the (convex differentiable) function F maps vectors (of any type or size) to real numbers.
The inner product used is `Base.dot`, but one can be passed in either by defining `inner` or by
passing in a keyword argument. If an analytic gradient isn't available, Julia offers a suite
of good automatic differentiation packages.
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
"""
struct Bregman{T1 <: Function, T2 <: Function, T3 <: Function} <: PreMetric
struct Bregman{T1 <: Function,T2 <: Function,T3 <: Function} <: PreMetric
F::T1
::T2
inner::T3
end

# Default costructor.
# Default costructor.
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)

# Evaluation fuction
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
# Evaluation fuction
function (dist::Bregman)(p::AbstractVector, q::AbstractVector)
# Create cache vals.
FP_val = dist.F(p);
FQ_val = dist.F(q);
FQ_val = dist.F(q);
DQ_val = dist.(q);
p_size = size(p);
# Check F codomain.
# Check F codomain.
if !(isa(FP_val, Real) && isa(FQ_val, Real))
throw(ArgumentError("F Codomain Error: F doesn't map the vectors to real numbers"))
end
# Check vector size.
end
# Check vector size.
if !(p_size == size(q))
throw(DimensionMismatch("The vector p ($(size(p))) and q ($(size(q))) are different sizes."))
end
# Check gradient size.
# Check gradient size.
if !(size(DQ_val) == p_size)
throw(DimensionMismatch("The gradient result is not the same size as p and q"))
end
# Return the Bregman divergence.
return FP_val - FQ_val - dist.inner(DQ_val, p-q);
end
end
# Return the Bregman divergence.
return FP_val - FQ_val - dist.inner(DQ_val, p - q);
end

# Convenience function.
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y)
# Convenience function.
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = Bregman(F, ∇, inner)(x, y)
25 changes: 13 additions & 12 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ abstract type SemiMetric <: PreMetric end
#
abstract type Metric <: SemiMetric end

evaluate(dist::PreMetric, a, b) = dist(a, b)

# Generic functions

Expand All @@ -33,7 +34,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs
n = size(b, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, a, view(b, :, j))
r[j] = metric(a, view(b, :, j))
end
r
end
Expand All @@ -42,7 +43,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
n = size(a, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, view(a, :, j), b)
r[j] = metric(view(a, :, j), b)
end
r
end
Expand All @@ -51,7 +52,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
n = get_common_ncols(a, b)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:n
r[j] = evaluate(metric, view(a, :, j), view(b, :, j))
r[j] = metric(view(a, :, j), view(b, :, j))
end
r
end
Expand Down Expand Up @@ -82,14 +83,14 @@ end
# Generic pairwise evaluation

function _pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a)
a::AbstractMatrix, b::AbstractMatrix = a)
na = size(a, 2)
nb = size(b, 2)
size(r) == (na, nb) || throw(DimensionMismatch("Incorrect size of r."))
@inbounds for j = 1:size(b, 2)
bj = view(b, :, j)
for i = 1:size(a, 2)
r[i, j] = evaluate(metric, view(a, :, i), bj)
r[i, j] = metric(view(a, :, i), bj)
end
end
r
Expand All @@ -101,7 +102,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
@inbounds for j = 1:n
aj = view(a, :, j)
for i = (j + 1):n
r[i, j] = evaluate(metric, view(a, :, i), aj)
r[i, j] = metric(view(a, :, i), aj)
end
r[j, j] = 0
for i = 1:(j - 1)
Expand Down Expand Up @@ -135,7 +136,7 @@ If a single matrix `a` is provided, compute distances between its rows or column
"""
function pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims::Union{Nothing,Integer} = nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
if dims == 1
Expand All @@ -159,7 +160,7 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric,
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims::Union{Nothing,Integer} = nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
if dims == 1
Expand All @@ -186,20 +187,20 @@ compute distances between its rows or columns.
`a` and `b` must have the same numbers of columns if `dims=1`, or of rows if `dims=2`.
"""
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims::Union{Nothing,Integer} = nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
m = size(a, dims)
n = size(b, dims)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
pairwise!(r, metric, a, b, dims=dims)
pairwise!(r, metric, a, b, dims = dims)
end

function pairwise(metric::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims::Union{Nothing,Integer} = nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
n = size(a, dims)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a, dims=dims)
pairwise!(r, metric, a, dims = dims)
end
12 changes: 6 additions & 6 deletions src/haversine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ The haversine distance between two locations on a sphere of given `radius`.
Locations are described with longitude and latitude in degrees.
The computed distance has the same units as that of the radius.
"""
struct Haversine{T<:Real} <: Metric
struct Haversine{T <: Real} <: Metric
radius::T
end

const VecOrLengthTwoTuple{T} = Union{AbstractVector{T}, NTuple{2, T}}
const VecOrLengthTwoTuple{T} = Union{AbstractVector{T},NTuple{2,T}}

function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
function (dist::Haversine)(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple)
length(x) == length(y) == 2 || haversine_error()

@inbounds begin
Expand All @@ -27,12 +27,12 @@ function evaluate(dist::Haversine, x::VecOrLengthTwoTuple, y::VecOrLengthTwoTupl
Δφ = φ₂ - φ₁

# haversine formula
a = sin(Δφ/2)^2 + cos(φ₁)*cos(φ₂)*sin(Δλ/2)^2
a = sin(Δφ / 2)^2 + cos(φ₁) * cos(φ₂) * sin(Δλ / 2)^2

# distance on the sphere
2 * dist.radius * asin( min(√a, one(a)) ) # take care of floating point errors
2 * dist.radius * asin(min(√a, one(a))) # take care of floating point errors
end

haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = evaluate(Haversine(radius), x, y)
haversine(x::VecOrLengthTwoTuple, y::VecOrLengthTwoTuple, radius::Real) = Haversine(radius)(x, y)

@noinline haversine_error() = throw(ArgumentError("expected both inputs to have length 2 in Haversine distance"))
10 changes: 5 additions & 5 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ result_type(::SqMahalanobis{T}, ::AbstractArray, ::AbstractArray) where {T} = T

# SqMahalanobis

function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
function (dist::SqMahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
Expand All @@ -23,7 +23,7 @@ function evaluate(dist::SqMahalanobis{T}, a::AbstractVector, b::AbstractVector)
return dot(z, Q * z)
end

sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(SqMahalanobis(Q), a, b)
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::SqMahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
Q = dist.qmat
Expand Down Expand Up @@ -83,11 +83,11 @@ end

# Mahalanobis

function evaluate(dist::Mahalanobis{T}, a::AbstractVector, b::AbstractVector) where {T <: Real}
sqrt(evaluate(SqMahalanobis(dist.qmat), a, b))
function (dist::Mahalanobis{T})(a::AbstractVector, b::AbstractVector) where {T <: Real}
sqrt(SqMahalanobis(dist.qmat)(a, b))
end

mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = evaluate(Mahalanobis(Q), a, b)
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)

function colwise!(r::AbstractArray, dist::Mahalanobis{T}, a::AbstractMatrix, b::AbstractMatrix) where {T <: Real}
sqrt!(colwise!(r, SqMahalanobis(dist.qmat), a, b))
Expand Down
Loading