Skip to content

Commit 24da9aa

Browse files
authored
Ann-bench: fix unsafe lazy blobs (#828)
The ann-bench dataset uses lazy-loading blobs to move data between storage and host and device memory. The data may be moved between memory spaces at the moment some properties/pointers are requested. In the search throughput mode, this sometimes causes a problem: two concurrent benchmark threads access the same property and concurrently modify the state of the blobs, which leads to various segfaults. The fix is to guard the critical sections with a mutex lock. There shouldn't be any impact on benchmark QPS results. Only one method, `dataset->dim()` is accessed within the benchmark loop. To avoid locking the mutex in this method, this PR changes the way `dim()` is evaluated; it's cached in `dim_` variable while still maintaining the behavior of loading it either from the query set or the base set depending on what is available/accessed first. Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tarang Jain (https://github.com/tarang-jain) - Tamas Bela Feher (https://github.com/tfeher) URL: #828
1 parent 5b1825c commit 24da9aa

1 file changed

Lines changed: 52 additions & 22 deletions

File tree

cpp/bench/ann/src/common/dataset.hpp

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
#include "ann_types.hpp"
1919
#include "blob.hpp"
2020

21+
#include <atomic>
2122
#include <cstdint>
2223
#include <cstdio>
24+
#include <mutex>
2325
#include <optional>
2426
#include <random>
2527
#include <string>
@@ -56,8 +58,19 @@ struct dataset {
5658
std::optional<blob<IdxT>> ground_truth_set_;
5759
std::optional<blob<bitset_carrier_type>> filter_bitset_;
5860

59-
mutable bool base_set_accessed_ = false;
60-
mutable bool query_set_accessed_ = false;
61+
// Protects the lazy mutations of the blobs accessed by multiple threads
62+
mutable std::mutex mutex_;
63+
// The dim can be read either from the training set or from the query set.
64+
// This cache variable is filled from either of the two sets loaded first.
65+
mutable std::atomic<int> dim_ = -1;
66+
67+
// Cache the dim value from the passed blob.
68+
inline void cache_dim(const blob<DataT>& blob) const
69+
{
70+
if (dim_.load(std::memory_order_relaxed) == -1) {
71+
dim_.store(static_cast<int>(blob.n_cols()), std::memory_order_relaxed);
72+
}
73+
}
6174

6275
public:
6376
dataset(std::string name,
@@ -98,71 +111,87 @@ struct dataset {
98111
[[nodiscard]] auto distance() const -> std::string { return distance_; }
99112
[[nodiscard]] auto dim() const -> int
100113
{
101-
// If any of base/query set are already accessed, use those
102-
if (base_set_accessed_) { return static_cast<int>(base_set_.n_cols()); }
103-
if (query_set_accessed_) { return static_cast<int>(query_set_.n_cols()); }
114+
auto d = dim_.load(std::memory_order_relaxed);
115+
if (d > -1) { return d; }
116+
std::lock_guard<std::mutex> lock(mutex_);
104117
// Otherwise, try reading both (one of the two sets may be missing)
105118
try {
106-
query_set_accessed_ = true;
107-
return static_cast<int>(query_set_.n_cols());
119+
d = static_cast<int>(query_set_.n_cols());
108120
} catch (const std::runtime_error& e) {
109121
// Any exception raised above will re-raise next time we try to access the query set.
110-
query_set_accessed_ = false;
111122
query_set_.reset_lazy_state();
123+
// If the query set is not accessible, use the base set.
124+
// Don't catch the exception here, because we have nothing else to do anyway.
125+
d = static_cast<int>(base_set_.n_cols());
112126
}
113-
base_set_accessed_ = true;
114-
return static_cast<int>(base_set_.n_cols());
127+
dim_.store(d, std::memory_order_relaxed);
128+
return d;
115129
}
116130
[[nodiscard]] auto max_k() const -> uint32_t
117131
{
132+
std::lock_guard<std::mutex> lock(mutex_);
118133
if (ground_truth_set_.has_value()) { return ground_truth_set_->n_cols(); }
119134
return 0;
120135
}
121136
[[nodiscard]] auto base_set_size() const -> size_t
122137
{
123-
base_set_accessed_ = true;
124-
return base_set_.n_rows();
138+
std::lock_guard<std::mutex> lock(mutex_);
139+
auto r = base_set_.n_rows();
140+
cache_dim(base_set_);
141+
return r;
125142
}
126143
[[nodiscard]] auto query_set_size() const -> size_t
127144
{
128-
query_set_accessed_ = true;
129-
return query_set_.n_rows();
145+
std::lock_guard<std::mutex> lock(mutex_);
146+
auto r = query_set_.n_rows();
147+
cache_dim(query_set_);
148+
return r;
130149
}
131150

132151
[[nodiscard]] auto gt_set() const -> const IdxT*
133152
{
153+
std::lock_guard<std::mutex> lock(mutex_);
134154
if (ground_truth_set_.has_value()) { return ground_truth_set_->data(); }
135155
return nullptr;
136156
}
137157

138158
[[nodiscard]] auto query_set() const -> const DataT*
139159
{
140-
query_set_accessed_ = true;
141-
return query_set_.data();
160+
std::lock_guard<std::mutex> lock(mutex_);
161+
auto* r = query_set_.data();
162+
cache_dim(query_set_);
163+
return r;
142164
}
143165
[[nodiscard]] auto query_set(MemoryType memory_type,
144166
HugePages request_hugepages_2mb = HugePages::kDisable) const
145167
-> const DataT*
146168
{
147-
query_set_accessed_ = true;
148-
return query_set_.data(memory_type, request_hugepages_2mb);
169+
std::lock_guard<std::mutex> lock(mutex_);
170+
auto* r = query_set_.data(memory_type, request_hugepages_2mb);
171+
cache_dim(query_set_);
172+
return r;
149173
}
150174

151175
[[nodiscard]] auto base_set() const -> const DataT*
152176
{
153-
base_set_accessed_ = true;
154-
return base_set_.data();
177+
std::lock_guard<std::mutex> lock(mutex_);
178+
auto* r = base_set_.data();
179+
cache_dim(base_set_);
180+
return r;
155181
}
156182
[[nodiscard]] auto base_set(MemoryType memory_type,
157183
HugePages request_hugepages_2mb = HugePages::kDisable) const
158184
-> const DataT*
159185
{
160-
base_set_accessed_ = true;
161-
return base_set_.data(memory_type, request_hugepages_2mb);
186+
std::lock_guard<std::mutex> lock(mutex_);
187+
auto* r = base_set_.data(memory_type, request_hugepages_2mb);
188+
cache_dim(base_set_);
189+
return r;
162190
}
163191

164192
[[nodiscard]] auto filter_bitset() const -> const bitset_carrier_type*
165193
{
194+
std::lock_guard<std::mutex> lock(mutex_);
166195
if (filter_bitset_.has_value()) { return filter_bitset_->data(); }
167196
return nullptr;
168197
}
@@ -171,6 +200,7 @@ struct dataset {
171200
HugePages request_hugepages_2mb = HugePages::kDisable) const
172201
-> const bitset_carrier_type*
173202
{
203+
std::lock_guard<std::mutex> lock(mutex_);
174204
if (filter_bitset_.has_value()) {
175205
return filter_bitset_->data(memory_type, request_hugepages_2mb);
176206
}

0 commit comments

Comments
 (0)