Skip to content

Commit b7abf7b

Browse files
committed
try to fix fma on windows
1 parent 5bd3a6d commit b7abf7b

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

base/compiler/tfuncs.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,8 @@ function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecializ
15631563
return _builtin_nothrow(f, argtypes, rt)
15641564
end
15651565

1566+
julia_fma(x, y, z) = error()
1567+
15661568
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Array{Any,1},
15671569
sv::Union{InferenceState,Nothing})
15681570
if f === tuple
@@ -1572,7 +1574,11 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
15721574
if is_pure_intrinsic_infer(f) && _all(@nospecialize(a) -> isa(a, Const), argtypes)
15731575
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
15741576
try
1575-
return Const(f(argvals...))
1577+
if f === Intrinsics.fma_float
1578+
return Const(julia_fma(argvals...))
1579+
else
1580+
return Const(f(argvals...))
1581+
end
15761582
catch
15771583
end
15781584
end

base/floatfuncs.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,14 @@ fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)
415415
fma(x::Float32, y::Float32, z::Float32) = Core.Intrinsics.have_fma(Float32) ? fma_llvm(x,y,z) : fma_emulated(x,y,z)
416416
fma(x::Float64, y::Float64, z::Float64) = Core.Intrinsics.have_fma(Float64) ? fma_llvm(x,y,z) : fma_emulated(x,y,z)
417417

418+
@static if Sys.iswindows()
419+
Core.Compiler.julia_fma(x::Float32, y::Float32, z::Float32) = fma_emulated(x,y,z)
420+
Core.Compiler.julia_fma(x::Float64, y::Float64, z::Float64) = fma_emulated(x,y,z)
421+
else
422+
Core.Compiler.julia_fma(x::Float32, y::Float32, z::Float32) = fma_float(x,y,z)
423+
Core.Compiler.julia_fma(x::Float64, y::Float64, z::Float64) = fma_float(x,y,z)
424+
end
425+
418426
function fma(a::Float16, b::Float16, c::Float16)
419427
Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it.
420428
end

src/llvm-cpufeatures.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ Optional<bool> always_have_fma(Function &intr) {
3636
auto intr_name = intr.getName();
3737
auto typ = intr_name.substr(strlen("julia.cpu.have_fma."));
3838

39-
#if defined(_OS_WINDOWS_)
39+
// #if defined(_OS_WINDOWS_)
4040
// FMA on Windows is weirdly broken (#43088)
41-
return false;
42-
#elif defined(_CPU_AARCH64_)
41+
// return false;
42+
#if defined(_CPU_AARCH64_)
4343
return typ == "f32" || typ == "f64";
4444
#else
4545
(void)typ;

test/math.jl

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,29 +1289,32 @@ end
12891289
end
12901290

12911291
@testset "fma" begin
1292-
if !(@static Sys.iswindows() && Int===Int64) # windows fma currently seems broken somehow.
1293-
for func in (fma, Base.fma_emulated)
1294-
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
1295-
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
1296-
@testset "$T" for T in (Float32, Float64)
1297-
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
1298-
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
1299-
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
1300-
@test func(floatmax(T), floatmax(T), -T(Inf)) === -T(Inf)
1301-
@test func(floatmax(T), -floatmax(T), T(Inf)) === T(Inf)
1302-
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
1303-
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
1304-
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
1305-
for _ in 1:2^18
1306-
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
1307-
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
1308-
end
1292+
for func in (fma, Base.fma_emulated)
1293+
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
1294+
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
1295+
@testset "$T" for T in (Float32, Float64)
1296+
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
1297+
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
1298+
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
1299+
@test func(floatmax(T), floatmax(T), -T(Inf)) === -T(Inf)
1300+
@test func(floatmax(T), -floatmax(T), T(Inf)) === T(Inf)
1301+
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
1302+
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
1303+
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
1304+
for _ in 1:2^18
1305+
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
1306+
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
13091307
end
1310-
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
1311-
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
1312-
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
1313-
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
1314-
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
13151308
end
1309+
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
1310+
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
1311+
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
1312+
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
1313+
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
1314+
end
1315+
@static if Sys.iswindows()
1316+
# TODO: if this pass one day, then we can remove fma hack on windows
1317+
error = @eval Base.fma_float(-1.9369631f13, 2.1513551f-7, -1.7354427f-24)
1318+
@test_broken error == fma(-1.9369631f13, 2.1513551f-7, -1.7354427f-24)
13161319
end
13171320
end

0 commit comments

Comments
 (0)