22
33#include < xsf/legendre.h>
44
5- // backport of std::mdspan (since C++23)
6- #define MDSPAN_USE_PAREN_OPERATOR 1
7- #include < xsf/third_party/kokkos/mdspan.hpp>
8-
9- // Type aliases for commonly used mdspan types
10- using mdspan_1d_double = std::mdspan<double , std::dextents<ptrdiff_t , 1 >>;
11- using mdspan_2d_double = std::mdspan<double , std::dextents<ptrdiff_t , 2 >>;
12- using mdspan_2d_cdouble = std::mdspan<std::complex <double >, std::dextents<ptrdiff_t , 2 >>;
5+ // `std::mdspan` (the kokkos one) was causing issues on windows and mac, so we duck-type it instead
6+ template <typename T>
7+ class duck_mdspan_1d {
8+ private:
9+ T *data_;
10+ size_t size_;
11+
12+ public:
13+ duck_mdspan_1d (T *data, size_t size) : data_(data), size_(size) {}
14+ T &operator ()(size_t i) { return data_[i]; }
15+ const T &operator ()(size_t i) const { return data_[i]; }
16+ size_t size () const { return size_; }
17+ // `xsf::lqn` also uses [] for indexing
18+ T &operator [](size_t i) { return data_[i]; }
19+ const T &operator [](size_t i) const { return data_[i]; }
20+ };
21+
22+ template <typename T>
23+ class duck_mdspan_2d {
24+ private:
25+ T *data_;
26+ size_t rows_, cols_;
27+
28+ public:
29+ duck_mdspan_2d (T *data, size_t rows, size_t cols) : data_(data), rows_(rows), cols_(cols) {}
30+ T &operator ()(size_t i, size_t j) { return data_[i * cols_ + j]; }
31+ const T &operator ()(size_t i, size_t j) const { return data_[i * cols_ + j]; }
32+ size_t extent (int dim) const { return (dim == 0 ) ? rows_ : cols_; }
33+ };
1334
1435// From https://github.com/scipy/scipy/blob/bdd3b0e/scipy/special/tests/test_legendre.py#L693-L697
1536TEST_CASE (" lqmn TestLegendreFunctions.test_lqmn" , " [lqmn][lqn][real][smoketest]" ) {
@@ -21,15 +42,17 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn", "[lqmn][lqn][real][smoketest]"
2142
2243 constexpr int m1p = m + 1 ;
2344 constexpr int n1p = n + 1 ;
45+ // lqmn requires buffer space for at least 2x2
46+ constexpr int bufsize = std::max (2 , m1p) * std::max (2 , n1p);
2447
2548 // lqmnf = special.lqmn(0, 2, .5)
26- double lqmnf0_data[m1p * n1p ], lqmnf1_data[m1p * n1p ];
27- mdspan_2d_double lqmnf0 (lqmnf0_data, m1p, n1p), lqmnf1 (lqmnf1_data, m1p, n1p);
49+ double lqmnf0_data[bufsize ], lqmnf1_data[bufsize ];
50+ duck_mdspan_2d< double > lqmnf0 (lqmnf0_data, m1p, n1p), lqmnf1 (lqmnf1_data, m1p, n1p);
2851 xsf::lqmn (x, lqmnf0, lqmnf1);
2952
3053 // lqf = special.lqn(2, .5)
3154 double lqf0_data[n1p], lqf1_data[n1p];
32- mdspan_1d_double lqf0 (lqf0_data, n1p), lqf1 (lqf1_data, n1p);
55+ duck_mdspan_1d< double > lqf0 (lqf0_data, n1p), lqf1 (lqf1_data, n1p);
3356 xsf::lqn (x, lqf0, lqf1);
3457
3558 // assert_allclose(lqmnf[0][0], lqf[0], atol=1.5e-4, rtol=0)
@@ -69,7 +92,7 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn_gt1", "[lqmn][real][smoketest]")
6992 constexpr int n1p = n + 1 ;
7093
7194 double lqmnf0_data[m1p * n1p], lqmnf1_data[m1p * n1p];
72- mdspan_2d_double lqmnf0 (lqmnf0_data, m1p, n1p), lqmnf1 (lqmnf1_data, m1p, n1p);
95+ duck_mdspan_2d< double > lqmnf0 (lqmnf0_data, m1p, n1p), lqmnf1 (lqmnf1_data, m1p, n1p);
7396
7497 // algorithm for real arguments changes at 1.0001
7598 // test against analytical result for m=2, n=1
@@ -95,17 +118,19 @@ TEST_CASE("lqmn TestLegendreFunctions.test_lqmn_gt1", "[lqmn][real][smoketest]")
95118TEST_CASE (" lqmn complex" , " [lqmn][complex][smoketest]" ) {
96119 constexpr double atol = 1e-16 ;
97120 constexpr double x = 0.5 ;
121+ // lqmn requires buffer space for at least 2x2
122+ constexpr int bufsize = 2 * 2 ;
98123
99124 // (q_mn, qp_mn) = lqmn(0, 0, 0.5)
100- double q_data[1 ], qp_data[1 ];
101- mdspan_2d_double q_mn (q_data, 1 , 1 ), qp_mn (qp_data, 1 , 1 );
125+ double q_data[bufsize ], qp_data[bufsize ];
126+ duck_mdspan_2d< double > q_mn (q_data, 1 , 1 ), qp_mn (qp_data, 1 , 1 );
102127 xsf::lqmn (x, q_mn, qp_mn);
103128 auto q = q_mn (0 , 0 );
104129 auto qp = qp_mn (0 , 0 );
105130
106131 // (cq_mn, cqp_mn) = lqmn(0, 0, 0.5 + 0j)
107- std::complex <double > cq_data[1 ], cqp_data[1 ];
108- mdspan_2d_cdouble cq_mn (cq_data, 1 , 1 ), cqp_mn (cqp_data, 1 , 1 );
132+ std::complex <double > cq_data[bufsize ], cqp_data[bufsize ];
133+ duck_mdspan_2d<std:: complex < double >> cq_mn (cq_data, 1 , 1 ), cqp_mn (cqp_data, 1 , 1 );
109134 xsf::lqmn (std::complex <double >(x, 0.0 ), cq_mn, cqp_mn);
110135 auto cq = cq_mn (0 , 0 );
111136 auto cqp = cqp_mn (0 , 0 );
0 commit comments