-
Notifications
You must be signed in to change notification settings - Fork 60
randint utils for samplers
#26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e404549
f8455a6
8583458
a0f700c
213053b
9e2f09e
7b91c37
7f600c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,55 @@ | ||||||
| #include <torch/torch.h> | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
and replace all |
||||||
|
|
||||||
| #include <limits.h> | ||||||
|
|
||||||
| namespace pyg { | ||||||
| namespace random { | ||||||
| const int RAND_PREFETCH_THRESHOLD = 128; | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think this should be called RAND_PREFETCH_SIZE instead of THRESHOLD. |
||||||
| const int RAND_PREFETCH_BITS = 64; | ||||||
|
ZenoTan marked this conversation as resolved.
|
||||||
| template <typename T> | ||||||
| class RandintEngine { | ||||||
| public: | ||||||
| RandintEngine() : size_(RAND_PREFETCH_THRESHOLD), bits_(64) { | ||||||
| prefetch_randint_ = torch::randint( | ||||||
| 0L, std::numeric_limits<int64_t>::max(), {RAND_PREFETCH_THRESHOLD}, | ||||||
| torch::TensorOptions().dtype(torch::kInt64)); | ||||||
| } | ||||||
|
ZenoTan marked this conversation as resolved.
Outdated
|
||||||
| template <unsigned B> | ||||||
| T rand(T range) { | ||||||
| T num; | ||||||
| if (bits_ < B) { | ||||||
|
ZenoTan marked this conversation as resolved.
Outdated
|
||||||
| if (size_ > 0) { | ||||||
| size_--; | ||||||
| bits_ = 64; | ||||||
| } else { | ||||||
| prefetch_randint_ = torch::randint( | ||||||
| 0L, std::numeric_limits<int64_t>::max(), {RAND_PREFETCH_THRESHOLD}, | ||||||
| torch::TensorOptions().dtype(torch::kInt64)); | ||||||
| size_ = RAND_PREFETCH_THRESHOLD; | ||||||
| bits_ = RAND_PREFETCH_BITS; | ||||||
| } | ||||||
| } | ||||||
| int64_t* prefetch_ptr = prefetch_randint_.data_ptr<int64_t>(); | ||||||
|
ZenoTan marked this conversation as resolved.
Outdated
|
||||||
| int64_t res = (prefetch_ptr[size_ - 1] % range) & ((1ULL << B) - 1); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we want
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Treat random numbers as just unsigned bits by pointer casting |
||||||
| return (T)res; | ||||||
| } | ||||||
|
|
||||||
| T operator()(T beg, T end) { | ||||||
| T range = end - beg; | ||||||
| if (range <= (1 << 15)) { | ||||||
| return rand<16>(range) + beg; | ||||||
| } else if (range <= (1 << 31)) { | ||||||
| return rand<32>(range) + beg; | ||||||
| } | ||||||
| return rand<63>(range) + beg; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| torch::Tensor prefetch_randint_; | ||||||
| int size_; | ||||||
| int bits_; | ||||||
| }; | ||||||
|
|
||||||
| } // namespace random | ||||||
|
|
||||||
| } // namespace pyg | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <vector> | ||
|
|
||
| #include "../../../pyg_lib/csrc/random/cpu/randint_engine.h" | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a simple test to ensure random maybe 1000 times from a seed without duplicates.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added one. It sometimes failed for 10000 times so I just have 1000 times. You have made a good guess :) |
||
| TEST(RandintPrefetchTest, BasicAssertions) { | ||
| pyg::random::RandintEngine<int64_t> eng; | ||
|
|
||
| // Test many times to enable prefetching | ||
|
|
||
| int iter = 10000; | ||
| int64_t beg = 86421357; | ||
| int64_t end = 97538642; | ||
|
|
||
| for (int i = 0; i < iter; i++) { | ||
| auto res = eng(beg, end); | ||
| EXPECT_LT(res, end); | ||
| EXPECT_GE(res, beg); | ||
| } | ||
| } | ||
|
|
||
| TEST(RandintValidTest, BasicAssertions) { | ||
| pyg::random::RandintEngine<int64_t> eng; | ||
|
|
||
| // Test ranges | ||
|
|
||
| std::vector<unsigned> test_bits{10, 20, 30, 40, 50}; | ||
|
|
||
| for (auto b : test_bits) { | ||
| int64_t beg = 321; | ||
| int64_t end = beg + (1ULL << b); | ||
| auto res = eng(beg, end); | ||
| EXPECT_LT(res, end); | ||
| EXPECT_GE(res, beg); | ||
| } | ||
|
|
||
| // Test types | ||
|
|
||
| pyg::random::RandintEngine<unsigned short> eng_short_unsigned; | ||
| int64_t beg = 12345; | ||
| int64_t end = 54321; | ||
| auto res_short_unsigned = eng(beg, end); | ||
| EXPECT_LT(res_short_unsigned, end); | ||
| EXPECT_GE(res_short_unsigned, beg); | ||
|
|
||
| pyg::random::RandintEngine<unsigned> eng_int_unsigned; | ||
| beg = 12345678; | ||
| end = 87654321; | ||
| auto res_int_unsigned = eng(beg, end); | ||
| EXPECT_LT(res_int_unsigned, end); | ||
| EXPECT_GE(res_int_unsigned, beg); | ||
|
|
||
| pyg::random::RandintEngine<int> eng_int_signed; | ||
| beg = 12345678; | ||
| end = 87654321; | ||
| auto res_int_signed = eng(beg, end); | ||
| EXPECT_LT(res_int_signed, end); | ||
| EXPECT_GE(res_int_signed, beg); | ||
| } | ||
|
|
||
| TEST(RandintSeedTest, BasicAssertions) { | ||
| int64_t beg = 12345678; | ||
| int64_t end = 87654321; | ||
|
|
||
| torch::manual_seed(147); | ||
| pyg::random::RandintEngine<int64_t> eng1; | ||
|
|
||
| std::vector<int64_t> res; | ||
| for (int i = 0; i < 100; i++) { | ||
| res.push_back(eng1(beg, end)); | ||
| } | ||
|
|
||
| torch::manual_seed(147); | ||
| pyg::random::RandintEngine<int64_t> eng2; | ||
| for (auto r : res) { | ||
| EXPECT_EQ(eng2(beg, end), r); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.