From 7d912cb3f243dd7ae7034b6b741f94f77813a57c Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Thu, 27 Jun 2024 16:35:19 +0200 Subject: [PATCH] fix any --- src/array_partition.jl | 4 ++-- test/partitions_test.jl | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index b4047d3b..5b3a1c09 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -169,8 +169,8 @@ function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T} mapreduce(f, op, (i for i in A); kwargs...) end Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x)) -Base.any(f, A::ArrayPartition) = any(f, (any(f, x) for x in A.x)) -Base.any(f::Function, A::ArrayPartition) = any(f, (any(f, x) for x in A.x)) +Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x)) +Base.any(f::Function, A::ArrayPartition) = any((any(f, x) for x in A.x)) Base.any(A::ArrayPartition) = any(identity, A) Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x)) Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x)) diff --git a/test/partitions_test.jl b/test/partitions_test.jl index e31e9da5..1b0cf3f8 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -136,6 +136,26 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0])) @inferred mapreduce(string, *, x) @test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q" +# any +@test !any(isnan, ArrayPartition([1, 2], [3.0, 4.0])) +@test !any(isnan, ArrayPartition([3.0, 4.0])) +@test any(isnan, ArrayPartition([NaN], [3.0, 4.0])) +@test any(isnan, ArrayPartition([NaN])) +@test any(isnan, ArrayPartition(ArrayPartition([NaN]))) +@test any(isnan, ArrayPartition([2], [NaN])) +@test any(isnan, ArrayPartition([2], ArrayPartition([NaN]))) + +# all +@test !all(isnan, ArrayPartition([1, 2], [3.0, 4.0])) +@test !all(isnan, ArrayPartition([3.0, 4.0])) +@test !all(isnan, ArrayPartition([NaN], [3.0, 4.0])) +@test all(isnan, ArrayPartition([NaN])) +@test all(isnan, ArrayPartition(ArrayPartition([NaN]))) +@test !all(isnan, ArrayPartition([2], [NaN])) +@test all(isnan, ArrayPartition([NaN], [NaN])) +@test all(isnan, ArrayPartition([NaN], ArrayPartition([NaN]))) + + # broadcasting _scalar_op(y) = y + 1 # Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function: