Skip to content

Commit 758ee13

Browse files
committed
fixed bug in test and syntax for thrust::transform
1 parent 82c02d6 commit 758ee13

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/LinAlg/VectorCudaKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ double log_barr_wei_obj_kernel(int n, double* d1, const double* id, const double
10621062
// compute v[i] = ( log(v[i]) if v[i]>0, otherwise 0 )
10631063
thrust::transform(thrust::device, v_temp, v_temp+n, v_temp, log_select_op);
10641064
// compute v[i] = w[i]*v[i]
1065-
thrust::transform(thrust::device, v_temp, v_temp+n, v_temp, mult_op);
1065+
thrust::transform(thrust::device, v_temp, v_temp+n, wei_v, v_temp, mult_op);
10661066
// sum up
10671067
const double sum = thrust::reduce(thrust::device, v_temp, v_temp+n, 0.0, plus_op);
10681068

tests/LinAlg/vectorTests.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,8 +1240,8 @@ class VectorTests : public TestBase
12401240
setLocalElement(&weight, N-1, 100);
12411241

12421242
real_type expected = (N - 1) * std::log(x_val) / N;
1243-
real_type result = x.logBarrier_local(pattern);
1244-
1243+
real_type result = x.logBarrierWeighted_local(pattern, weight);
1244+
12451245
int fail = !isEqual(result, expected);
12461246

12471247
// Make sure pattern eliminates the correct elements
@@ -1253,7 +1253,7 @@ class VectorTests : public TestBase
12531253
setLocalElement(&weight, N - 1, w_val);
12541254

12551255
expected = std::log(x_val)*w_val;
1256-
result = x.logBarrier_local(pattern);
1256+
result = x.logBarrierWeighted_local(pattern, weight);
12571257
fail += !isEqual(result, expected);
12581258

12591259
printMessage(fail, __func__, rank);

0 commit comments

Comments
 (0)