Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions pyg_lib/csrc/random/cpu/randint_engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <torch/torch.h>
Comment thread
ZenoTan marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <torch/torch.h>
#include <ATen/Aten.h>

and replace all torch:: with at::. Just found out that this heavily improves compilation time.


#include <limits.h>

namespace pyg {
namespace random {
const int RAND_PREFETCH_THRESHOLD = 128;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this should be called RAND_PREFETCH_SIZE instead of THRESHOLD.

const int RAND_PREFETCH_BITS = 64;
Comment thread
ZenoTan marked this conversation as resolved.
template <typename T>
class RandintEngine {
public:
RandintEngine() : size_(RAND_PREFETCH_THRESHOLD), bits_(64) {
prefetch_randint_ = torch::randint(
0L, std::numeric_limits<int64_t>::max(), {RAND_PREFETCH_THRESHOLD},
torch::TensorOptions().dtype(torch::kInt64));
}
Comment thread
ZenoTan marked this conversation as resolved.
Outdated
template <unsigned B>
T rand(T range) {
T num;
if (bits_ < B) {
Comment thread
ZenoTan marked this conversation as resolved.
Outdated
if (size_ > 0) {
size_--;
bits_ = 64;
} else {
prefetch_randint_ = torch::randint(
0L, std::numeric_limits<int64_t>::max(), {RAND_PREFETCH_THRESHOLD},
torch::TensorOptions().dtype(torch::kInt64));
size_ = RAND_PREFETCH_THRESHOLD;
bits_ = RAND_PREFETCH_BITS;
}
}
int64_t* prefetch_ptr = prefetch_randint_.data_ptr<int64_t>();
Comment thread
ZenoTan marked this conversation as resolved.
Outdated
int64_t res = (prefetch_ptr[size_ - 1] % range) & ((1ULL << B) - 1);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want ((1ULL << B) - 1);? to make sure we return positive numbers? I think maybe we can make sure whatever in prefetch_ptr is unsigned.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Treat random numbers as just unsigned bits by pointer casting

return (T)res;
}

T operator()(T beg, T end) {
T range = end - beg;
if (range <= (1 << 15)) {
return rand<16>(range) + beg;
} else if (range <= (1 << 31)) {
return rand<32>(range) + beg;
}
return rand<63>(range) + beg;
}

private:
torch::Tensor prefetch_randint_;
int size_;
int bits_;
};

} // namespace random

} // namespace pyg
79 changes: 79 additions & 0 deletions test/csrc/random/test_randint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <gtest/gtest.h>

#include <vector>

#include "../../../pyg_lib/csrc/random/cpu/randint_engine.h"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a simple test to ensure random maybe 1000 times from a seed without duplicates.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one. It sometimes failed for 10000 times so I just have 1000 times. You have made a good guess :)

TEST(RandintPrefetchTest, BasicAssertions) {
pyg::random::RandintEngine<int64_t> eng;

// Test many times to enable prefetching

int iter = 10000;
int64_t beg = 86421357;
int64_t end = 97538642;

for (int i = 0; i < iter; i++) {
auto res = eng(beg, end);
EXPECT_LT(res, end);
EXPECT_GE(res, beg);
}
}

TEST(RandintValidTest, BasicAssertions) {
pyg::random::RandintEngine<int64_t> eng;

// Test ranges

std::vector<unsigned> test_bits{10, 20, 30, 40, 50};

for (auto b : test_bits) {
int64_t beg = 321;
int64_t end = beg + (1ULL << b);
auto res = eng(beg, end);
EXPECT_LT(res, end);
EXPECT_GE(res, beg);
}

// Test types

pyg::random::RandintEngine<unsigned short> eng_short_unsigned;
int64_t beg = 12345;
int64_t end = 54321;
auto res_short_unsigned = eng(beg, end);
EXPECT_LT(res_short_unsigned, end);
EXPECT_GE(res_short_unsigned, beg);

pyg::random::RandintEngine<unsigned> eng_int_unsigned;
beg = 12345678;
end = 87654321;
auto res_int_unsigned = eng(beg, end);
EXPECT_LT(res_int_unsigned, end);
EXPECT_GE(res_int_unsigned, beg);

pyg::random::RandintEngine<int> eng_int_signed;
beg = 12345678;
end = 87654321;
auto res_int_signed = eng(beg, end);
EXPECT_LT(res_int_signed, end);
EXPECT_GE(res_int_signed, beg);
}

TEST(RandintSeedTest, BasicAssertions) {
int64_t beg = 12345678;
int64_t end = 87654321;

torch::manual_seed(147);
pyg::random::RandintEngine<int64_t> eng1;

std::vector<int64_t> res;
for (int i = 0; i < 100; i++) {
res.push_back(eng1(beg, end));
}

torch::manual_seed(147);
pyg::random::RandintEngine<int64_t> eng2;
for (auto r : res) {
EXPECT_EQ(eng2(beg, end), r);
}
}