Skip to content

Commit 4929f4a

Browse files
committed
TST: work around out of bound indexing for small lqmn matrices
1 parent e2b4ee3 commit 4929f4a

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

tests/scipy_special_tests/test_lqmn.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,35 @@
22

33
#include <xsf/legendre.h>
44

5-
// backport of std::mdspan (since C++23)
6-
#define MDSPAN_USE_PAREN_OPERATOR 1
7-
#include <xsf/third_party/kokkos/mdspan.hpp>
8-
9-
// Type aliases for commonly used mdspan types
10-
using mdspan_1d_double = std::mdspan<double, std::dextents<ptrdiff_t, 1>>;
11-
using mdspan_2d_double = std::mdspan<double, std::dextents<ptrdiff_t, 2>>;
12-
using mdspan_2d_cdouble = std::mdspan<std::complex<double>, std::dextents<ptrdiff_t, 2>>;
5+
// `std::mdspan` (the kokkos one) was causing issues on windows and mac, so we duck-type it instead
6+
template <typename T>
7+
class duck_mdspan_1d {
8+
private:
9+
T *data_;
10+
size_t size_;
11+
12+
public:
13+
duck_mdspan_1d(T *data, size_t size) : data_(data), size_(size) {}
14+
T &operator()(size_t i) { return data_[i]; }
15+
const T &operator()(size_t i) const { return data_[i]; }
16+
size_t size() const { return size_; }
17+
// `xsf::lqn` also uses [] for indexing
18+
T &operator[](size_t i) { return data_[i]; }
19+
const T &operator[](size_t i) const { return data_[i]; }
20+
};
21+
22+
template <typename T>
23+
class duck_mdspan_2d {
24+
private:
25+
T *data_;
26+
size_t rows_, cols_;
27+
28+
public:
29+
duck_mdspan_2d(T *data, size_t rows, size_t cols) : data_(data), rows_(rows), cols_(cols) {}
30+
T &operator()(size_t i, size_t j) { return data_[i * cols_ + j]; }
31+
const T &operator()(size_t i, size_t j) const { return data_[i * cols_ + j]; }
32+
size_t extent(int dim) const { return (dim == 0) ? rows_ : cols_; }
33+
};
1334

1435
// From https://github.com/scipy/scipy/blob/bdd3b0e/scipy/special/tests/test_legendre.py#L693-L697
1536
TEST_CASE("lqmn TestLegendreFunctions.test_lqmn", "[lqmn][lqn][real][smoketest]") {
@@ -21,15 +42,17 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn", "[lqmn][lqn][real][smoketest]"
2142

2243
constexpr int m1p = m + 1;
2344
constexpr int n1p = n + 1;
45+
// lqmn requires buffer space for at least 2x2
46+
constexpr int bufsize = std::max(2, m1p) * std::max(2, n1p);
2447

2548
// lqmnf = special.lqmn(0, 2, .5)
26-
double lqmnf0_data[m1p * n1p], lqmnf1_data[m1p * n1p];
27-
mdspan_2d_double lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
49+
double lqmnf0_data[bufsize], lqmnf1_data[bufsize];
50+
duck_mdspan_2d<double> lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
2851
xsf::lqmn(x, lqmnf0, lqmnf1);
2952

3053
// lqf = special.lqn(2, .5)
3154
double lqf0_data[n1p], lqf1_data[n1p];
32-
mdspan_1d_double lqf0(lqf0_data, n1p), lqf1(lqf1_data, n1p);
55+
duck_mdspan_1d<double> lqf0(lqf0_data, n1p), lqf1(lqf1_data, n1p);
3356
xsf::lqn(x, lqf0, lqf1);
3457

3558
// assert_allclose(lqmnf[0][0], lqf[0], atol=1.5e-4, rtol=0)
@@ -69,7 +92,7 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn_gt1", "[lqmn][real][smoketest]")
6992
constexpr int n1p = n + 1;
7093

7194
double lqmnf0_data[m1p * n1p], lqmnf1_data[m1p * n1p];
72-
mdspan_2d_double lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
95+
duck_mdspan_2d<double> lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
7396

7497
// algorithm for real arguments changes at 1.0001
7598
// test against analytical result for m=2, n=1
@@ -95,17 +118,19 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn_gt1", "[lqmn][real][smoketest]")
95118
TEST_CASE("lqmn complex", "[lqmn][complex][smoketest]") {
96119
constexpr double atol = 1e-16;
97120
constexpr double x = 0.5;
121+
// lqmn requires buffer space for at least 2x2
122+
constexpr int bufsize = 2 * 2;
98123

99124
// (q_mn, qp_mn) = lqmn(0, 0, 0.5)
100-
double q_data[1], qp_data[1];
101-
mdspan_2d_double q_mn(q_data, 1, 1), qp_mn(qp_data, 1, 1);
125+
double q_data[bufsize], qp_data[bufsize];
126+
duck_mdspan_2d<double> q_mn(q_data, 1, 1), qp_mn(qp_data, 1, 1);
102127
xsf::lqmn(x, q_mn, qp_mn);
103128
auto q = q_mn(0, 0);
104129
auto qp = qp_mn(0, 0);
105130

106131
// (cq_mn, cqp_mn) = lqmn(0, 0, 0.5 + 0j)
107-
std::complex<double> cq_data[1], cqp_data[1];
108-
mdspan_2d_cdouble cq_mn(cq_data, 1, 1), cqp_mn(cqp_data, 1, 1);
132+
std::complex<double> cq_data[bufsize], cqp_data[bufsize];
133+
duck_mdspan_2d<std::complex<double>> cq_mn(cq_data, 1, 1), cqp_mn(cqp_data, 1, 1);
109134
xsf::lqmn(std::complex<double>(x, 0.0), cq_mn, cqp_mn);
110135
auto cq = cq_mn(0, 0);
111136
auto cqp = cqp_mn(0, 0);

0 commit comments

Comments
 (0)