Skip to content

Commit 9a78af4

Browse files
authored
Implement std::chi_squared_distribution (#6856)
1 parent 6d2f46e commit 9a78af4

10 files changed

Lines changed: 249 additions & 7 deletions

File tree

libcudacxx/include/cuda/std/__random/bernoulli_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
# pragma system_header
2121
#endif // no system header
2222

23+
#include <cuda/std/__limits/numeric_limits.h>
2324
#include <cuda/std/__random/generate_canonical.h>
2425
#include <cuda/std/__random/is_valid.h>
25-
#include <cuda/std/limits>
2626
#if !_CCCL_COMPILER(NVRTC)
2727
# include <ios>
2828
#endif // !_CCCL_COMPILER(NVRTC)

libcudacxx/include/cuda/std/__random/cauchy_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
#endif // no system header
2222

2323
#include <cuda/std/__cmath/trigonometric_functions.h>
24+
#include <cuda/std/__limits/numeric_limits.h>
2425
#include <cuda/std/__random/is_valid.h>
2526
#include <cuda/std/__random/uniform_real_distribution.h>
26-
#include <cuda/std/limits>
2727

2828
#include <cuda/std/__cccl/prologue.h>
2929

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef _CUDA_STD___CHI_SQUARED_DISTRIBUTION_H
11+
#define _CUDA_STD___CHI_SQUARED_DISTRIBUTION_H
12+
13+
#include <cuda/std/detail/__config>
14+
15+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
16+
# pragma GCC system_header
17+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
18+
# pragma clang system_header
19+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
20+
# pragma system_header
21+
#endif // no system header
22+
23+
#include <cuda/std/__limits/numeric_limits.h>
24+
#include <cuda/std/__random/gamma_distribution.h>
25+
#include <cuda/std/__random/is_valid.h>
26+
27+
#include <cuda/std/__cccl/prologue.h>
28+
29+
_CCCL_BEGIN_NAMESPACE_CUDA_STD
30+
31+
template <class _RealType = double>
32+
class chi_squared_distribution
33+
{
34+
static_assert(__libcpp_random_is_valid_realtype<_RealType>, "RealType must be a supported floating-point type");
35+
36+
public:
37+
// types
38+
using result_type = _RealType;
39+
40+
class param_type
41+
{
42+
private:
43+
result_type __n_ = result_type{1};
44+
45+
public:
46+
using distribution_type = chi_squared_distribution;
47+
48+
constexpr param_type() noexcept = default;
49+
50+
_CCCL_API constexpr explicit param_type(result_type __n) noexcept
51+
: __n_{__n}
52+
{}
53+
54+
[[nodiscard]] _CCCL_API constexpr result_type n() const noexcept
55+
{
56+
return __n_;
57+
}
58+
59+
[[nodiscard]] friend _CCCL_API constexpr bool operator==(const param_type& __x, const param_type& __y) noexcept
60+
{
61+
return __x.__n_ == __y.__n_;
62+
}
63+
#if _CCCL_STD_VER <= 2017
64+
[[nodiscard]] friend _CCCL_API constexpr bool operator!=(const param_type& __x, const param_type& __y) noexcept
65+
{
66+
return !(__x == __y);
67+
}
68+
#endif // _CCCL_STD_VER <= 2017
69+
};
70+
71+
private:
72+
param_type __p_{};
73+
74+
public:
75+
// constructor and reset functions
76+
constexpr chi_squared_distribution() noexcept = default;
77+
78+
_CCCL_API constexpr explicit chi_squared_distribution(result_type __n) noexcept
79+
: __p_{param_type{__n}}
80+
{}
81+
_CCCL_API constexpr explicit chi_squared_distribution(const param_type& __p) noexcept
82+
: __p_{__p}
83+
{}
84+
_CCCL_API void reset() noexcept {}
85+
86+
// generating functions
87+
template <class _URng>
88+
[[nodiscard]] _CCCL_API result_type operator()(_URng& __g)
89+
{
90+
return (*this)(__g, __p_);
91+
}
92+
template <class _URng>
93+
[[nodiscard]] _CCCL_API result_type operator()(_URng& __g, const param_type& __p)
94+
{
95+
static_assert(__cccl_random_is_valid_urng<_URng>, "URng must meet the UniformRandomBitGenerator requirements");
96+
return gamma_distribution<result_type>(__p.n() / 2, 2)(__g);
97+
}
98+
99+
// property functions
100+
[[nodiscard]] _CCCL_API constexpr result_type n() const noexcept
101+
{
102+
return __p_.n();
103+
}
104+
105+
[[nodiscard]] _CCCL_API constexpr param_type param() const noexcept
106+
{
107+
return __p_;
108+
}
109+
_CCCL_API constexpr void param(const param_type& __p) noexcept
110+
{
111+
__p_ = __p;
112+
}
113+
114+
[[nodiscard]] _CCCL_API static constexpr result_type min() noexcept
115+
{
116+
return result_type{0};
117+
}
118+
[[nodiscard]] _CCCL_API static constexpr result_type max() noexcept
119+
{
120+
return numeric_limits<result_type>::infinity();
121+
}
122+
123+
[[nodiscard]] friend _CCCL_API constexpr bool
124+
operator==(const chi_squared_distribution& __x, const chi_squared_distribution& __y) noexcept
125+
{
126+
return __x.__p_ == __y.__p_;
127+
}
128+
#if _CCCL_STD_VER <= 2017
129+
[[nodiscard]] friend _CCCL_API constexpr bool
130+
operator!=(const chi_squared_distribution& __x, const chi_squared_distribution& __y) noexcept
131+
{
132+
return !(__x == __y);
133+
}
134+
#endif // _CCCL_STD_VER <= 2017
135+
136+
#if !_CCCL_COMPILER(NVRTC)
137+
template <class _CharT, class _Traits>
138+
friend ::std::basic_ostream<_CharT, _Traits>&
139+
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const chi_squared_distribution& __x)
140+
{
141+
using ostream_type = ::std::basic_ostream<_CharT, _Traits>;
142+
using ios_base = typename ostream_type::ios_base;
143+
const typename ios_base::fmtflags __flags = __os.flags();
144+
const _CharT __fill = __os.fill();
145+
const ::std::streamsize __precision = __os.precision();
146+
__os.flags(ios_base::dec | ios_base::left | ios_base::scientific);
147+
__os.precision(numeric_limits<result_type>::max_digits10);
148+
__os << __x.n();
149+
__os.flags(__flags);
150+
__os.fill(__fill);
151+
__os.precision(__precision);
152+
return __os;
153+
}
154+
155+
template <class _CharT, class _Traits>
156+
friend ::std::basic_istream<_CharT, _Traits>&
157+
operator>>(::std::basic_istream<_CharT, _Traits>& __is, chi_squared_distribution& __x)
158+
{
159+
using istream_type = ::std::basic_istream<_CharT, _Traits>;
160+
using ios_base = typename istream_type::ios_base;
161+
const typename ios_base::fmtflags __flags = __is.flags();
162+
__is.flags(ios_base::dec | ios_base::skipws);
163+
result_type __n;
164+
__is >> __n;
165+
if (!__is.fail())
166+
{
167+
__x.param(param_type{__n});
168+
}
169+
__is.flags(__flags);
170+
return __is;
171+
}
172+
#endif // !_CCCL_COMPILER(NVRTC)
173+
};
174+
175+
_CCCL_END_NAMESPACE_CUDA_STD
176+
177+
#include <cuda/std/__cccl/epilogue.h>
178+
179+
#endif // _CUDA_STD___CHI_SQUARED_DISTRIBUTION_H

