Skip to content

Commit 9b029a3

Browse files
committed
Merge pull request #15973 from JuliaLang/teh/more_reshape
reshape: helpful error message, more tests
2 parents 17dfdc1 + 64745d7 commit 9b029a3

File tree

3 files changed

+134
-54
lines changed

3 files changed

+134
-54
lines changed

base/reshapedarray.jl

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,28 @@ ReshapedArray{T,N}(parent::AbstractArray{T}, dims::NTuple{N,Int}, mi) = Reshaped
1111
typealias ReshapedArrayLF{T,N,P<:AbstractArray} ReshapedArray{T,N,P,Tuple{}}
1212

1313
# Fast iteration on ReshapedArrays: use the parent iterator
14-
immutable ReshapedRange{I,M}
14+
immutable ReshapedArrayIterator{I,M}
1515
iter::I
1616
mi::NTuple{M,SignedMultiplicativeInverse{Int}}
1717
end
18-
ReshapedRange(A::ReshapedArray) = reshapedrange(parent(A), A.mi)
19-
function reshapedrange{M}(P, mi::NTuple{M})
18+
ReshapedArrayIterator(A::ReshapedArray) = _rs_iterator(parent(A), A.mi)
19+
function _rs_iterator{M}(P, mi::NTuple{M})
2020
iter = eachindex(P)
21-
ReshapedRange{typeof(iter),M}(iter, mi)
21+
ReshapedArrayIterator{typeof(iter),M}(iter, mi)
2222
end
2323

2424
immutable ReshapedIndex{T}
2525
parentindex::T
2626
end
2727

28-
# eachindex(A::ReshapedArray) = ReshapedRange(A) # TODO: uncomment this line
29-
start(R::ReshapedRange) = start(R.iter)
30-
@inline done(R::ReshapedRange, i) = done(R.iter, i)
31-
@inline function next(R::ReshapedRange, i)
28+
# eachindex(A::ReshapedArray) = ReshapedArrayIterator(A) # TODO: uncomment this line
29+
start(R::ReshapedArrayIterator) = start(R.iter)
30+
@inline done(R::ReshapedArrayIterator, i) = done(R.iter, i)
31+
@inline function next(R::ReshapedArrayIterator, i)
3232
item, inext = next(R.iter, i)
3333
ReshapedIndex(item), inext
3434
end
35-
length(R::ReshapedRange) = length(R.iter)
35+
length(R::ReshapedArrayIterator) = length(R.iter)
3636

3737
function reshape(parent::AbstractArray, dims::Dims)
3838
prod(dims) == length(parent) || throw(DimensionMismatch("parent has $(length(parent)) elements, which is incompatible with size $dims"))
@@ -84,19 +84,61 @@ reinterpret{T}(::Type{T}, A::ReshapedArray, dims::Dims) = reinterpret(T, parent(
8484
ind2sub_rs((d+1, out...), tail(strds), r)
8585
end
8686

87-
@inline getindex(A::ReshapedArrayLF, index::Int) = (@boundscheck checkbounds(A, index); @inbounds ret = parent(A)[index]; ret)
88-
@inline getindex(A::ReshapedArray, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_getindex(A, indexes...))
89-
@inline getindex(A::ReshapedArray, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds ret = parent(A)[index.parentindex]; ret)
87+
@inline function getindex(A::ReshapedArrayLF, index::Int)
88+
@boundscheck checkbounds(A, index)
89+
@inbounds ret = parent(A)[index]
90+
ret
91+
end
92+
@inline function getindex(A::ReshapedArray, indexes::Int...)
93+
@boundscheck checkbounds(A, indexes...)
94+
_unsafe_getindex(A, indexes...)
95+
end
96+
@inline function getindex(A::ReshapedArray, index::ReshapedIndex)
97+
@boundscheck checkbounds(parent(A), index.parentindex)
98+
@inbounds ret = parent(A)[index.parentindex]
99+
ret
100+
end
101+
102+
@inline function _unsafe_getindex(A::ReshapedArray, indexes::Int...)
103+
@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]
104+
ret
105+
end
106+
@inline function _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...)
107+
@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]
108+
ret
109+
end
90110

91-
@inline _unsafe_getindex(A::ReshapedArray, indexes::Int...) = (@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]; ret)
92-
@inline _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...) = (@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]; ret)
111+
@inline function setindex!(A::ReshapedArrayLF, val, index::Int)
112+
@boundscheck checkbounds(A, index)
113+
@inbounds parent(A)[index] = val
114+
val
115+
end
116+
@inline function setindex!(A::ReshapedArray, val, indexes::Int...)
117+
@boundscheck checkbounds(A, indexes...)
118+
_unsafe_setindex!(A, val, indexes...)
119+
end
120+
@inline function setindex!(A::ReshapedArray, val, index::ReshapedIndex)
121+
@boundscheck checkbounds(parent(A), index.parentindex)
122+
@inbounds parent(A)[index.parentindex] = val
123+
val
124+
end
125+
126+
@inline function _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...)
127+
@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val
128+
val
129+
end
130+
@inline function _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...)
131+
@inbounds parent(A)[sub2ind(size(A), indexes...)] = val
132+
val
133+
end
93134

94-
@inline setindex!(A::ReshapedArrayLF, val, index::Int) = (@boundscheck checkbounds(A, index); @inbounds parent(A)[index] = val; val)
95-
@inline setindex!(A::ReshapedArray, val, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_setindex!(A, val, indexes...))
96-
@inline setindex!(A::ReshapedArray, val, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds parent(A)[index.parentindex] = val; val)
135+
# helpful error message for a common failure case
136+
typealias ReshapedRange{T,N,A<:Range} ReshapedArray{T,N,A,Tuple{}}
137+
setindex!(A::ReshapedRange, val, index::Int) = _rs_setindex!_err()
138+
setindex!(A::ReshapedRange, val, indexes::Int...) = _rs_setindex!_err()
139+
setindex!(A::ReshapedRange, val, index::ReshapedIndex) = _rs_setindex!_err()
97140

