Skip to content

Commit 4d6e516

Browse files
committed
Don't allow reinterprets that would expose padding
In #25908 it was noted that reinterpreting structures with paddings exposes undef LLVM values to user code. This is problematic, because an LLVM undef value is quite dangerous (it can have a different value at every use, e.g. for `a::Bool` undef, we can have `a || !a == true`. There are proposal in LLVM to create values that are merely arbitrary (but the same at every use), but that capability does not currently exist in LLVM. As such, we should try hard to prevent `undef` showing up in a user-visible way. There are several ways to fix this: 1. Wait until LLVM comes up with a safer `undef` and have the value merely be arbitrary, but not dangerous. 2. Always guarantee that padding bytes will be 0. 3. For contiguous-memory arrays, guarantee that we end up with the underlying bytes from that array. However, for now, I think don't think we should make a choice here. Issues like #21912, may play into the consideration, and I think we should be able to reserve making a choice until that point. So what this PR does is only allow reinterprets when they would not expose padding. This should hopefully cover the most common use cases of reinterpret: - Reinterpreting a vector or matrix of values to StaticVectors of the same element type. These should generally always have compatiable padding (if not, reinterpret was likely the wrong API to use). - Reinterpreting from a Vector{UInt8} to a vector of structs (that may have padding). This PR allows this for reading (but not for writing). Both cases are generally better served by the IO APIs, but hopefully this should still allow the common cases. Fixes #25908
1 parent c670739 commit 4d6e516

5 files changed

Lines changed: 147 additions & 6 deletions

File tree

base/atomics.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ inttype(::Type{Float32}) = Int32
337337
inttype(::Type{Float64}) = Int64
338338

339339

340-
alignment(::Type{T}) where {T} = ccall(:jl_alignment, Cint, (Csize_t,), sizeof(T))
340+
gc_alignment(::Type{T}) where {T} = ccall(:jl_alignment, Cint, (Csize_t,), sizeof(T))
341341

342342
# All atomic operations have acquire and/or release semantics, depending on
343343
# whether the load or store values. Most of the time, this is what one wants
@@ -350,13 +350,13 @@ for typ in atomictypes
350350
@eval getindex(x::Atomic{$typ}) =
351351
llvmcall($"""
352352
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
353-
%rv = load atomic $rt %ptr acquire, align $(alignment(typ))
353+
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ))
354354
ret $lt %rv
355355
""", $typ, Tuple{Ptr{$typ}}, unsafe_convert(Ptr{$typ}, x))
356356
@eval setindex!(x::Atomic{$typ}, v::$typ) =
357357
llvmcall($"""
358358
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
359-
store atomic $lt %1, $lt* %ptr release, align $(alignment(typ))
359+
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ))
360360
ret void
361361
""", Cvoid, Tuple{Ptr{$typ}, $typ}, unsafe_convert(Ptr{$typ}, x), v)
362362

base/iterators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,9 @@ mutable struct Stateful{T, VS}
10301030
# A bit awkward right now, but adapted to the new iteration protocol
10311031
nextvalstate::Union{VS, Nothing}
10321032
taken::Int
1033+
@inline function Stateful{<:Any, Any}(itr::T) where {T}
1034+
new{T, Any}(itr, iterate(itr), 0)
1035+
end
10331036
@inline function Stateful(itr::T) where {T}
10341037
VS = approx_iter_type(T)
10351038
new{T, VS}(itr, iterate(itr)::VS, 0)

