Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f6d9404
add new log2 operation
Joejiong Oct 30, 2020
f1b4f73
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Oct 30, 2020
cca2f9a
fix sample code
Joejiong Oct 30, 2020
3049700
test fp16
Joejiong Oct 30, 2020
665f827
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Oct 30, 2020
eebde3d
fix fp16_error_ratio
Joejiong Oct 30, 2020
933dd5d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Oct 30, 2020
02fcf16
fix latex
Joejiong Nov 2, 2020
cd94a3a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 2, 2020
d1838bf
fix paddle2.0 api style
Joejiong Nov 2, 2020
b72f42c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 2, 2020
e5a5c26
add dygraph example code
Joejiong Nov 3, 2020
8ae9d2c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 3, 2020
bfc4d79
fix doc gen
Joejiong Nov 3, 2020
e08c326
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 3, 2020
9eadc71
clean doc fluid
Joejiong Nov 4, 2020
faedba2
change directory
Joejiong Nov 5, 2020
f0151d0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 5, 2020
bb143ad
optimize log2
Joejiong Nov 5, 2020
6ca56f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 5, 2020
c8b5fcf
clean code
Joejiong Nov 5, 2020
23e94c0
fix float16
Joejiong Nov 6, 2020
4b35414
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 6, 2020
d508c4c
remove grad_atol
Joejiong Nov 6, 2020
9a6d1bb
fix example code
Joejiong Nov 9, 2020
c84a99f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 9, 2020
c7023f9
clean example
Joejiong Nov 10, 2020
816a086
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Joejiong Nov 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/operators/activation_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ Natural logarithm of x.

)DOC";

UNUSED constexpr char Log2Doc[] = R"DOC(
Log2 Activation Operator.

$$out = \log_2x$$

logarithm of x base to 2.

)DOC";

UNUSED constexpr char Log1pDoc[] = R"DOC(
Log Activation Operator.

Expand Down Expand Up @@ -697,6 +706,7 @@ REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Log2, Log2Doc);
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/operators/activation_op.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,27 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// log2(x) = logarithm to the base 2 of the elements of x
template <typename T>
struct Log2Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(2));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽管数学上等价,但是计算机应该算以log2为底会更简单快速。比算log(x)/log(2)快。如果有空可以自己写写看这里有没有更快的实现。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里看看有空调查和实现一下,没空时现在这样也能勉强接受。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后实现,谢谢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个问题,除了性能,还有计算误差的问题,建议再调研下,看是否能优化。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成tensor原生实现,thx
done;

}
};

// the gradient of log2(x) is 1/(x*ln(2))
template <typename T>
struct Log2GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(2)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -1908,6 +1929,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(log2, Log2, Log2Functor, Log2GradFunctor); \
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
Expand Down
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
from .tensor.math import floor #DEFINE_ALIAS
from .tensor.math import increment #DEFINE_ALIAS
from .tensor.math import log #DEFINE_ALIAS
from .tensor.math import log2 #DEFINE_ALIAS
from .tensor.math import multiplex #DEFINE_ALIAS
from .tensor.math import pow #DEFINE_ALIAS
from .tensor.math import reciprocal #DEFINE_ALIAS
Expand Down
155 changes: 120 additions & 35 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.log_sigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[11, 17], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.log_sigmoid, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[11, 17], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[11, 17], dtype='float16')
F.log_sigmoid(x_fp16)


Expand Down Expand Up @@ -260,10 +262,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.tanh, 1)
# The input dtype must be float16, float32.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.tanh, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.tanh(x_fp16)


Expand Down Expand Up @@ -519,10 +523,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.tanhshrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.tanhshrink, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.tanhshrink(x_fp16)


Expand Down Expand Up @@ -616,10 +622,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardshrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.hardshrink, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.hardshrink(x_fp16)


Expand Down Expand Up @@ -676,10 +684,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardtanh, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.hardtanh, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.hardtanh(x_fp16)


Expand Down Expand Up @@ -759,13 +769,16 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.softshrink, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.softshrink, x_int32)
# The threshold must be no less than zero
x_fp32 = paddle.fluid.data(name='x_fp32', shape=[12, 10], dtype='float32')
x_fp32 = paddle.fluid.data(
name='x_fp32', shape=[12, 10], dtype='float32')
self.assertRaises(ValueError, F.softshrink, x_fp32, -1.0)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.softshrink(x_fp16)


