Skip to content

Commit e8a2aa3

Browse files
committed
TST: replace 🦆-type with std::mdspan
1 parent d3d334d commit e8a2aa3

File tree

1 file changed

+13
-34
lines changed

1 file changed

+13
-34
lines changed

‎tests/scipy_special_tests/test_lqmn.cpp‎

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,9 @@
22

33
#include <xsf/legendre.h>
44

5-
// `std::mdspan` (the kokkos one) was causing issues on windows and mac, so we duck-type it instead
6-
template <typename T>
7-
class duck_mdspan_1d {
8-
private:
9-
T *data_;
10-
size_t size_;
11-
12-
public:
13-
duck_mdspan_1d(T *data, size_t size) : data_(data), size_(size) {}
14-
T &operator()(size_t i) { return data_[i]; }
15-
const T &operator()(size_t i) const { return data_[i]; }
16-
size_t size() const { return size_; }
17-
// `xsf::lqn` also uses [] for indexing
18-
T &operator[](size_t i) { return data_[i]; }
19-
const T &operator[](size_t i) const { return data_[i]; }
20-
};
21-
22-
template <typename T>
23-
class duck_mdspan_2d {
24-
private:
25-
T *data_;
26-
size_t rows_, cols_;
27-
28-
public:
29-
duck_mdspan_2d(T *data, size_t rows, size_t cols) : data_(data), rows_(rows), cols_(cols) {}
30-
T &operator()(size_t i, size_t j) { return data_[i * cols_ + j]; }
31-
const T &operator()(size_t i, size_t j) const { return data_[i * cols_ + j]; }
32-
size_t extent(int dim) const { return (dim == 0) ? rows_ : cols_; }
33-
};
5+
// backport of std::mdspan (since C++23)
6+
#define MDSPAN_USE_PAREN_OPERATOR 1
7+
#include <xsf/third_party/kokkos/mdspan.hpp>
348

359
// From https://github.com/scipy/scipy/blob/bdd3b0e/scipy/special/tests/test_legendre.py#L693-L697
3610
TEST_CASE("lqmn TestLegendreFunctions.test_lqmn", "[lqmn][lqn][real][smoketest]") {
@@ -47,12 +21,14 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn", "[lqmn][lqn][real][smoketest]"
4721

4822
// lqmnf = special.lqmn(0, 2, .5)
4923
double lqmnf0_data[bufsize], lqmnf1_data[bufsize];
50-
duck_mdspan_2d<double> lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
24+
auto lqmnf0 = std::mdspan(lqmnf0_data, m1p, n1p);
25+
auto lqmnf1 = std::mdspan(lqmnf1_data, m1p, n1p);
5126
xsf::lqmn(x, lqmnf0, lqmnf1);
5227

5328
// lqf = special.lqn(2, .5)
5429
double lqf0_data[n1p], lqf1_data[n1p];
55-
duck_mdspan_1d<double> lqf0(lqf0_data, n1p), lqf1(lqf1_data, n1p);
30+
auto lqf0 = std::mdspan(lqf0_data, n1p);
31+
auto lqf1 = std::mdspan(lqf1_data, n1p);
5632
xsf::lqn(x, lqf0, lqf1);
5733

5834
// assert_allclose(lqmnf[0][0], lqf[0], atol=1.5e-4, rtol=0)
@@ -92,7 +68,8 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn_gt1", "[lqmn][real][smoketest]")
9268
constexpr int n1p = n + 1;
9369

9470
double lqmnf0_data[m1p * n1p], lqmnf1_data[m1p * n1p];
95-
duck_mdspan_2d<double> lqmnf0(lqmnf0_data, m1p, n1p), lqmnf1(lqmnf1_data, m1p, n1p);
71+
auto lqmnf0 = std::mdspan(lqmnf0_data, m1p, n1p);
72+
auto lqmnf1 = std::mdspan(lqmnf1_data, m1p, n1p);
9673

9774
// algorithm for real arguments changes at 1.0001
9875
// test against analytical result for m=2, n=1
@@ -123,14 +100,16 @@ TEST_CASE("lqmn complex", "[lqmn][complex][smoketest]") {
123100

124101
// (q_mn, qp_mn) = lqmn(0, 0, 0.5)
125102
double q_data[bufsize], qp_data[bufsize];
126-
duck_mdspan_2d<double> q_mn(q_data, 1, 1), qp_mn(qp_data, 1, 1);
103+
auto q_mn = std::mdspan(q_data, 1, 1);
104+
auto qp_mn = std::mdspan(qp_data, 1, 1);
127105
xsf::lqmn(x, q_mn, qp_mn);
128106
auto q = q_mn(0, 0);
129107
auto qp = qp_mn(0, 0);
130108

131109
// (cq_mn, cqp_mn) = lqmn(0, 0, 0.5 + 0j)
132110
std::complex<double> cq_data[bufsize], cqp_data[bufsize];
133-
duck_mdspan_2d<std::complex<double>> cq_mn(cq_data, 1, 1), cqp_mn(cqp_data, 1, 1);
111+
auto cq_mn = std::mdspan(cq_data, 1, 1);
112+
auto cqp_mn = std::mdspan(cqp_data, 1, 1);
134113
xsf::lqmn(std::complex<double>(x, 0.0), cq_mn, cqp_mn);
135114
auto cq = cq_mn(0, 0);
136115
auto cqp = cqp_mn(0, 0);

0 commit comments

Comments
 (0)