Skip to content

Commit aa7b3f9

Browse files
mbaumanJeffBezanson
authored andcommitted
Fix #26488: don't map over values not provided (#26521)
This is a symptom of the good old how-to-allocate-a-result-array-of-an-arbitrary-transform-of-its-elements problem. Eventually it'd be nice to solve this with `collect` of a lazy implementation, but for now this papers over the egregious problem.
1 parent a7104ac commit aa7b3f9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

base/reducedim.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,14 @@ let
140140
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
141141
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}
142142

143-
global reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region) =
144-
reducedim_initarray(A, region, mapreduce_first(f, op, zero(eltype(A))))
145-
global reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region) =
146-
reducedim_initarray(A, region, mapreduce_first(f, op, one(eltype(A))))
143+
global function reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region)
144+
z = zero(f(zero(eltype(A))))
145+
reducedim_initarray(A, region, op(z, z))
146+
end
147+
global function reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region)
148+
u = one(f(one(eltype(A))))
149+
reducedim_initarray(A, region, op(u, u))
150+
end
147151
end
148152

149153
## generic (map)reduction

test/reducedim.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,18 @@ for region in Any[-1, 0, (-1, 2), [0, 1], (1,-2,3), [0 1;
338338
@test_throws ArgumentError minimum(abs, Areduc, dims=region)
339339
end
340340

341+
# issue #26488
342+
@testset "don't map over initial values not provided" begin
343+
@test sum(x->x+1, [1], dims=1)[1] === sum(x->x+1, [1]) === 2
344+
@test prod(x->x+1, [1], dims=1)[1] === prod(x->x+1, [1]) === 2
345+
@test mapreduce(x->x+1, +, [1], dims=1)[1] === mapreduce(x->x+1, +, [1]) === 2
346+
@test mapreduce(x->x+1, *, [1], dims=1)[1] === mapreduce(x->x+1, *, [1]) === 2
347+
@test mapreduce(!, &, [false], dims=1)[1] === mapreduce(!, &, [false]) === true
348+
@test mapreduce(!, |, [true], dims=1)[1] === mapreduce(!, |, [true]) === false
349+
@test mapreduce(x->1/x, max, [1], dims=1)[1] === mapreduce(x->1/x, max, [1]) === 1.0
350+
@test mapreduce(x->-1/x, min, [1], dims=1)[1] === mapreduce(x->-1/x, min, [1]) === -1.0
351+
end
352+
341353
# check type of result
342354
@testset "type of sum(::Array{$T}" for T in [UInt8, Int8, Int32, Int64, BigInt]
343355
result = sum(T[1 2 3; 4 5 6; 7 8 9], dims=2)

0 commit comments

Comments
 (0)