Skip to content

Commit c33ddc7

Browse files
committed
Fix some bugs, add more unittests.
1 parent e9cc328 commit c33ddc7

3 files changed

Lines changed: 79 additions & 20 deletions

File tree

paddle/operators/squared_l2_distance_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
4949
"First dimension of target must be equal to input "
5050
"or to 1.");
5151

52-
ctx.Output<Tensor>("sub_result")->Resize(x_dims);
52+
ctx.Output<Tensor>("sub_result")
53+
->Resize({static_cast<int>(x_dims[0]),
54+
static_cast<int>(framework::product(x_dims) / x_dims[0])});
5355
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
5456
}
5557
};
@@ -97,8 +99,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
9799
"must be 1.");
98100
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
99101
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
100-
if (x_grad != nullptr) x_grad->Resize(x_dims);
101-
if (y_grad != nullptr) y_grad->Resize(y_dims);
102+
if (x_grad) x_grad->Resize(x_dims);
103+
if (y_grad) y_grad->Resize(y_dims);
102104
}
103105
};
104106

paddle/operators/squared_l2_distance_op.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
5353
auto y_dims = y.dimensions();
5454
// buffer the substraction result
5555
if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) {
56-
auto y_broadcast_dims = y_dims;
57-
y_broadcast_dims[0] = x_dims[0];
58-
sub_result.device(place) = x - y.broadcast(y_broadcast_dims);
56+
sub_result.device(place) =
57+
x -
58+
y.broadcast(Eigen::array<int, 2>({static_cast<int>(x_dims[0]), 1}));
5959
} else {
6060
sub_result.device(place) = x - y;
6161
}
62-
63-
z.device(place) = sub_result.pow(2).sum(Eigen::array<int, 1>({1}));
62+
auto sub_res_pow2 = sub_result * sub_result;
63+
z.device(place) =
64+
sub_res_pow2.sum(Eigen::array<int, 1>({1}))
65+
.reshape(Eigen::array<int, 2>({static_cast<int>(x_dims[0]), 1}));
6466
}
6567
};
6668

@@ -86,7 +88,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
8688

8789
// propagate back to input
8890
auto eigen_place = context.GetEigenDevice<Place>();
89-
if (x_g != nullptr) {
91+
if (x_g) {
9092
x_g->mutable_data<T>(context.GetPlace());
9193
// eigen matrix
9294
auto x_grad =
@@ -95,7 +97,7 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
9597
x_grad.device(eigen_place) = grad_mat;
9698
}
9799

98-
if (y_g != nullptr) {
100+
if (y_g) {
99101
y_g->mutable_data<T>(context.GetPlace());
100102
auto y_grad =
101103
EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols}));
@@ -107,8 +109,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
107109
if (sub_result.dimensions()[0] == y_dims[0]) {
108110
y_grad.device(eigen_place) = -1 * grad_mat;
109111
} else {
112+
auto col_sum_res = -1 * (grad_mat.sum(Eigen::array<int, 1>({0})));
110113
y_grad.device(eigen_place) =
111-
-1 * (grad_mat.sum(Eigen::array<int, 2>({0})));
114+
col_sum_res.reshape(Eigen::array<int, 2>({1, cols}));
112115
}
113116
}
114117
}

python/paddle/v2/framework/tests/test_squared_l2_distance_op.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,84 @@
44
import numpy as np
55

66

7-
class TestSquaredL2DistanceOp(unittest.TestCase):
7+
class TestSquaredL2DistanceOp_f0(unittest.TestCase):
88
__metaclass__ = OpTestMeta
99

1010
def setUp(self):
1111
self.type = 'squared_l2_distance'
1212
self.inputs = {
13-
'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'),
14-
'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32')
13+
'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'),
14+
'Y': np.random.uniform(0.1, 1., (32, 64)).astype('float32')
1515
}
16-
subRes = self.inputs['X'] - self.inputs['Y']
17-
output = subRes * subRes
16+
sub_res = self.inputs['X'] - self.inputs['Y']
17+
output = sub_res * sub_res
1818
self.outputs = {
19-
'sub_result': subRes,
19+
'sub_result': sub_res,
20+
'Out': np.expand_dims(output.sum(1), 1)
21+
}
22+
23+
24+
class TestSquaredL2DistanceOp_f1(unittest.TestCase):
25+
__metaclass__ = OpTestMeta
26+
27+
def setUp(self):
28+
self.type = 'squared_l2_distance'
29+
self.inputs = {
30+
'X': np.random.uniform(0.1, 1., (32, 64)).astype('float32'),
31+
'Y': np.random.uniform(0.1, 1., (1, 64)).astype('float32')
32+
}
33+
sub_res = self.inputs['X'] - self.inputs['Y']
34+
output = sub_res * sub_res
35+
self.outputs = {
36+
'sub_result': sub_res,
37+
'Out': np.expand_dims(output.sum(1), 1)
38+
}
39+
40+
41+
class TestSquaredL2DistanceOp_f2(unittest.TestCase):
42+
__metaclass__ = OpTestMeta
43+
44+
def setUp(self):
45+
self.type = 'squared_l2_distance'
46+
self.inputs = {
47+
'X': np.random.uniform(0.1, 1., (32, 64, 128)).astype('float32'),
48+
'Y': np.random.uniform(0.1, 1., (1, 64, 128)).astype('float32')
49+
}
50+
sub_res = self.inputs['X'] - self.inputs['Y']
51+
sub_res = sub_res.reshape((32, 64 * 128))
52+
output = sub_res * sub_res
53+
self.outputs = {
54+
'sub_result': sub_res,
2055
'Out': np.expand_dims(output.sum(1), 1)
2156
}
2257

2358

2459
class TestSquaredL2DistanceGradOp(GradientChecker):
25-
def test_squared_l2_distance(self):
60+
def test_squared_l2_distance_b0(self):
61+
op = create_op("squared_l2_distance")
62+
inputs = {
63+
'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'),
64+
'Y': np.random.uniform(0.1, .6, (2, 3)).astype('float32')
65+
}
66+
self.compare_grad(op, inputs)
67+
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
68+
69+
def test_squared_l2_distance_b1(self):
70+
op = create_op("squared_l2_distance")
71+
inputs = {
72+
'X': np.random.uniform(0.1, .6, (2, 3)).astype('float32'),
73+
'Y': np.random.uniform(0.1, .6, (1, 3)).astype('float32')
74+
}
75+
self.compare_grad(op, inputs)
76+
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
77+
78+
def test_squared_l2_distance_b2(self):
2679
op = create_op("squared_l2_distance")
2780
inputs = {
28-
'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'),
29-
'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32')
81+
'X': np.random.uniform(0.1, .6, (2, 3, 4)).astype('float32'),
82+
'Y': np.random.uniform(0.1, .6, (1, 3, 4)).astype('float32')
3083
}
84+
self.compare_grad(op, inputs)
3185
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
3286

3387

0 commit comments

Comments
 (0)