98-
@inline _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...) = (@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val; val)
99-
@inline _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...) = (@inbounds parent(A)[sub2ind(size(A), indexes...)] = val; val)
141+
_rs_setindex!_err() = error("indexed assignment fails for a reshaped range; consider calling collect")
100142

101143
typealias ArrayT{N, T} Array{T,N}
102144
convert{T,S,N}(::Type{Array{T,N}}, V::ReshapedArray{S,N}) = copy!(Array(T, size(V)), V)

test/arrayops.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,50 @@ a = reshape(b, (2, 2, 2, 2, 2))
8888
@test a[2,2,2,2,2] == b[end]
8989

9090
# reshaping linearslow arrays
91-
a = zeros(1, 5)
91+
a = collect(reshape(1:5, 1, 5))
9292
s = sub(a, :, [2,3,5])
93-
@test length(reshape(s, length(s))) == 3
93+
r = reshape(s, length(s))
94+
@test length(r) == 3
95+
@test r[1] == 2
96+
@test r[3,1] == 5
97+
@test r[Base.ReshapedIndex(CartesianIndex((1,2)))] == 3
98+
@test parent(reshape(r, (1,3))) === r.parent === s
99+
@test parentindexes(r) == (1:1, 1:3)
100+
@test reshape(r, (3,)) === r
101+
r[2] = -1
102+
@test a[3] == -1
94103
a = zeros(0, 5) # an empty linearslow array
95104
s = sub(a, :, [2,3,5])
96105
@test length(reshape(s, length(s))) == 0
97106

107+
@test reshape(1:5, (5,)) === 1:5
108+
@test reshape(1:5, 5) === 1:5
109+
110+
# setindex! on a reshaped range
111+
a = reshape(1:20, 5, 4)
112+
for idx in ((3,), (2,2), (Base.ReshapedIndex(1),))
113+
try
114+
a[idx...] = 7
115+
catch err
116+
@test err.msg == "indexed assignment fails for a reshaped range; consider calling collect"
117+
end
118+
end
119+
120+
# operations with LinearFast ReshapedArray
121+
b = collect(1:12)
122+
a = Base.ReshapedArray(b, (4,3), ())
123+
@test a[3,2] == 7
124+
@test a[6] == 6
125+
a[3,2] = -2
126+
a[6] = -3
127+
a[Base.ReshapedIndex(5)] = -4
128+
@test b[5] == -4
129+
@test b[6] == -3
130+
@test b[7] == -2
131+
b = reinterpret(Int, a, (3,4))
132+
b[1] = -1
133+
@test vec(b) == vec(a)
134+
98135
a = rand(1, 1, 8, 8, 1)
99136
@test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1))
100137
@test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8))

test/bitarray.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -324,22 +324,21 @@ t1 = bitrand(n1, n2)
324324
b2 = bitrand(countnz(t1))
325325
@check_bit_operation setindex!(b1, b2, t1) BitMatrix
326326

327-
m1 = rand(1:n1)
328-
m2 = rand(1:n2)
329-
330-
t1 = bitrand(n1)
331-
b2 = bitrand(countnz(t1), m2)
332-
k2 = randperm(m2)
333-
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
334-
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
335-
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix
336-
337-
t2 = bitrand(n2)
338-
b2 = bitrand(m1, countnz(t2))
339-
k1 = randperm(m1)
340-
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
341-
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
342-
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
327+
let m1 = rand(1:n1), m2 = rand(1:n2)
328+
t1 = bitrand(n1)
329+
b2 = bitrand(countnz(t1), m2)
330+
k2 = randperm(m2)
331+
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
332+
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
333+
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix
334+
335+
t2 = bitrand(n2)
336+
b2 = bitrand(m1, countnz(t2))
337+
k1 = randperm(m1)
338+
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
339+
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
340+
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
341+
end
343342

344343
timesofar("indexing")
345344

@@ -1054,23 +1053,25 @@ end
10541053

10551054
## Reductions ##
10561055

1057-
b1 = bitrand(s1, s2, s3, s4)
1058-
m1 = 1
1059-
m2 = 3
1060-
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
1061-
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
1062-
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}
1063-
1064-
@check_bit_operation maximum(b1) Bool
1065-
@check_bit_operation minimum(b1) Bool
1066-
@check_bit_operation any(b1) Bool
1067-
@check_bit_operation all(b1) Bool
1068-
@check_bit_operation sum(b1) Int
1069-
1070-
b0 = falses(0)
1071-
@check_bit_operation any(b0) Bool
1072-
@check_bit_operation all(b0) Bool
1073-
@check_bit_operation sum(b0) Int
1056+
let
1057+
b1 = bitrand(s1, s2, s3, s4)
1058+
m1 = 1
1059+
m2 = 3
1060+
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
1061+
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
1062+
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}
1063+
1064+
@check_bit_operation maximum(b1) Bool
1065+
@check_bit_operation minimum(b1) Bool
1066+
@check_bit_operation any(b1) Bool
1067+
@check_bit_operation all(b1) Bool
1068+
@check_bit_operation sum(b1) Int
1069+
1070+
b0 = falses(0)
1071+
@check_bit_operation any(b0) Bool
1072+
@check_bit_operation all(b0) Bool
1073+
@check_bit_operation sum(b0) Int
1074+
end
10741075

10751076
timesofar("reductions")
10761077

0 commit comments

Comments
 (0)