Expand Down Expand Up @@ -1010,10 +1023,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[10, 12], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[10, 12], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[10, 12], dtype='float16')
F.relu(x_fp16)


Expand Down Expand Up @@ -1119,10 +1134,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.leaky_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.leaky_relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.leaky_relu(x_fp16)


Expand Down Expand Up @@ -1218,10 +1235,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.gelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[11, 17], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.gelu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[11, 17], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[11, 17], dtype='float16')
F.gelu(x_fp16)


Expand Down Expand Up @@ -1368,10 +1387,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.relu6, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.relu6, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.relu6(x_fp16)


Expand Down Expand Up @@ -1455,10 +1476,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardswish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.hardswish, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.hardswish(x_fp16)


Expand Down Expand Up @@ -1572,10 +1595,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.elu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[10, 12], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.elu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[10, 12], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[10, 12], dtype='float16')
F.elu(x_fp16)


Expand Down Expand Up @@ -1624,6 +1649,55 @@ def test_error(self):
self.assertRaises(TypeError, fluid.layers.log, in2)


class TestLog2(TestActivation):
def setUp(self):
self.op_type = "log2"
self.init_dtype()

x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.log2(x)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')

def test_error(self):
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")

self.assertRaises(TypeError, paddle.log2, in1)
self.assertRaises(TypeError, paddle.log2, in2)

def test_api(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64")
data_x = paddle.static.data(
name="data_x", shape=[11, 17], dtype="float64")

out1 = paddle.log2(data_x)
exe = paddle.static.Executor(place=fluid.CPUPlace())
exe.run(paddle.static.default_startup_program())
res1 = exe.run(paddle.static.default_main_program(),
feed={"data_x": input_x},
fetch_list=[out1])
expected_res = np.log2(input_x)
self.assertTrue(np.allclose(res1, expected_res))

# dygraph
with fluid.dygraph.guard():
np_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64")
data_x = paddle.to_tensor(np_x)
z = paddle.log2(data_x)
np_z = z.numpy()
z_expected = np.array(np.log2(np_x))
self.assertTrue(np.allclose(np_z, z_expected))


class TestLog1p(TestActivation):
def setUp(self):
self.op_type = "log1p"
Expand Down Expand Up @@ -1895,10 +1969,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.softplus, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.0后用paddle.data是更推荐的用法,可否看看这部分能否改成paddle.data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是Log1p, 我之后换一个pr,统一看看这个activate里面有多少要改的公共的这种需要迁移的,我统一迁移

name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.softplus, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.softplus(x_fp16)


Expand Down Expand Up @@ -1972,10 +2048,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.softsign, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.softsign, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.softsign(x_fp16)


Expand Down Expand Up @@ -2055,10 +2133,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.thresholded_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.thresholded_relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.thresholded_relu(x_fp16)


Expand Down Expand Up @@ -2154,10 +2234,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardsigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.hardsigmoid, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.hardsigmoid(x_fp16)


Expand Down Expand Up @@ -2232,10 +2314,12 @@ def test_errors(self):
# The input type must be Variable.
self.assertRaises(TypeError, F.swish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.swish, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16')
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.swish(x_fp16)


Expand Down Expand Up @@ -2347,6 +2431,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestReciprocal)
create_test_act_fp16_class(TestLog)
create_test_act_fp16_class(TestLog2, atol=5e-2)
create_test_act_fp16_class(TestLog1p, grad_atol=0.9)
create_test_act_fp16_class(TestSquare)
create_test_act_fp16_class(TestPow, atol=5e-2)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
from .math import atan #DEFINE_ALIAS
from .math import logsumexp #DEFINE_ALIAS
from .math import inverse #DEFINE_ALIAS
from .math import log2 #DEFINE_ALIAS
from .math import log1p #DEFINE_ALIAS
from .math import erf #DEFINE_ALIAS
# from .math import addcmul #DEFINE_ALIAS
Expand Down
Loading