Skip to content

Commit 48a86d6

Browse files
committed
optimized mapreduce using sub group shuffle
ref #352 Unfortunately, I don't really see any performance improvements with this, any ideas why? I expected this to be quite a bit faster.
1 parent 7c4881d commit 48a86d6

File tree

2 files changed

+115
-1
lines changed

2 files changed

+115
-1
lines changed

lib/intrinsics/src/work_item.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,50 @@ for (julia_name, (spirv_name, julia_type, offset)) in [
3434
end
3535
end
3636

37+
38+
# Sub-group shuffle intrinsics using a loop and @eval, matching the style of the 1D/3D value loops above
39+
export sub_group_shuffle, sub_group_shuffle_xor
40+
41+
for (jltype, llvmtype, julia_type_str) in [
42+
(Int8, "i8", :Int8),
43+
(UInt8, "i8", :UInt8),
44+
(Int16, "i16", :Int16),
45+
(UInt16, "i16", :UInt16),
46+
(Int32, "i32", :Int32),
47+
(UInt32, "i32", :UInt32),
48+
(Int64, "i64", :Int64),
49+
(UInt64, "i64", :UInt64),
50+
(Float16, "half", :Float16),
51+
(Float32, "float", :Float32),
52+
(Float64, "double",:Float64)
53+
]
54+
@eval begin
55+
export sub_group_shuffle, sub_group_shuffle_xor
56+
function sub_group_shuffle(x::$jltype, idx::Integer)
57+
Base.llvmcall(
58+
$("""
59+
declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32)
60+
define $llvmtype @entry($llvmtype %val, i32 %idx) #0 {
61+
%res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx)
62+
ret $llvmtype %res
63+
}
64+
attributes #0 = { alwaysinline }
65+
""", "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, idx % Int32 - 1i32)
66+
end
67+
function sub_group_shuffle_xor(x::$jltype, mask::Integer)
68+
Base.llvmcall(
69+
$("""
70+
declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32)
71+
define $llvmtype @entry($llvmtype %val, i32 %mask) #0 {
72+
%res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask)
73+
ret $llvmtype %res
74+
}
75+
attributes #0 = { alwaysinline }
76+
""", "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, mask % UInt32)
77+
end
78+
end
79+
end
80+
3781
# 3D values
3882
for (julia_name, (spirv_name, offset)) in [
3983
# indices

src/mapreduce.jl

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,78 @@
55
# - group-stride loop to delay need for second kernel launch
66
# - let the driver choose the local size
77

8+
function shuffle_expr(::Type{T}) where {T}
9+
if T in SPIRVIntrinsics.generic_integer_types || T in SPIRVIntrinsics.generic_types
10+
return :(sub_group_shuffle(val, i))
11+
elseif Base.isstructtype(T)
12+
ex = Expr(:new, T)
13+
for f in fieldnames(T)
14+
ex_f = shuffle_expr(fieldtype(T, f))
15+
ex_f === nothing && return nothing
16+
push!(ex.args, :(let val = getfield(val, $(QuoteNode(f)))
17+
$ex_f
18+
end))
19+
end
20+
return ex
21+
else
22+
return nothing
23+
end
24+
end
25+
26+
@inline @generated function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
27+
ex = shuffle_expr(T)
28+
if ex === nothing
29+
return :(reduce_group_fallback(op, val, neutral, Val(maxitems)))
30+
end
31+
32+
quote
33+
# Subgroup shuffle-based warp reduction
34+
lane = get_sub_group_local_id()
35+
width = get_sub_group_size()
36+
37+
offset = 1
38+
while offset < width
39+
if lane > offset
40+
i = lane - offset
41+
other = $ex
42+
val = op(val, other)
43+
end
44+
offset <<= 1
45+
end
46+
47+
items = get_num_sub_groups()
48+
item = get_sub_group_id()
49+
50+
shared = CLLocalArray(T, (maxitems,))
51+
if items > 1 && lane == 1
52+
@inbounds shared[item] = val
53+
54+
d = 1
55+
while d < items
56+
work_group_barrier(LOCAL_MEM_FENCE)
57+
index = 2 * d * (item-1) + 1
58+
@inbounds if index <= items
59+
other_val = if index + d <= items
60+
shared[index+d]
61+
else
62+
neutral
63+
end
64+
shared[index] = op(shared[index], other_val)
65+
end
66+
d *= 2
67+
end
68+
69+
if item == 1
70+
val = @inbounds shared[item]
71+
end
72+
end
73+
74+
return val
75+
end
76+
end
77+
878
# Reduce a value across a group, using local memory for communication
9-
@inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
79+
@inline function reduce_group_fallback(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
1080
items = get_local_size()
1181
item = get_local_id()
1282

0 commit comments

Comments
 (0)