libcudacxx/include/cuda/std/__random/generate_canonical.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
#endif // no system header
2222

2323
#include <cuda/std/__bit/integral.h>
24+
#include <cuda/std/__limits/numeric_limits.h>
2425
#include <cuda/std/cstdint>
25-
#include <cuda/std/limits>
2626

2727
#include <cuda/std/__cccl/prologue.h>
2828

libcudacxx/include/cuda/std/__random/normal_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
#include <cuda/std/__cmath/logarithms.h>
1616
#include <cuda/std/__cmath/roots.h>
17+
#include <cuda/std/__limits/numeric_limits.h>
1718
#include <cuda/std/__random/is_valid.h>
1819
#include <cuda/std/__random/uniform_real_distribution.h>
19-
#include <cuda/std/limits>
2020

2121
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
2222
# pragma GCC system_header

libcudacxx/include/cuda/std/__random/poisson_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
# pragma system_header
2121
#endif // no system header
2222

23+
#include <cuda/std/__limits/numeric_limits.h>
2324
#include <cuda/std/__random/generate_canonical.h>
2425
#include <cuda/std/__random/is_valid.h>
2526
#include <cuda/std/__random/normal_distribution.h>
2627
#include <cuda/std/__random/uniform_real_distribution.h>
2728
#include <cuda/std/cmath>
28-
#include <cuda/std/limits>
2929
#if !_CCCL_COMPILER(NVRTC)
3030
# include <ios>
3131
#endif // !_CCCL_COMPILER(NVRTC)

