diff --git a/faiss/gpu/GpuDistance.cu b/faiss/gpu/GpuDistance.cu index 38a62f03bb..599f4a3072 100644 --- a/faiss/gpu/GpuDistance.cu +++ b/faiss/gpu/GpuDistance.cu @@ -51,12 +51,18 @@ using namespace raft::distance; using namespace raft::neighbors; #endif +/// Caches device major version +int device_major_version = -1; + bool should_use_raft(GpuDistanceParams args) { - cudaDeviceProp prop; - int dev = args.device >= 0 ? args.device : getCurrentDevice(); - cudaGetDeviceProperties(&prop, dev); + if (device_major_version < 0) { + cudaDeviceProp prop; + int dev = args.device >= 0 ? args.device : getCurrentDevice(); + cudaGetDeviceProperties(&prop, dev); + device_major_version = prop.major; + } - if (prop.major < 7) + if (device_major_version < 7) return false; return args.use_raft; diff --git a/faiss/gpu/GpuIndex.cu b/faiss/gpu/GpuIndex.cu index d1ae3b5384..f91b7dc9c5 100644 --- a/faiss/gpu/GpuIndex.cu +++ b/faiss/gpu/GpuIndex.cu @@ -42,11 +42,17 @@ constexpr idx_t kAddVecSize = (idx_t)512 * 1024; // FIXME: parameterize based on algorithm need constexpr idx_t kSearchVecSize = (idx_t)32 * 1024; +/// Caches device major version +extern int device_major_version; + bool should_use_raft(GpuIndexConfig config_) { - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, config_.device); + if (device_major_version < 0) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, config_.device); + device_major_version = prop.major; + } - if (prop.major < 7) + if (device_major_version < 7) return false; return config_.use_raft;