Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions python/paddle/v2/framework/tests/test_modified_huber_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,31 @@

def modified_huber_loss_forward(val):
if val < -1:
return -4 * val
return -4. * val
elif val < 1:
return (1 - val) * (1 - val)
return (1. - val) * (1. - val)
else:
return 0
return 0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow



class TestModifiedHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'modified_huber_loss'
samples_num = 32
self.inputs = {
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'),
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1))
}
product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1)

x_np = np.random.uniform(-2., 2., (samples_num, 1)).astype('float32')
y_np = np.random.choice([0, 1], samples_num).reshape(
(samples_num, 1)).astype('float32')
product_res = x_np * (2. * y_np - 1.)
# keep away from the junction of piecewise function
for pos, val in np.ndenumerate(product_res):
while abs(val - 1.) < 0.05:
x_np[pos] = np.random.uniform(-2., 2.)
y_np[pos] = np.random.choice([0, 1])
product_res[pos] = x_np[pos] * (2 * y_np[pos] - 1)
val = product_res[pos]

self.inputs = {'X': x_np, 'Y': y_np}
loss = np.vectorize(modified_huber_loss_forward)(product_res)

self.outputs = {
Expand All @@ -32,7 +41,7 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005)
self.check_grad(['X'], 'Out', max_relative_error=0.01)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a random question: could we have a systematic way to set the error tolerance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of indifferentiable points for forwarding operators. Only writers of operators could know which point is indifferentiable.

So I think that should be managed manually by Op writers.



if __name__ == '__main__':
Expand Down