libcudacxx/include/cuda/std/__random/uniform_int_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323
#include <cuda/std/__bit/countl.h>
2424
#include <cuda/std/__bit/integral.h>
25+
#include <cuda/std/__limits/numeric_limits.h>
2526
#include <cuda/std/__random/is_valid.h>
2627
#include <cuda/std/__type_traits/conditional.h>
2728
#include <cuda/std/__type_traits/make_unsigned.h>
2829
#include <cuda/std/cstddef>
2930
#include <cuda/std/cstdint>
30-
#include <cuda/std/limits>
3131

3232
#include <cuda/std/__cccl/prologue.h>
3333

libcudacxx/include/cuda/std/__random/uniform_real_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
# pragma system_header
2121
#endif // no system header
2222

23+
#include <cuda/std/__limits/numeric_limits.h>
2324
#include <cuda/std/__random/generate_canonical.h>
2425
#include <cuda/std/__random/is_valid.h>
25-
#include <cuda/std/limits>
2626

2727
#if !_CCCL_COMPILER(NVRTC)
2828
# include <iosfwd>

libcudacxx/include/cuda/std/__random_

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <cuda/std/__random/bernoulli_distribution.h>
2525
#include <cuda/std/__random/binomial_distribution.h>
2626
#include <cuda/std/__random/cauchy_distribution.h>
27+
#include <cuda/std/__random/chi_squared_distribution.h>
2728
#include <cuda/std/__random/exponential_distribution.h>
2829
#include <cuda/std/__random/extreme_value_distribution.h>
2930
#include <cuda/std/__random/fisher_f_distribution.h>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// REQUIRES: long_tests
11+
12+
// <random>
13+
14+
// template<class RealType = double>
15+
// class chi_squared_distribution
16+
17+
#include <cuda/std/__random_>
18+
#include <cuda/std/cassert>
19+
#include <cuda/std/cmath>
20+
21+
#include "random_utilities/stats_functions.h"
22+
#include "random_utilities/test_distribution.h"
23+
#include "test_macros.h"
24+
25+
template <class T>
26+
struct chi_squared_cdf
27+
{
28+
using P = typename cuda::std::chi_squared_distribution<T>::param_type;
29+
30+
__host__ __device__ double operator()(double x, const P& p) const
31+
{
32+
if (x <= 0.0)
33+
{
34+
return 0.0;
35+
}
36+
37+
// Chi-squared distribution is a special case of gamma distribution
38+
// Chi-squared(k) = Gamma(k/2, 2)
39+
// CDF: P(k/2, x/2)
40+
double k = p.n();
41+
return incomplete_gamma(k / 2.0, x / 2.0);
42+
}
43+
};
44+
45+
template <class T>
46+
__host__ __device__ void test()
47+
{
48+
// Can be true if/when cuda::std::lgamma is constexpr
49+
[[maybe_unused]] const bool test_constexpr = false;
50+
using D = cuda::std::chi_squared_distribution<T>;
51+
using P = typename D::param_type;
52+
using G = cuda::std::philox4x64;
53+
cuda::std::array<P, 5> params = {P(1), P(2), P(3), P(5), P(10)};
54+
test_distribution<D, true, G, test_constexpr>(params, chi_squared_cdf<T>{});
55+
}
56+
57+
int main(int, char**)
58+
{
59+
test<double>();
60+
test<float>();
61+
return 0;
62+
}

0 commit comments

Comments
 (0)