diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 989a3e5536ec6..e1ab9c905447b 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -91,34 +91,42 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { } /// Expands tanh op into -/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 -/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 +/// 1-exp^{-2x} / 1+exp^{-2x} +/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`. +/// We compute a "signs" value which is -1 if input is negative and +1 if input +/// is positive. Then multiply the input by this value, guaranteeing that the +/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0, +/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the +/// result by `sign(x)` to retain sign of the real result. static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { auto floatType = op.getOperand().getType(); Location loc = op.getLoc(); + Value zero = createFloatConst(loc, floatType, 0.0, rewriter); Value one = createFloatConst(loc, floatType, 1.0, rewriter); - Value two = createFloatConst(loc, floatType, 2.0, rewriter); - Value doubledX = rewriter.create(loc, op.getOperand(), two); - - // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} - Value negDoubledX = rewriter.create(loc, doubledX); + Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); + + // Compute sign(x) = cast(x < 0) * (-2) + 1 + Value isNegative = rewriter.create( + loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); + Value isNegativeFloat = + rewriter.create(loc, floatType, isNegative); + Value isNegativeTimesNegTwo = + rewriter.create(loc, isNegativeFloat, negTwo); + Value sign = rewriter.create(loc, isNegativeTimesNegTwo, one); + + // Normalize input to positive value: y = sign(x) * x + Value positiveX = rewriter.create(loc, sign, op.getOperand()); + + // Decompose on normalized input + Value negDoubledX = rewriter.create(loc, negTwo, positiveX); Value exp2x = rewriter.create(loc, negDoubledX); Value dividend = rewriter.create(loc, one, exp2x); Value divisor = rewriter.create(loc, one, exp2x); Value positiveRes = rewriter.create(loc, dividend, divisor); - // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 - exp2x = rewriter.create(loc, doubledX); - dividend = rewriter.create(loc, exp2x, one); - divisor = rewriter.create(loc, exp2x, one); - Value negativeRes = rewriter.create(loc, dividend, divisor); + // Multiply result by sign(x) to retain signs from negative inputs + rewriter.replaceOpWithNewOp(op, sign, positiveRes); - // tanh(x) = x >= 0 ? positiveRes : negativeRes - Value zero = createFloatConst(loc, floatType, 0.0, rewriter); - Value cmpRes = rewriter.create(loc, arith::CmpFPredicate::OGE, - op.getOperand(), zero); - rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, - negativeRes); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 6ee65b085dad1..6326d3a71874b 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -7,19 +7,18 @@ func.func @tanh(%arg: f32) -> f32 { } // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[TWO:.+]] = arith.constant 2.000000e+00 : f32 -// CHECK: %[[DOUBLEDX:.+]] = arith.mulf %arg0, %[[TWO]] : f32 -// CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32 +// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32 +// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32 +// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32 +// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32 +// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32 +// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32 +// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32 // CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32 // CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32 // CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32 -// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32 -// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32 -// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32 -// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32 -// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32 -// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32 -// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32 +// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32 +// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32 // CHECK: return %[[RESULT]] // ----- diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir index 541a201c94c58..e2229a392bbf7 100644 --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -683,6 +683,24 @@ func.func @cosh() { return } +// -------------------------------------------------------------------------- // +// Tanh. +// -------------------------------------------------------------------------- // + +func.func @tanh_8xf32(%a : vector<8xf32>) { + %r = math.tanh %a : vector<8xf32> + vector.print %r : vector<8xf32> + return +} + +func.func @tanh() { + // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1 + %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32> + call @tanh_8xf32(%v3) : (vector<8xf32>) -> () + + return +} + func.func @main() { call @exp2f() : () -> () call @roundf() : () -> () @@ -690,5 +708,6 @@ func.func @main() { call @roundeven() : () -> () call @sinh() : () -> () call @cosh() : () -> () + call @tanh() : () -> () return }