diff --git a/src/rng/backends/rocrand/mrg32k3a.cpp b/src/rng/backends/rocrand/mrg32k3a.cpp index 1709bd6c7..424f14caf 100644 --- a/src/rng/backends/rocrand/mrg32k3a.cpp +++ b/src/rng/backends/rocrand/mrg32k3a.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) * and Computing Centre (URZ) - * cuRAND back-end Copyright (c) 2021, The Regents of the University of + * rocRAND back-end Copyright (c) 2021, The Regents of the University of * California, through Lawrence Berkeley National Laboratory (subject to receipt * of any required approvals from the U.S. Dept. of Energy). All rights * reserved. @@ -88,7 +88,9 @@ namespace rocrand { class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(sycl::queue queue, std::uint32_t seed) - : oneapi::mkl::rng::detail::engine_impl(queue) { + : oneapi::mkl::rng::detail::engine_impl(queue), + seed_(seed), + offset_(0) { rocrand_status status; ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_MRG32K3A); ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed); @@ -97,12 +99,19 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { mrg32k3a_impl(sycl::queue queue, std::initializer_list seed) : oneapi::mkl::rng::detail::engine_impl(queue) { throw oneapi::mkl::unimplemented("rng", "mrg32ka engine", - "multi-seed unsupported by cuRAND backend"); + "multi-seed unsupported by rocRAND backend"); } - mrg32k3a_impl(const mrg32k3a_impl* other) : oneapi::mkl::rng::detail::engine_impl(*other) { - throw oneapi::mkl::unimplemented("rng", "mrg32ka engine", - "copy construction unsupported by cuRAND backend"); + mrg32k3a_impl(const mrg32k3a_impl* other) + : oneapi::mkl::rng::detail::engine_impl(*other), + seed_(other->seed_), + offset_(other->offset_) { + rocrand_status status; + ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_MRG32K3A); + ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed_); + + // Allign this->engine_'s offset state with other->engine_'s offset + skip_ahead(offset_); } // Buffers API @@ -119,6 +128,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -134,6 +146,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -150,6 +165,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -165,6 +183,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -180,6 +201,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -196,6 +220,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::gaussian< @@ -211,22 +237,42 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r_ptr, n, + distr.mean(), distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -242,6 +288,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -257,50 +305,88 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r_ptr, n, + distr.m(), distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::int32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r_ptr, + n, distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r_ptr, n, + distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bits& distr, std::int64_t n, @@ -314,6 +400,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } // USM APIs @@ -330,6 +418,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -345,6 +436,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -362,6 +456,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -377,6 +474,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -392,6 +492,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -400,13 +503,17 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -414,31 +521,51 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -446,13 +573,17 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -460,31 +591,51 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bernoulli& distr, @@ -492,7 +643,7 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } @@ -501,37 +652,57 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r, n, + distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r, n, distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate, status, engine_, r, n); }); }); + + increment_internal_offset(n); + + return event; } virtual oneapi::mkl::rng::detail::engine_impl* copy_state() override { @@ -545,11 +716,11 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void skip_ahead(std::initializer_list num_to_skip) override { throw oneapi::mkl::unimplemented("rng", "skip_ahead", - "initializer list unsupported by cuRAND backend"); + "initializer list unsupported by rocRAND backend"); } virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { - throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by cuRAND backend"); + throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by rocRAND backend"); } virtual ~mrg32k3a_impl() override { @@ -559,8 +730,13 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { private: rocrand_generator engine_; std::uint32_t seed_; + std::uint64_t offset_; + + void increment_internal_offset(std::uint64_t n) { + offset_ += n; + } }; -#else // cuRAND backend is currently not supported on Windows +#else // rocRAND backend is currently not supported on Windows class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(sycl::queue queue, std::uint32_t seed) diff --git a/src/rng/backends/rocrand/philox4x32x10.cpp b/src/rng/backends/rocrand/philox4x32x10.cpp index 1b3511a1c..5bc241360 100644 --- a/src/rng/backends/rocrand/philox4x32x10.cpp +++ b/src/rng/backends/rocrand/philox4x32x10.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) * and Computing Centre (URZ) - * cuRAND back-end Copyright (c) 2021, The Regents of the University of + * rocRAND back-end Copyright (c) 2021, The Regents of the University of * California, through Lawrence Berkeley National Laboratory (subject to receipt * of any required approvals from the U.S. Dept. of Energy). All rights * reserved. @@ -86,7 +86,7 @@ namespace rocrand { #if !defined(_WIN64) /* - * Note that cuRAND consists of two pieces: a host (CPU) API and a device (GPU) + * Note that rocRAND consists of two pieces: a host (CPU) API and a device (GPU) * API. The host API acts like any standard library; the `rocrand.h' header is * included and the functions can be called as usual. The generator is * instantiated on the host and random numbers can be generated on either the @@ -110,7 +110,9 @@ namespace rocrand { class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(sycl::queue queue, std::uint64_t seed) - : oneapi::mkl::rng::detail::engine_impl(queue) { + : oneapi::mkl::rng::detail::engine_impl(queue), + seed_(seed), + offset_(0) { rocrand_status status; ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed); @@ -119,13 +121,19 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { philox4x32x10_impl(sycl::queue queue, std::initializer_list seed) : oneapi::mkl::rng::detail::engine_impl(queue) { throw oneapi::mkl::unimplemented("rng", "philox4x32x10 engine", - "multi-seed unsupported by cuRAND backend"); + "multi-seed unsupported by rocRAND backend"); } philox4x32x10_impl(const philox4x32x10_impl* other) - : oneapi::mkl::rng::detail::engine_impl(*other) { - throw oneapi::mkl::unimplemented("rng", "philox4x32x10 engine", - "copy construction unsupported by cuRAND backend"); + : oneapi::mkl::rng::detail::engine_impl(*other), + seed_(other->seed_), + offset_(other->offset_) { + rocrand_status status; + ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); + ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed_); + + // Allign this->engine_'s offset state with other->engine_'s offset + skip_ahead(offset_); } // Buffers API @@ -142,6 +150,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -157,6 +168,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -173,6 +187,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -188,6 +205,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -203,6 +223,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -219,6 +242,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::gaussian< @@ -234,22 +259,42 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r_ptr, n, + distr.mean(), distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -265,6 +310,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -280,50 +327,88 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r_ptr, n, + distr.m(), distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::int32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r_ptr, + n, distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r_ptr, n, + distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bits& distr, std::int64_t n, @@ -337,6 +422,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } // USM APIs @@ -353,6 +440,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -368,6 +458,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -385,6 +478,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -400,6 +496,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -415,6 +514,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -423,13 +525,17 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -437,31 +543,51 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -469,13 +595,17 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -483,31 +613,51 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bernoulli& distr, @@ -515,7 +665,7 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } @@ -524,37 +674,56 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r, n, + distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r, n, distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate, status, engine_, r, n); }); }); + + increment_internal_offset(n); + + return event; } virtual oneapi::mkl::rng::detail::engine_impl* copy_state() override { @@ -568,11 +737,11 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void skip_ahead(std::initializer_list num_to_skip) override { throw oneapi::mkl::unimplemented("rng", "skip_ahead", - "initializer list unsupported by cuRAND backend"); + "initializer list unsupported by rocRAND backend"); } virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { - throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by cuRAND backend"); + throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by rocRAND backend"); } virtual ~philox4x32x10_impl() override { @@ -581,8 +750,14 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { private: rocrand_generator engine_; + std::uint64_t seed_; + std::uint64_t offset_; + + void increment_internal_offset(std::uint64_t n) { + offset_ += n; + } }; -#else // cuRAND backend is currently not supported on Windows +#else // rocRAND backend is currently not supported on Windows class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(sycl::queue queue, std::uint64_t seed) @@ -859,8 +1034,7 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { #endif oneapi::mkl::rng::detail::engine_impl* create_philox4x32x10(sycl::queue queue, std::uint64_t seed) { - auto a = new philox4x32x10_impl(queue, seed); - return a; + return new philox4x32x10_impl(queue, seed); } oneapi::mkl::rng::detail::engine_impl* create_philox4x32x10( diff --git a/src/rng/backends/rocrand/rocrand_helper.hpp b/src/rng/backends/rocrand/rocrand_helper.hpp index 6be9269a6..5f759a695 100644 --- a/src/rng/backends/rocrand/rocrand_helper.hpp +++ b/src/rng/backends/rocrand/rocrand_helper.hpp @@ -315,10 +315,10 @@ class rocm_error : virtual public std::runtime_error { } }; -#define HIP_ERROR_FUNC(name, err, ...) \ - err = name(__VA_ARGS__); \ - if (err != HIP_SUCCESS) { \ - throw hip_error(std::string(#name) + std::string(" : "), err); \ +#define HIP_ERROR_FUNC(name, err, ...) \ + err = name(__VA_ARGS__); \ + if (err != HIP_SUCCESS) { \ + throw rocm_error(std::string(#name) + std::string(" : "), err); \ } #define ROCRAND_CALL(func, status, ...) \ diff --git a/src/rng/backends/rocrand/rocrand_task.hpp b/src/rng/backends/rocrand/rocrand_task.hpp index 4646ca342..2588dc901 100644 --- a/src/rng/backends/rocrand/rocrand_task.hpp +++ b/src/rng/backends/rocrand/rocrand_task.hpp @@ -43,6 +43,9 @@ static inline void host_task_internal(H &cgh, A acc, E e, F f) { auto r_ptr = reinterpret_cast( ih.get_native_mem(acc)); f(r_ptr); + + hipError_t err; + HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); }); } @@ -53,6 +56,9 @@ static inline void host_task_internal(H &cgh, E e, F f) { auto stream = ih.get_native_queue(); ROCRAND_CALL(rocrand_set_stream, status, e, stream); f(ih); + + hipError_t err; + HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); }); } #endif