// -*- C++ -*-
//===---------------------------- cmath -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_CMATH
#define _LIBCUDACXX_CMATH

/*
    cmath synopsis

Macros:

    HUGE_VAL
    HUGE_VALF               // C99
    HUGE_VALL               // C99
    INFINITY                // C99
    NAN                     // C99
    FP_INFINITE             // C99
    FP_NAN                  // C99
    FP_NORMAL               // C99
    FP_SUBNORMAL            // C99
    FP_ZERO                 // C99
    FP_FAST_FMA             // C99
    FP_FAST_FMAF            // C99
    FP_FAST_FMAL            // C99
    FP_ILOGB0               // C99
    FP_ILOGBNAN             // C99
    MATH_ERRNO              // C99
    MATH_ERREXCEPT          // C99
    math_errhandling        // C99

namespace std
{

Types:

    float_t                 // C99
    double_t                // C99

// C90

floating_point abs(floating_point x);

floating_point acos (arithmetic x);
float          acosf(float x);
long double    acosl(long double x);

floating_point asin (arithmetic x);
float          asinf(float x);
long double    asinl(long double x);

floating_point atan (arithmetic x);
float          atanf(float x);
long double    atanl(long double x);

floating_point atan2 (arithmetic y, arithmetic x);
float          atan2f(float y, float x);
long double    atan2l(long double y, long double x);

floating_point ceil (arithmetic x);
float          ceilf(float x);
long double    ceill(long double x);

floating_point cos (arithmetic x);
float          cosf(float x);
long double    cosl(long double x);

floating_point cosh (arithmetic x);
float          coshf(float x);
long double    coshl(long double x);

floating_point exp (arithmetic x);
float          expf(float x);
long double    expl(long double x);

floating_point fabs (arithmetic x);
float          fabsf(float x);
long double    fabsl(long double x);

floating_point floor (arithmetic x);
float          floorf(float x);
long double    floorl(long double x);

floating_point fmod (arithmetic x, arithmetic y);
float          fmodf(float x, float y);
long double    fmodl(long double x, long double y);

floating_point frexp (arithmetic value, int* exp);
float          frexpf(float value, int* exp);
long double    frexpl(long double value, int* exp);

floating_point ldexp (arithmetic value, int exp);
float          ldexpf(float value, int exp);
long double    ldexpl(long double value, int exp);

floating_point log (arithmetic x);
float          logf(float x);
long double    logl(long double x);

floating_point log10 (arithmetic x);
float          log10f(float x);
long double    log10l(long double x);

floating_point modf (floating_point value, floating_point* iptr);
float          modff(float value, float* iptr);
long double    modfl(long double value, long double* iptr);

floating_point pow (arithmetic x, arithmetic y);
float          powf(float x, float y);
long double    powl(long double x, long double y);

floating_point sin (arithmetic x);
float          sinf(float x);
long double    sinl(long double x);

floating_point sinh (arithmetic x);
float          sinhf(float x);
long double    sinhl(long double x);

floating_point sqrt (arithmetic x);
float          sqrtf(float x);
long double    sqrtl(long double x);

floating_point tan (arithmetic x);
float          tanf(float x);
long double    tanl(long double x);

floating_point tanh (arithmetic x);
float          tanhf(float x);
long double    tanhl(long double x);

//  C99

bool signbit(arithmetic x);

int fpclassify(arithmetic x);

bool isfinite(arithmetic x);
bool isinf(arithmetic x);
bool isnan(arithmetic x);
bool isnormal(arithmetic x);

bool isgreater(arithmetic x, arithmetic y);
bool isgreaterequal(arithmetic x, arithmetic y);
bool isless(arithmetic x, arithmetic y);
bool islessequal(arithmetic x, arithmetic y);
bool islessgreater(arithmetic x, arithmetic y);
bool isunordered(arithmetic x, arithmetic y);

floating_point acosh (arithmetic x);
float          acoshf(float x);
long double    acoshl(long double x);

floating_point asinh (arithmetic x);
float          asinhf(float x);
long double    asinhl(long double x);

floating_point atanh (arithmetic x);
float          atanhf(float x);
long double    atanhl(long double x);

floating_point cbrt (arithmetic x);
float          cbrtf(float x);
long double    cbrtl(long double x);

floating_point copysign (arithmetic x, arithmetic y);
float          copysignf(float x, float y);
long double    copysignl(long double x, long double y);

floating_point erf (arithmetic x);
float          erff(float x);
long double    erfl(long double x);

floating_point erfc (arithmetic x);
float          erfcf(float x);
long double    erfcl(long double x);

floating_point exp2 (arithmetic x);
float          exp2f(float x);
long double    exp2l(long double x);

floating_point expm1 (arithmetic x);
float          expm1f(float x);
long double    expm1l(long double x);

floating_point fdim (arithmetic x, arithmetic y);
float          fdimf(float x, float y);
long double    fdiml(long double x, long double y);

floating_point fma (arithmetic x, arithmetic y, arithmetic z);
float          fmaf(float x, float y, float z);
long double    fmal(long double x, long double y, long double z);

floating_point fmax (arithmetic x, arithmetic y);
float          fmaxf(float x, float y);
long double    fmaxl(long double x, long double y);

floating_point fmin (arithmetic x, arithmetic y);
float          fminf(float x, float y);
long double    fminl(long double x, long double y);

floating_point hypot (arithmetic x, arithmetic y);
float          hypotf(float x, float y);
long double    hypotl(long double x, long double y);

double       hypot(double x, double y, double z);                // C++17
float        hypot(float x, float y, float z);                   // C++17
long double  hypot(long double x, long double y, long double z); // C++17

int ilogb (arithmetic x);
int ilogbf(float x);
int ilogbl(long double x);

floating_point lgamma (arithmetic x);
float          lgammaf(float x);
long double    lgammal(long double x);

long long llrint (arithmetic x);
long long llrintf(float x);
long long llrintl(long double x);

long long llround (arithmetic x);
long long llroundf(float x);
long long llroundl(long double x);

floating_point log1p (arithmetic x);
float          log1pf(float x);
long double    log1pl(long double x);

floating_point log2 (arithmetic x);
float          log2f(float x);
long double    log2l(long double x);

floating_point logb (arithmetic x);
float          logbf(float x);
long double    logbl(long double x);

long lrint (arithmetic x);
long lrintf(float x);
long lrintl(long double x);

long lround (arithmetic x);
long lroundf(float x);
long lroundl(long double x);

double      nan (const char* str);
float       nanf(const char* str);
long double nanl(const char* str);

floating_point nearbyint (arithmetic x);
float          nearbyintf(float x);
long double    nearbyintl(long double x);

floating_point nextafter (arithmetic x, arithmetic y);
float          nextafterf(float x, float y);
long double    nextafterl(long double x, long double y);

floating_point nexttoward (arithmetic x, long double y);
float          nexttowardf(float x, long double y);
long double    nexttowardl(long double x, long double y);

floating_point remainder (arithmetic x, arithmetic y);
float          remainderf(float x, float y);
long double    remainderl(long double x, long double y);

floating_point remquo (arithmetic x, arithmetic y, int* pquo);
float          remquof(float x, float y, int* pquo);
long double    remquol(long double x, long double y, int* pquo);

floating_point rint (arithmetic x);
float          rintf(float x);
long double    rintl(long double x);

floating_point round (arithmetic x);
float          roundf(float x);
long double    roundl(long double x);

floating_point scalbln (arithmetic x, long ex);
float          scalblnf(float x, long ex);
long double    scalblnl(long double x, long ex);

floating_point scalbn (arithmetic x, int ex);
float          scalbnf(float x, int ex);
long double    scalbnl(long double x, int ex);

floating_point tgamma (arithmetic x);
float          tgammaf(float x);
long double    tgammal(long double x);

floating_point trunc (arithmetic x);
float          truncf(float x);
long double    truncl(long double x);

}  // std

*/

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#if !_CCCL_COMPILER(NVRTC)
#  include <math.h>
#endif // !_CCCL_COMPILER(NVRTC)

