Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/faster_code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ mutable struct CodegenState{T}
In such a case, if we assume that the index of `x + y` in `ir` is `4`, then
`cache[4] == Symbol("%3")`.
"""
const cache::Dictionary{Int32, Symbol}
const cache::Dictionary{Int32, Union{Symbol, Int, Float64, Float32}}
"""
Rewrite rules, similar to `NameState`.
"""
Expand All @@ -78,7 +78,7 @@ written to `block` and `ir` is the underlying `IRStructure`. Rewrite rules can o
be supplied as the last argument.
"""
function CodegenState(expr::Expr, block::Expr, ir::IRStructure{T}, rewrites = Dict()) where {T}
CodegenState{T}(expr, block, ir, Dictionary{Int32, Symbol}(), rewrites, 0)
CodegenState{T}(expr, block, ir, Dictionary{Int32, Union{Symbol, Int, Float64, Float32}}(), rewrites, 0)
end

"""
Expand Down Expand Up @@ -199,7 +199,7 @@ function enter_scope(cs::CodegenState{T}) where {T}
new_scope = Expr(:block)
bm = bookmark(cs)
scoped_cs = CodegenState{T}(new_scope, new_scope, cs.ir, cs.cache, cs.rewrites, cs.misc_idx)

return scoped_cs, bm
end

Expand Down Expand Up @@ -312,7 +312,10 @@ end
function fast_toexpr(sym::CodegenPrimitive, ir::IRStructure{T}, rewrites::Dict{Any, Any}) where {T}
expr = block = Expr(:block)
state = CodegenState(expr, block, ir, rewrites)
lhs = state(sym)::Symbol
lhs = state(sym)
if !(lhs isa Symbol)
return lhs
end
for line in expr.args
if Meta.isexpr(line, :(=)) && line.args[1] === lhs
return line.args[2]
Expand Down Expand Up @@ -535,7 +538,7 @@ function codegen_function!(::Type{ArrayMaker{T}}, cs::CodegenState{T}, expr::Bas
end
return declare!(cs, get_misc_identifier(cs), result)
end

if _allocator !== zeros && !__allocator_is_returns_expr(T, _allocator) &&
isequal(regions[1], sh) && __is_fill_zero(cs.ir[first(values_exprs_idxs)])
output_buffer = codegen_allocator_call!(
Expand Down Expand Up @@ -853,7 +856,7 @@ end

function codegen_ir!(cs::CodegenState{T}, idx::Integer) where {T}
cached = get(cs.cache, idx, nothing)
if cached isa Symbol
if cached !== nothing
return cached
end
ir = cs.ir
Expand All @@ -875,6 +878,9 @@ function codegen_ir!(cs::CodegenState{T}, idx::Integer) where {T}
@match sym begin
BSImpl.Const(; val) => if val isa CodegenPrimitive
return cs(val)
elseif val isa Union{Int, Float64, Float32}
insert!(cs.cache, idx, val)
return val
else
return codegen!(cs, idx, val)
end
Expand Down Expand Up @@ -933,10 +939,13 @@ function (cs::CodegenState)(@nospecialize(thing))
if uthing !== thing
return cs(uthing)
end
if thing isa Union{Int, Float64, Float32}
return thing
end
return declare!(cs, get_misc_identifier(cs), thing)
end

function (cs::CodegenState)(expr::BasicSymbolic{T})::Symbol where {T}
function (cs::CodegenState)(expr::BasicSymbolic{T}) where {T}
idx = populate_ir!(cs.ir, expr)
codegen_ir!(cs, idx)
end
Expand Down Expand Up @@ -1070,9 +1079,9 @@ function (cs::CodegenState{T})(fn::Func) where {T}
end

function (cs::CodegenState)(ex::SetArray)
arr = cs(ex.arr)::Symbol
arr = cs(ex.arr)::Union{Symbol, Int, Float64, Float32}
lhss = []
rhss = Symbol[]
rhss = Union{Symbol, Int, Float64, Float32}[]
for (i, elem) in enumerate(ex.elems)
if elem isa AtIndex
push!(lhss, cs(elem.i))
Expand Down
Loading
Loading