diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 0f74389a9a5f2..f1176cc5a611c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -412,10 +412,28 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" since = "3.0.0", group = "math_funcs") case class Acosh(child: Expression) - extends UnaryMathExpression((x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0)), "ACOSH") { + extends UnaryMathExpression((x: Double) => x match { + // in case of large values, the square would lead to Infinity; also, - 1 would be ignored due + // to numeric precision. So log(x + sqrt(x * x - 1)) becomes log(2x) = log(2) + log(x) for + // positive values. + case x if x >= Math.sqrt(Double.MaxValue) => + StrictMath.log(2) + StrictMath.log(x) + case x if x < 1 => + Double.NaN + case _ => StrictMath.log(x + math.sqrt(x * x - 1.0)) }, "ACOSH") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, - c => s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c - 1.0))") + nullSafeCodeGen(ctx, ev, c => { + val sm = "java.lang.StrictMath" + s""" + |if ($c >= ${Math.sqrt(Double.MaxValue)}) { + | ${ev.value} = $sm.log($c) + $sm.log(2); + |} else if ($c < 1) { + | ${ev.value} = java.lang.Double.NaN; + |} else { + | ${ev.value} = $sm.log($c + java.lang.Math.sqrt($c * $c - 1.0)); + |} + |""".stripMargin + }) } override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild) } @@ -842,12 +860,20 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" group = "math_funcs") case class Asinh(child: Expression) extends UnaryMathExpression((x: Double) => x match { - case Double.NegativeInfinity => Double.NegativeInfinity + // in case of large values, the square would lead to Infinity; also, + 1 would be ignored due + // to numeric precision. So log(x + sqrt(x * x + 1)) becomes log(2x) = log(2) + log(x) for + // positive values. Since the function is symmetric, for large values we can use + // signum(x) + log(2|x|) + case x if Math.abs(x) >= Math.sqrt(Double.MaxValue) - 1 => + Math.signum(x) * (StrictMath.log(2) + StrictMath.log(Math.abs(x))) case _ => StrictMath.log(x + math.sqrt(x * x + 1.0)) }, "ASINH") { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => - s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " + - s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c + 1.0))") + defineCodeGen(ctx, ev, c => { + val sm = "java.lang.StrictMath" + s"$sm.abs($c) >= ${Math.sqrt(Double.MaxValue) - 1} ? " + + s"$sm.signum($c) * ($sm.log($sm.abs($c)) + $sm.log(2)) :" + + s"$sm.log($c + java.lang.Math.sqrt($c * $c + 1.0))" + }) } override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 823a6d2ce8675..c78a9f563d424 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -280,7 +280,9 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("acosh") { - testUnary(Acosh, (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0))) + def f: (Double) => Double = (x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0)) + testUnary(Acosh, f, (10 to 20).map(_ * 0.1)) + testUnary(Acosh, f, (-20 to 9).map(_ * 0.1), expectNaN = true) checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) val nullLit = Literal.create(null, NullType) @@ -963,4 +965,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WidthBucket(5.35, 0.024, Double.NegativeInfinity, 5L), null) checkEvaluation(WidthBucket(5.35, 0.024, Double.PositiveInfinity, 5L), null) } + + test("SPARK-55557: hyperbolic functions should not overflow with large inputs") { + checkEvaluation(Asinh(Double.MaxValue), 710.4758600739439) + checkEvaluation(Asinh(Math.sqrt(Double.MaxValue)), 355.58450362725193) + checkEvaluation(Acosh(Double.MaxValue), 710.4758600739439) + checkEvaluation(Acosh(Math.sqrt(Double.MaxValue)), 355.58450362725193) + checkEvaluation(Asinh(Double.MinValue), -710.4758600739439) + checkEvaluation(Asinh(-Math.sqrt(Double.MaxValue)), -355.58450362725193) + checkNaN(Acosh(Double.MinValue)) + checkNaN(Acosh(-Math.sqrt(Double.MaxValue) + 1)) + checkNaN(Acosh(-Math.sqrt(Double.MaxValue) + 2)) + } }