Skip to content

Commit cd1880b

Browse files
authored
Revert "SROA: generalize unswitchtupleunion optimization (#50502)"
This reverts commit 3995278.
1 parent 5baaafd commit cd1880b

File tree

3 files changed

+21
-59
lines changed

3 files changed

+21
-59
lines changed

base/compiler/ssair/passes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,8 +1107,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
11071107
end
11081108
struct_typ = widenconst(argextype(val, compact))
11091109
struct_typ_unwrapped = unwrap_unionall(struct_typ)
1110-
if isa(struct_typ, Union)
1111-
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
1110+
if isa(struct_typ, Union) && struct_typ <: Tuple
1111+
struct_typ_unwrapped = unswitchtupleunion(struct_typ_unwrapped)
11121112
end
11131113
if isa(struct_typ_unwrapped, Union) && is_isdefined
11141114
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)

base/compiler/typeutils.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -317,40 +317,33 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
317317
return depth
318318
end
319319

320-
# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
320+
# convert a Union of Tuple types to a Tuple of Unions
321+
unswitchtupleunion(u::Union) = unswitchtypeunion(u, Tuple.name)
322+
321323
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
322324
ts = uniontypes(u)
323325
n = -1
324326
for t in ts
325-
t isa DataType || return u
326-
if typename === nothing
327-
typename = t.name
328-
elseif typename !== t.name
329-
return u
330-
end
331-
params = t.parameters
332-
np = length(params)
333-
if np == 0 || isvarargtype(params[end])
334-
return u
335-
end
336-
if n == -1
337-
n = np
338-
elseif n np
327+
if t isa DataType
328+
if typename === nothing
329+
typename = t.name
330+
elseif typename !== t.name
331+
return u
332+
end
333+
if length(t.parameters) != 0 && !isvarargtype(t.parameters[end])
334+
if n == -1
335+
n = length(t.parameters)
336+
elseif n != length(t.parameters)
337+
return u
338+
end
339+
end
340+
else
339341
return u
340342
end
341343
end
342344
Head = (typename::Core.TypeName).wrapper
343-
hparams = Any[]
344-
for i = 1:n
345-
uparams = Any[]
346-
for t in ts
347-
tpᵢ = (t::DataType).parameters[i]
348-
tpᵢ isa Type || return u
349-
push!(uparams, tpᵢ)
350-
end
351-
push!(hparams, Union{uparams...})
352-
end
353-
return Head{hparams...}
345+
unionparams = Any[ Union{Any[(t::DataType).parameters[i] for t in ts]...} for i in 1:n ]
346+
return Head{unionparams...}
354347
end
355348

356349
function unwraptv_ub(@nospecialize t)

test/compiler/irpasses.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,34 +1390,3 @@ function wrap1_wrap1_wrapper(b, x, y)
13901390
end
13911391
@test wrap1_wrap1_wrapper(true, 1, 1.0) === 1.0
13921392
@test wrap1_wrap1_wrapper(false, 1, 1.0) === 1
1393-
1394-
# Test unswitching-union optimization within SRO Apass
1395-
function sroaunswitchuniontuple(c, x1, x2)
1396-
t = c ? (x1,) : (x2,)
1397-
return getfield(t, 1)
1398-
end
1399-
struct SROAUnswitchUnion1{T}
1400-
x::T
1401-
end
1402-
struct SROAUnswitchUnion2{S,T}
1403-
x::T
1404-
@inline SROAUnswitchUnion2{S}(x::T) where {S,T} = new{S,T}(x)
1405-
end
1406-
function sroaunswitchunionstruct1(c, x1, x2)
1407-
x = c ? SROAUnswitchUnion1(x1) : SROAUnswitchUnion1(x2)
1408-
return getfield(x, :x)
1409-
end
1410-
function sroaunswitchunionstruct2(c, x1, x2)
1411-
x = c ? SROAUnswitchUnion2{:a}(x1) : SROAUnswitchUnion2{:a}(x2)
1412-
return getfield(x, :x)
1413-
end
1414-
let src = code_typed1(sroaunswitchuniontuple, Tuple{Bool, Int, Float64})
1415-
@test count(isnew, src.code) == 0
1416-
@test count(iscall((src, getfield)), src.code) == 0
1417-
end
1418-
let src = code_typed1(sroaunswitchunionstruct1, Tuple{Bool, Int, Float64})
1419-
@test count(isnew, src.code) == 0
1420-
@test count(iscall((src, getfield)), src.code) == 0
1421-
end
1422-
@test sroaunswitchunionstruct2(true, 1, 1.0) === 1
1423-
@test sroaunswitchunionstruct2(false, 1, 1.0) === 1.0

0 commit comments

Comments
 (0)