@@ -212,19 +212,23 @@ nn::Int32Tensor2D QINCoStep::encode(
212212 // repeated codebook
213213 Tensor2D zqs_r (n * K, d); // size n, K, d
214214 Tensor2D cc (n * K, d * 2 ); // size n, K, d * 2
215- size_t d_2 = this ->d ;
216215
217- auto copy_row = [d_2](Tensor2D& t, size_t i, size_t j, const float * data) {
218- assert (i <= t.shape [0 ] && j <= t.shape [1 ]);
219- memcpy (t.data () + i * t.shape [1 ] + j, data, sizeof (float ) * d_2);
220- };
216+ size_t local_d = this ->d ;
217+
218+ auto copy_row =
219+ [local_d](Tensor2D& t, size_t i, size_t j, const float * data) {
220+ assert (i <= t.shape [0 ] && j <= t.shape [1 ]);
221+ memcpy (t.data () + i * t.shape [1 ] + j,
222+ data,
223+ sizeof (float ) * local_d);
224+ };
221225
222226 // manual broadcasting
223227 for (size_t i = 0 ; i < n; i++) {
224228 for (size_t j = 0 ; j < K; j++) {
225- copy_row (zqs_r, i * K + j, 0 , codebook.data () + j * d_2 );
226- copy_row (cc, i * K + j, 0 , codebook.data () + j * d_2 );
227- copy_row (cc, i * K + j, d_2 , xhat.data () + i * d_2 );
229+ copy_row (zqs_r, i * K + j, 0 , codebook.data () + j * d );
230+ copy_row (cc, i * K + j, 0 , codebook.data () + j * d );
231+ copy_row (cc, i * K + j, d , xhat.data () + i * d );
228232 }
229233 }
230234
@@ -237,13 +241,13 @@ nn::Int32Tensor2D QINCoStep::encode(
237241
238242 // add the xhat
239243 for (size_t i = 0 ; i < n; i++) {
240- float * zqs_r_row = zqs_r.data () + i * K * d_2 ;
241- const float * xhat_row = xhat.data () + i * d_2 ;
244+ float * zqs_r_row = zqs_r.data () + i * K * d ;
245+ const float * xhat_row = xhat.data () + i * d ;
242246 for (size_t l = 0 ; l < K; l++) {
243- for (size_t j = 0 ; j < d_2 ; j++) {
247+ for (size_t j = 0 ; j < d ; j++) {
244248 zqs_r_row[j] += xhat_row[j];
245249 }
246- zqs_r_row += d_2 ;
250+ zqs_r_row += d ;
247251 }
248252 }
249253
@@ -252,31 +256,31 @@ nn::Int32Tensor2D QINCoStep::encode(
252256 float * res = nullptr ;
253257 if (residuals) {
254258 FAISS_THROW_IF_NOT (
255- residuals->shape [0 ] == n && residuals->shape [1 ] == d_2 );
259+ residuals->shape [0 ] == n && residuals->shape [1 ] == d );
256260 res = residuals->data ();
257261 }
258262
259263 for (size_t i = 0 ; i < n; i++) {
260- const float * q = x.data () + i * d_2 ;
261- const float * db = zqs_r.data () + i * K * d_2 ;
264+ const float * q = x.data () + i * d ;
265+ const float * db = zqs_r.data () + i * K * d ;
262266 float dis_min = HUGE_VALF;
263267 int64_t idx = -1 ;
264268 for (size_t j = 0 ; j < K; j++) {
265- float dis = fvec_L2sqr (q, db, d_2 );
269+ float dis = fvec_L2sqr (q, db, d );
266270 if (dis < dis_min) {
267271 dis_min = dis;
268272 idx = j;
269273 }
270- db += d_2 ;
274+ db += d ;
271275 }
272276 codes.v [i] = idx;
273277 if (res) {
274- const float * xhat_row = xhat.data () + i * d_2 ;
275- const float * xhat_next_row = zqs_r.data () + (i * K + idx) * d_2 ;
276- for (size_t j = 0 ; j < d_2 ; j++) {
278+ const float * xhat_row = xhat.data () + i * d ;
279+ const float * xhat_next_row = zqs_r.data () + (i * K + idx) * d ;
280+ for (size_t j = 0 ; j < d ; j++) {
277281 res[j] = xhat_next_row[j] - xhat_row[j];
278282 }
279- res += d_2 ;
283+ res += d ;
280284 }
281285 }
282286 return codes;
0 commit comments