Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
blocks_per_sm::Union{Nothing,Int} = nothing
maxregs::Union{Nothing,Int} = nothing

fastmath::Bool = Base.JLOptions().fast_math == 1

# deprecated; remove with next major version
exitable::Union{Nothing,Bool} = nothing
unreachable::Union{Nothing,Bool} = nothing
Expand All @@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
h = hash(target.maxthreads, h)
h = hash(target.blocks_per_sm, h)
h = hash(target.maxregs, h)
h = hash(target.fastmath, h)

h
end
Expand Down Expand Up @@ -81,6 +84,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)")
job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)")
job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)")
job.config.target.fastmath && print(io, ", fast math enabled")
end

const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free")
Expand Down Expand Up @@ -423,7 +427,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# handle possible cases
# XXX: put some of these property in the compiler job?
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
fast_math = Base.JLOptions().fast_math == 1
fast_math = current_job.config.target.fastmath
# NOTE: we follow nvcc's --use_fast_math
reflect_val = if reflect_arg == "__CUDA_FTZ"
# single-precision denormals support
Expand All @@ -432,7 +436,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# single-precision floating-point division and reciprocals.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_PREC_SQRT"
# single-precision denormals support
# single-precision floating point square roots.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_FMAD"
# contraction of floating-point multiplies and adds/subtracts into
Expand Down