Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ Standard library changes
------------------------
* The `@timed` macro now returns a `NamedTuple` ([#34149])
* New `supertypes(T)` function returns a tuple of all supertypes of `T` ([#34419]).
* Sorting-related functions such as `sort` that take the keyword arguments `lt`, `rev`, `order`
and `by` now do not discard `order` if `by` or `lt` are passed. In the former case, the
order from `order` is used to compare the values of `by(element)`. In the latter case,
any order different from `Forward` or `Reverse` will raise an error about the
ambiguity.

#### LinearAlgebra
* The BLAS submodule now supports the level-2 BLAS subroutine `hpmv!` ([#34211]).
Expand Down
30 changes: 18 additions & 12 deletions base/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export # not exported by Base
By, Lt, Perm,
ReverseOrdering, ForwardOrdering,
DirectOrdering,
lt, ord, ordtype
lt, ord

abstract type Ordering end

Expand All @@ -34,8 +34,9 @@ const DirectOrdering = Union{ForwardOrdering,ReverseOrdering{ForwardOrdering}}
const Forward = ForwardOrdering()
const Reverse = ReverseOrdering()

struct By{T} <: Ordering
struct By{T, O} <: Ordering
by::T
order::O
end

struct Lt{T} <: Ordering
Expand All @@ -47,9 +48,12 @@ struct Perm{O<:Ordering,V<:AbstractVector} <: Ordering
data::V
end

ReverseOrdering(by::By) = By(by.by, ReverseOrdering(by.order))
ReverseOrdering(perm::Perm) = Perm(ReverseOrdering(perm.order), perm.data)

lt(o::ForwardOrdering, a, b) = isless(a,b)
lt(o::ReverseOrdering, a, b) = lt(o.fwd,b,a)
lt(o::By, a, b) = isless(o.by(a),o.by(b))
lt(o::By, a, b) = lt(o.order,o.by(a),o.by(b))
lt(o::Lt, a, b) = o.lt(a,b)

@propagate_inbounds function lt(p::Perm, a::Integer, b::Integer)
Expand All @@ -58,16 +62,18 @@ lt(o::Lt, a, b) = o.lt(a,b)
lt(p.order, da, db) | (!lt(p.order, db, da) & (a < b))
end

ordtype(o::ReverseOrdering, vs::AbstractArray) = ordtype(o.fwd, vs)
ordtype(o::Perm, vs::AbstractArray) = ordtype(o.order, o.data)
# TODO: here, we really want the return type of o.by, without calling it
ordtype(o::By, vs::AbstractArray) = try typeof(o.by(vs[1])) catch; Any end
ordtype(o::Ordering, vs::AbstractArray) = eltype(vs)

_ord(lt::typeof(isless), by::typeof(identity), order::Ordering) = order
_ord(lt::typeof(isless), by, order::Ordering) = By(by)
_ord(lt, by::typeof(identity), order::Ordering) = Lt(lt)
_ord(lt, by, order::Ordering) = Lt((x,y)->lt(by(x),by(y)))
_ord(lt::typeof(isless), by, order::Ordering) = By(by, order)

function _ord(lt, by, order::Ordering)
if order == Forward
return Lt((x, y) -> lt(by(x), by(y)))
elseif order == Reverse
return Lt((x, y) -> lt(by(y), by(x)))
else
error("Passing both lt= and order= arguments is ambiguous; please pass order=Forward or order=Reverse (or leave default)")
end
end

ord(lt, by, rev::Nothing, order::Ordering=Forward) = _ord(lt, by, order)

Expand Down
38 changes: 38 additions & 0 deletions test/ordering.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using Test

import Base.Order: Forward, Reverse

# every argument can flip the integer order by passing the right value. Here,
# we enumerate a few of these combinations and check that all these flips
# compound so that in total we either have an increasing or decreasing sort.
for (s1, rev) in enumerate([true, false])
for (s2, lt) in enumerate([>, <, (a, b) -> a - b > 0, (a, b) -> a - b < 0])
for (s3, by) in enumerate([-, +])
for (s4, order) in enumerate([Reverse, Forward])
if iseven(s1 + s2 + s3 + s4)
target = [1, 2, 3]
else
target = [3, 2, 1]
end
@test target == sort([2, 3, 1], rev=rev, lt=lt, by=by, order=order)
end
end
end
end

@test [1 => 3, 2 => 5, 3 => 1] ==
sort([1 => 3, 2 => 5, 3 => 1]) ==
sort([1 => 3, 2 => 5, 3 => 1], by=first) ==
sort([1 => 3, 2 => 5, 3 => 1], rev=true, order=Reverse) ==
sort([1 => 3, 2 => 5, 3 => 1], lt= >, order=Reverse)

@test [3 => 1, 1 => 3, 2 => 5] ==
sort([1 => 3, 2 => 5, 3 => 1], by=last) ==
sort([1 => 3, 2 => 5, 3 => 1], by=last, rev=true, order=Reverse) ==
sort([1 => 3, 2 => 5, 3 => 1], by=last, lt= >, order=Reverse)


struct SomeOtherOrder <: Base.Order.Ordering end

@test_throws ErrorException sort([1, 2, 3], lt=(a, b) -> a - b < 0, order=SomeOtherOrder())