Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions faiss/gpu/GpuDistance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ using namespace raft::distance;
using namespace raft::neighbors;
#endif

/// Caches device major version
int device_major_version = -1;

bool should_use_raft(GpuDistanceParams args) {
cudaDeviceProp prop;
int dev = args.device >= 0 ? args.device : getCurrentDevice();
cudaGetDeviceProperties(&prop, dev);
if (device_major_version < 0) {
cudaDeviceProp prop;
int dev = args.device >= 0 ? args.device : getCurrentDevice();
cudaGetDeviceProperties(&prop, dev);
device_major_version = prop.major;
}

if (prop.major < 7)
if (device_major_version < 7)
return false;

return args.use_raft;
Expand Down
12 changes: 9 additions & 3 deletions faiss/gpu/GpuIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ constexpr idx_t kAddVecSize = (idx_t)512 * 1024;
// FIXME: parameterize based on algorithm need
constexpr idx_t kSearchVecSize = (idx_t)32 * 1024;

/// Caches device major version
extern int device_major_version;

bool should_use_raft(GpuIndexConfig config_) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, config_.device);
if (device_major_version < 0) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, config_.device);
device_major_version = prop.major;
}

if (prop.major < 7)
if (device_major_version < 7)
return false;

return config_.use_raft;
Expand Down