|
| 1 | +module CompatCartesian |
| 2 | + |
| 3 | +export @ngenerate, @nsplat |
| 4 | + |
| 5 | +macro ngenerate(itersym, returntypeexpr, funcexpr) |
| 6 | + if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline") |
| 7 | + funcexpr = Base._inline(funcexpr.args[2]) |
| 8 | + end |
| 9 | + isfuncexpr(funcexpr) || error("Requires a function expression") |
| 10 | + esc(_ngenerate(itersym, funcexpr)) |
| 11 | +end |
| 12 | + |
| 13 | +function _ngenerate(itersym::Symbol, funcexpr::Expr) |
| 14 | + prototype = funcexpr.args[1] |
| 15 | + body = funcexpr.args[2] |
| 16 | + varname, T = get_splatinfo(prototype, itersym) |
| 17 | + ex = Expr(:$, itersym) |
| 18 | + sreplace!(body, itersym, ex) |
| 19 | + if !isempty(varname) |
| 20 | + prototype, body = _nsplat(prototype, body, varname, T, itersym) |
| 21 | + else |
| 22 | + body = Expr(:quote, body) |
| 23 | + end |
| 24 | + Expr(:stagedfunction, prototype, body) |
| 25 | +end |
| 26 | + |
| 27 | +macro nsplat(itersym, args...) |
| 28 | + if length(args) == 1 |
| 29 | + funcexpr = args[1] |
| 30 | + elseif length(args) == 2 |
| 31 | + funcexpr = args[2] |
| 32 | + else |
| 33 | + error("Wrong number of arguments") |
| 34 | + end |
| 35 | + if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline") |
| 36 | + funcexpr = Base._inline(funcexpr.args[2]) |
| 37 | + end |
| 38 | + isfuncexpr(funcexpr) || error("Second argument must be a function expression") |
| 39 | + prototype = funcexpr.args[1] |
| 40 | + body = funcexpr.args[2] |
| 41 | + varname, T = get_splatinfo(prototype, itersym) |
| 42 | + isempty(varname) && error("Last argument must be a splat") |
| 43 | + prototype, body = _nsplat(prototype, body, varname, T, itersym) |
| 44 | + esc(Expr(:stagedfunction, prototype, body)) |
| 45 | +end |
| 46 | + |
| 47 | +function _nsplat(prototype, body, varname, T, itersym) |
| 48 | + varsym = symbol(varname) |
| 49 | + prototype.args[end] = Expr(:..., Expr(:(::), varsym, T)) # :($varsym::$T...) |
| 50 | + varquot = Expr(:quote, varsym) |
| 51 | + bodyquot = Expr(:quote, body) |
| 52 | + newbody = quote |
| 53 | + $itersym = length($varsym) |
| 54 | + Compat.CompatCartesian.resolvesplats!($bodyquot, $varquot, $itersym) |
| 55 | + end |
| 56 | + prototype, newbody |
| 57 | +end |
| 58 | + |
| 59 | +# If using the syntax that will need "desplatting", |
| 60 | +# myfunction(A::AbstractArray, I::NTuple{N, Int}...) |
| 61 | +# return the variable name (as a string) and type |
| 62 | +function get_splatinfo(ex::Expr, itersym::Symbol) |
| 63 | + if ex.head == :call |
| 64 | + a = ex.args[end] |
| 65 | + if isa(a, Expr) && a.head == :... && length(a.args) == 1 |
| 66 | + b = a.args[1] |
| 67 | + if isa(b, Expr) && b.head == :(::) |
| 68 | + varname = string(b.args[1]) |
| 69 | + c = b.args[2] |
| 70 | + if isa(c, Expr) && c.head == :curly && c.args[1] == :NTuple && c.args[2] == itersym |
| 71 | + T = c.args[3] |
| 72 | + return varname, T |
| 73 | + end |
| 74 | + end |
| 75 | + end |
| 76 | + end |
| 77 | + "", Void |
| 78 | +end |
| 79 | + |
| 80 | +resolvesplats!(arg, varname, N) = arg |
| 81 | +function resolvesplats!(ex::Expr, varname, N::Int) |
| 82 | + if ex.head == :call |
| 83 | + for i = 2:length(ex.args)-1 |
| 84 | + resolvesplats!(ex.args[i], varname, N) |
| 85 | + end |
| 86 | + a = ex.args[end] |
| 87 | + if isa(a, Expr) && a.head == :... && a.args[1] == symbol(varname) |
| 88 | + ex.args[end] = :($varname[1]) # Expr(:ref, varname, 1) |
| 89 | + for i = 2:N |
| 90 | + push!(ex.args, :($varname[$i])) # Expr(:ref, varname, i)) |
| 91 | + end |
| 92 | + else |
| 93 | + resolvesplats!(a, varname, N) |
| 94 | + end |
| 95 | + else |
| 96 | + for i = 1:length(ex.args) |
| 97 | + resolvesplats!(ex.args[i], varname, N) |
| 98 | + end |
| 99 | + end |
| 100 | + ex |
| 101 | +end |
| 102 | + |
| 103 | +isfuncexpr(ex::Expr) = |
| 104 | + ex.head == :function || (ex.head == :(=) && typeof(ex.args[1]) == Expr && ex.args[1].head == :call) |
| 105 | +isfuncexpr(arg) = false |
| 106 | + |
| 107 | +sreplace!(arg, sym, val) = arg |
| 108 | +function sreplace!(ex::Expr, sym, val) |
| 109 | + for i = 1:length(ex.args) |
| 110 | + ex.args[i] = sreplace!(ex.args[i], sym, val) |
| 111 | + end |
| 112 | + ex |
| 113 | +end |
| 114 | +sreplace!(s::Symbol, sym, val) = s == sym ? val : s |
| 115 | + |
| 116 | +# If using the syntax that will need "desplatting", |
| 117 | +# myfunction(A::AbstractArray, I::NTuple{N, Int}...) |
| 118 | +# return the variable name (as a string) and type |
| 119 | +function get_splatinfo(ex::Expr, itersym::Symbol) |
| 120 | + if ex.head == :call |
| 121 | + a = ex.args[end] |
| 122 | + if isa(a, Expr) && a.head == :... && length(a.args) == 1 |
| 123 | + b = a.args[1] |
| 124 | + if isa(b, Expr) && b.head == :(::) |
| 125 | + varname = string(b.args[1]) |
| 126 | + c = b.args[2] |
| 127 | + if isa(c, Expr) && c.head == :curly && c.args[1] == :NTuple && c.args[2] == itersym |
| 128 | + T = c.args[3] |
| 129 | + return varname, T |
| 130 | + end |
| 131 | + end |
| 132 | + end |
| 133 | + end |
| 134 | + "", Void |
| 135 | +end |
| 136 | + |
| 137 | +isfuncexpr(ex::Expr) = |
| 138 | + ex.head == :function || (ex.head == :(=) && typeof(ex.args[1]) == Expr && ex.args[1].head == :call) |
| 139 | +isfuncexpr(arg) = false |
| 140 | + |
| 141 | +end |
0 commit comments