#if _CCCL_COMPILER(NVHPC)
#  include <cmath>
#endif // _CCCL_COMPILER(NVHPC)

#include <cuda/std/__cmath/abs.h>
#include <cuda/std/__cmath/copysign.h>
#include <cuda/std/__cmath/exponential_functions.h>
#include <cuda/std/__cmath/fpclassify.h>
#include <cuda/std/__cmath/gamma.h>
#include <cuda/std/__cmath/hyperbolic_functions.h>
#include <cuda/std/__cmath/hypot.h>
#include <cuda/std/__cmath/inverse_hyperbolic_functions.h>
#include <cuda/std/__cmath/inverse_trigonometric_functions.h>
#include <cuda/std/__cmath/isfinite.h>
#include <cuda/std/__cmath/isinf.h>
#include <cuda/std/__cmath/isnan.h>
#include <cuda/std/__cmath/isnormal.h>
#include <cuda/std/__cmath/lerp.h>
#include <cuda/std/__cmath/logarithms.h>
#include <cuda/std/__cmath/min_max.h>
#include <cuda/std/__cmath/roots.h>
#include <cuda/std/__cmath/rounding_functions.h>
#include <cuda/std/__cmath/signbit.h>
#include <cuda/std/__cmath/traits.h>
#include <cuda/std/__cmath/trigonometric_functions.h>
#include <cuda/std/__cstdlib/abs.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
#include <cuda/std/version>