base/reinterpretarray.jl

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ the first dimension.
77
"""
88
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
99
parent::A
10+
readable::Bool
11+
writable::Bool
1012
global reinterpret
1113
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
1214
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@@ -31,10 +33,32 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
3133
dim = size(a)[1]
3234
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
3335
end
34-
new{T, N, S, A}(a)
36+
readable = array_subpadding(T, S)
37+
writable = array_subpadding(S, T)
38+
new{T, N, S, A}(a, readable, writable)
3539
end
3640
end
3741

42+
function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S}
43+
# See comment in check_writable
44+
if !a.readable && !array_subpadding(T, S)
45+
throw(PaddingError(T, S))
46+
end
47+
end
48+
49+
function check_writable(a::ReinterpretArray{T, N, S} where N) where {T,S}
50+
# `array_subpadding` is relatively expensive (compared to a simple arrayref),
51+
# so it is cached in the array. However, it is computable at compile time if,
52+
# inference has the types available. By using this form of the check, we can
53+
# get the best of both worlds for the success case. If the types were not
54+
# available to inference, we simply need to check the field (relatively cheap)
55+
# and if they were we should be able to fold this check away entirely.
56+
if !a.writable && !array_subpadding(S, T)
57+
throw(PaddingError(T, S))
58+
end
59+
end
60+
61+
3862
parent(a::ReinterpretArray) = a.parent
3963
dataids(a::ReinterpretArray) = dataids(a.parent)
4064

@@ -51,6 +75,7 @@ unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} =
5175
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]
5276

5377
@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
78+
check_readable(a)
5479
# Make sure to match the scalar reinterpret if that is applicable
5580
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
5681
return reinterpret(T, a.parent[inds...])
@@ -85,6 +110,7 @@ end
85110
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)
86111

87112
@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
113+
check_writable(a)
88114
v = convert(T, v)::T
89115
# Make sure to match the scalar reinterpret if that is applicable
90116
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
@@ -136,3 +162,97 @@ end
136162
end
137163
return a
138164
end
165+
166+
# Padding
167+
struct Padding
168+
offset::Int
169+
size::Int
170+
end
171+
function intersect(p1::Padding, p2::Padding)
172+
start = max(p1.offset, p2.offset)
173+
stop = min(p1.offset + p1.size, p2.offset + p2.size)
174+
Padding(start, max(0, stop-start))
175+
end
176+
177+
struct PaddingError
178+
S::Type
179+
T::Type
180+
end
181+
182+
function showerror(io::IO, p::PaddingError)
183+
print(io, "Padding of type $(p.S) is not compatible with type $(p.T).")
184+
end
185+
186+
"""
187+
CyclePadding(padding, total_size)
188+
189+
Cylces an iterator of `Padding` structs, restarting the padding at `total_size`.
190+
E.g. if `padding` is all the padding in a struct and `total_size` is the total
191+
aligned size of that array, `CyclePadding` will correspond to the padding in an
192+
infinite vector of such structs.
193+
"""
194+
struct CyclePadding{P}
195+
padding::P
196+
total_size::Int
197+
end
198+
eltype(::Type{<:CyclePadding}) = Padding
199+
IteratorSize(::Type{<:CyclePadding}) = IsInfinite()
200+
isempty(cp::CyclePadding) = isempty(cp.padding)
201+
function iterate(cp::CyclePadding)
202+
y = iterate(cp.padding)
203+
y === nothing && return nothing
204+
y[1], (0, y[2])
205+
end
206+
function iterate(cp::CyclePadding, state::Tuple)
207+
y = iterate(cp.padding, tail(state)...)
208+
y === nothing && return iterate(cp, (state[1]+cp.total_size,))
209+
Padding(y[1].offset+state[1], y[1].size), (state[1], tail(y)...)
210+
end
211+
212+
"""
213+
Compute the location of padding in a type.
214+
"""
215+
function padding(T)
216+
padding = Padding[]
217+
last_end::Int = 0
218+
for i = 1:fieldcount(T)
219+
offset = fieldoffset(T, i)
220+
fT = fieldtype(T, i)
221+
if offset != last_end
222+
push!(padding, Padding(offset, offset-last_end))
223+
end
224+
last_end = offset + sizeof(fT)
225+
end
226+
padding
227+
end
228+
229+
function CyclePadding(T::DataType)
230+
a, s = datatype_alignment(T), sizeof(T)
231+
as = s + (a - (s % a)) % a
232+
pad = padding(T)
233+
s != as && push!(pad, Padding(s, as - s))
234+
CyclePadding(pad, as)
235+
end
236+
237+
using .Iterators: Stateful
238+
@pure function array_subpadding(S, T)
239+
checked_size = 0
240+
lcm_size = lcm(sizeof(S), sizeof(T))
241+
s, t = Stateful{<:Any, Any}(CyclePadding(S)),
242+
Stateful{<:Any, Any}(CyclePadding(T))
243+
isempty(t) && return true
244+
isempty(s) && return false
245+
while checked_size < lcm_size
246+
# Take padding in T
247+
pad = popfirst!(t)
248+
# See if there's corresponding padding in S
249+
while true
250+
ps = peek(s)
251+
ps.offset > pad.offset && return false
252+
intersect(ps, pad) == pad && break
253+
popfirst!(s)
254+
end
255+
checked_size = pad.offset + pad.size
256+
end
257+
return true
258+
end

base/sysimg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ include("array.jl")
137137
include("abstractarray.jl")
138138
include("subarray.jl")
139139
include("views.jl")
140-
include("reinterpretarray.jl")
141-
142140

143141
# ## dims-type-converting Array constructors for convenience
144142
# type and dimensionality specified, accepting dims as series of Integers
@@ -205,6 +203,7 @@ include("reduce.jl")
205203

206204
## core structures
207205
include("reshapedarray.jl")
206+
include("reinterpretarray.jl")
208207
include("bitarray.jl")
209208
include("bitset.jl")
210209

test/reinterpretarray.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,22 @@ let A = collect(reshape(1:20, 5, 4))
4949
@test view(R, :, :) isa StridedArray
5050
@test reshape(R, :) isa StridedArray
5151
end
52+
53+
# Error on reinterprets that would expose padding
54+
struct S1
55+
a::Int8
56+
b::Int64
57+
end
58+
59+
struct S2
60+
a::Int16
61+
b::Int64
62+
end
63+
64+
A1 = S1[S1(0, 0)]
65+
A2 = S2[S2(0, 0)]
66+
@test reinterpret(S1, A2)[1] == S1(0, 0)
67+
@test_throws Base.PaddingError (reinterpret(S1, A2)[1] = S2(1, 2))
68+
@test_throws Base.PaddingError reinterpret(S2, A1)[1]
69+
reinterpret(S2, A1)[1] = S2(1, 2)
70+
@test A1[1] == S1(1, 2)

0 commit comments

Comments
 (0)