Skip to content

Commit 6fc7b0a

Browse files
committed
Make foreachfield a static function
This allows `setindex!` on `StructArray`s to be used in `CUDAnative` kernels.
1 parent 9a5918f commit 6fc7b0a

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

src/utils.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,50 @@ else
2828
const _getproperty = getproperty
2929
end
3030

31-
function _foreachfield(names, xs)
31+
function _sstuple(::Type{<:NTuple{N, Any}}) where {N}
32+
ntuple(j->Symbol(j), N)
33+
end
34+
35+
function _sstuple(::Type{NT}) where {NT<:NamedTuple}
36+
_map_params(x->_sstuple(staticschema(x)), NT)
37+
end
38+
39+
function _getcolproperties!(exprs, s, es=[])
40+
if typeof(s) <: Symbol
41+
push!(exprs, es)
42+
return
43+
end
44+
for key in keys(s)
45+
_getcolproperties!(exprs, getproperty(s,key), vcat(es, key))
46+
end
47+
end
48+
49+
@generated function foreachfield(::Type{T}, f, xs...) where {T<:Tup}
50+
# TODO get columnsproperties directly from T without converting to the
51+
# tuple s.
52+
s = _sstuple(T)
53+
columnsproperties = []
54+
_getcolproperties!(columnsproperties, s)
55+
3256
exprs = Expr[]
33-
for field in names
34-
sym = QuoteNode(field)
35-
args = [Expr(:call, :_getproperty, :(getfield(xs, $j)), sym) for j in 1:length(xs)]
57+
for col in columnsproperties
58+
args = Expr[]
59+
for prop in col
60+
sym = QuoteNode(prop)
61+
if length(args) == 0
62+
args = [Expr(:call, :_getproperty, :(getfield(xs, $j)), sym) for j in 1:length(xs)]
63+
else
64+
for j in 1:length(xs)
65+
args[j] = Expr(:call, :_getproperty, args[j], sym)
66+
end
67+
end
68+
end
3669
push!(exprs, Expr(:call, :f, args...))
3770
end
3871
push!(exprs, :(return nothing))
3972
return Expr(:block, exprs...)
4073
end
4174

42-
@generated foreachfield(::Type{<:NamedTuple{names}}, f, xs...) where {names} = _foreachfield(names, xs)
43-
@generated foreachfield(::Type{<:NTuple{N, Any}}, f, xs...) where {N} = _foreachfield(Base.OneTo(N), xs)
44-
4575
foreachfield(f, x::T, xs...) where {T} = foreachfield(staticschema(T), f, x, xs...)
4676

4777
"""

0 commit comments

Comments
 (0)