#if _LIBCUDACXX_HAS_NVFP16()
#  include <cuda/std/__cmath/nvfp16.h>
#endif // _LIBCUDACXX_HAS_NVFP16()

#if _LIBCUDACXX_HAS_NVBF16()
#  include <cuda/std/__cmath/nvbf16.h>
#endif // _LIBCUDACXX_HAS_NVBF16()

#if _CCCL_COMPILER(NVRTC)
#  define INFINITY _CUDA_VSTD::numeric_limits<float>::infinity()
#  define NAN      _CUDA_VSTD::numeric_limits<float>::quiet_NaN()
#endif // _CCCL_COMPILER(NVRTC)

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_STD

#if !_CCCL_COMPILER(NVRTC)

using ::double_t;
using ::float_t;

using ::fmod;
using ::fmodf;

using ::modf;
using ::modff;

using ::erf;
using ::erfc;
using ::erfcf;
using ::erff;
using ::fdim;
using ::fdimf;
using ::fma;
using ::fmaf;

using ::nan;
using ::nanf;

using ::remainder;
using ::remainderf;
using ::remquo;
using ::remquof;

using ::fmodl;
using ::modfl;

using ::erfcl;
using ::erfl;
using ::fdiml;
using ::fmal;
using ::nanl;
using ::remainderl;
using ::remquol;

#endif // _CCCL_COMPILER(NVRTC)

#if _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
#  define _CCCL_CONSTEXPR_CXX14_COMPLEX constexpr
#else
#  define _CCCL_CONSTEXPR_CXX14_COMPLEX
#endif // !_LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC)
template <class _A1>
_LIBCUDACXX_HIDE_FROM_ABI _A1 __constexpr_copysign(_A1 __x, _A1 __y) noexcept
{
  return ::copysign(__x, __y);
}
#else
_LIBCUDACXX_HIDE_FROM_ABI constexpr float __constexpr_copysign(float __x, float __y) noexcept
{
  return __builtin_copysignf(__x, __y);
}

_LIBCUDACXX_HIDE_FROM_ABI constexpr double __constexpr_copysign(double __x, double __y) noexcept
{
  return __builtin_copysign(__x, __y);
}

#  if _CCCL_HAS_LONG_DOUBLE()
_LIBCUDACXX_HIDE_FROM_ABI constexpr long double __constexpr_copysign(long double __x, long double __y) noexcept
{
  return __builtin_copysignl(__x, __y);
}
#  endif // _CCCL_HAS_LONG_DOUBLE()

template <class _A1, class _A2>
_LIBCUDACXX_HIDE_FROM_ABI constexpr enable_if_t<is_arithmetic<_A1>::value && is_arithmetic<_A2>::value,
                                                __promote_t<_A1, _A2>>
__constexpr_copysign(_A1 __x, _A2 __y) noexcept
{
  using __result_type = __promote_t<_A1, _A2>;
  static_assert((!(_IsSame<_A1, __result_type>::value && _IsSame<_A2, __result_type>::value)), "");
  return __builtin_copysign((__result_type) __x, (__result_type) __y);
}
#endif // !_CCCL_COMPILER(MSVC)

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC)
template <class _A1>
_LIBCUDACXX_HIDE_FROM_ABI _A1 __constexpr_fabs(_A1 __x) noexcept
{
  return ::fabs(__x);
}
#else
_LIBCUDACXX_HIDE_FROM_ABI constexpr float __constexpr_fabs(float __x) noexcept
{
  return __builtin_fabsf(__x);
}

_LIBCUDACXX_HIDE_FROM_ABI constexpr double __constexpr_fabs(double __x) noexcept
{
  return __builtin_fabs(__x);
}

#  if _CCCL_HAS_LONG_DOUBLE()
_LIBCUDACXX_HIDE_FROM_ABI constexpr long double __constexpr_fabs(long double __x) noexcept
{
  return __builtin_fabsl(__x);
}
#  endif // _CCCL_HAS_LONG_DOUBLE()

template <class _Tp, enable_if_t<is_integral<_Tp>::value, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI constexpr double __constexpr_fabs(_Tp __x) noexcept
{
  return __builtin_fabs(static_cast<double>(__x));
}
#endif // !_CCCL_COMPILER(MSVC)

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC)
template <class _A1>
_LIBCUDACXX_HIDE_FROM_ABI _A1 __constexpr_fmax(_A1 __x, _A1 __y) noexcept
{
  return ::fmax(__x, __y);
}
#else
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX float __constexpr_fmax(float __x, float __y) noexcept
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (_CUDA_VSTD::isnan(__x))
    {
      return __y;
    }
    if (_CUDA_VSTD::isnan(__y))
    {
      return __x;
    }
    return __x < __y ? __y : __x;
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_fmaxf(__x, __y);
}

_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX double __constexpr_fmax(double __x, double __y) noexcept
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (_CUDA_VSTD::isnan(__x))
    {
      return __y;
    }
    if (_CUDA_VSTD::isnan(__y))
    {
      return __x;
    }
    return __x < __y ? __y : __x;
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_fmax(__x, __y);
}

#  if _CCCL_HAS_LONG_DOUBLE()
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX long double
__constexpr_fmax(long double __x, long double __y) noexcept
{
#    if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (_CUDA_VSTD::isnan(__x))
    {
      return __y;
    }
    if (_CUDA_VSTD::isnan(__y))
    {
      return __x;
    }
    return __x < __y ? __y : __x;
  }
#    endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_fmax(__x, __y);
}
#  endif // _CCCL_HAS_LONG_DOUBLE()

template <class _Tp, class _Up, enable_if_t<is_arithmetic<_Tp>::value && is_arithmetic<_Up>::value, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX __promote_t<_Tp, _Up> __constexpr_fmax(_Tp __x, _Up __y) noexcept
{
  using __result_type = __promote_t<_Tp, _Up>;
  return _CUDA_VSTD::__constexpr_fmax(static_cast<__result_type>(__x), static_cast<__result_type>(__y));
}
#endif // !_CCCL_COMPILER(MSVC)

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC) || _CCCL_CUDA_COMPILER(CLANG)
template <class _A1>
_LIBCUDACXX_HIDE_FROM_ABI _A1 __constexpr_logb(_A1 __x)
{
  return _CUDA_VSTD::logb(__x);
}
#else
template <class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX _Tp __constexpr_logb(_Tp __x)
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (__x == _Tp(0))
    {
      // raise FE_DIVBYZERO
      return -numeric_limits<_Tp>::infinity();
    }

    if (_CUDA_VSTD::isinf(__x))
    {
      return numeric_limits<_Tp>::infinity();
    }

    if (_CUDA_VSTD::isnan(__x))
    {
      return numeric_limits<_Tp>::quiet_NaN();
    }

    __x                      = _CUDA_VSTD::__constexpr_fabs(__x);
    unsigned long long __exp = 0;
    while (__x >= _Tp(numeric_limits<_Tp>::radix))
    {
      __x /= numeric_limits<_Tp>::radix;
      __exp += 1;
    }
    return _Tp(__exp);
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_logb(__x);
}
#endif // !_MSVC

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC) || _CCCL_CUDA_COMPILER(CLANG)
template <class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI _Tp __constexpr_scalbn(_Tp __x, int __i)
{
  return static_cast<_Tp>(::scalbn(static_cast<double>(__x), __i));
}

template <>
_LIBCUDACXX_HIDE_FROM_ABI float __constexpr_scalbn<float>(float __x, int __i)
{
  return ::scalbnf(__x, __i);
}

template <>
_LIBCUDACXX_HIDE_FROM_ABI double __constexpr_scalbn<double>(double __x, int __i)
{
  return ::scalbn(__x, __i);
}

#  if _CCCL_HAS_LONG_DOUBLE()
template <>
_LIBCUDACXX_HIDE_FROM_ABI long double __constexpr_scalbn<long double>(long double __x, int __i)
{
  return ::scalbnl(__x, __i);
}
#  endif // _CCCL_HAS_LONG_DOUBLE()
#else
template <class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14_COMPLEX _Tp __constexpr_scalbn(_Tp __x, int __exp)
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (__x == _Tp(0))
    {
      return __x;
    }

    if (_CUDA_VSTD::isinf(__x))
    {
      return __x;
    }

    if (_Tp(__exp) == _Tp(0))
    {
      return __x;
    }

    if (_CUDA_VSTD::isnan(__x))
    {
      return numeric_limits<_Tp>::quiet_NaN();
    }

    _Tp __mult(1);
    if (__exp > 0)
    {
      __mult = numeric_limits<_Tp>::radix;
      --__exp;
    }
    else
    {
      ++__exp;
      __exp = -__exp;
      __mult /= numeric_limits<_Tp>::radix;
    }

    while (__exp > 0)
    {
      if (!(__exp & 1))
      {
        __mult *= __mult;
        __exp >>= 1;
      }
      else
      {
        __x *= __mult;
        --__exp;
      }
    }
    return __x;
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_scalbn(__x, __exp);
}
#endif // !_CCCL_COMPILER(MSVC)

_LIBCUDACXX_END_NAMESPACE_STD

_CCCL_POP_MACROS

#endif // _LIBCUDACXX_CMATH
