Skip to content

Commit af9dcb2

Browse files
authored
supplet several interface of static Variable to consistent with dygraph Tensor (#33330)
As the title
1 parent ba7e2a9 commit af9dcb2

File tree

10 files changed

+387
-28
lines changed

10 files changed

+387
-28
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/share_data_op.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ShareDataOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShareData");
27+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShareData");
28+
auto in_type = ctx->GetInputsVarType("X")[0];
29+
auto out_type = ctx->GetOutputsVarType("Out")[0];
30+
31+
PADDLE_ENFORCE_EQ(
32+
in_type == framework::proto::VarType::LOD_TENSOR ||
33+
in_type == framework::proto::VarType::SELECTED_ROWS,
34+
true, platform::errors::InvalidArgument(
35+
"Type of Variable[X] must be LoDTensor or SelectedRows!"));
36+
PADDLE_ENFORCE_EQ(
37+
in_type, out_type,
38+
platform::errors::InvalidArgument(
39+
"The type of input (X) and output (Out) are inconsistent."));
40+
41+
ctx->ShareDim("X", "Out");
42+
}
43+
};
44+
45+
class ShareDataOpMaker : public framework::OpProtoAndCheckerMaker {
46+
public:
47+
void Make() override {
48+
AddInput("X", "(Tensor), The input tensor of share_data op");
49+
AddOutput("Out", "(Tensor), The output tensor of share_data op");
50+
AddComment(R"DOC(
51+
ShareData Operator.
52+
53+
Return a tensor $Out$ that shares data with the input tensor $X$ and without tensor copy.
54+
)DOC");
55+
}
56+
};
57+
58+
} // namespace operators
59+
} // namespace paddle
60+
61+
namespace ops = paddle::operators;
62+
REGISTER_OPERATOR(
63+
share_data, ops::ShareDataOp, ops::ShareDataOpMaker,
64+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
65+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
66+
REGISTER_OP_CPU_KERNEL(share_data, ops::ShareDataKernel<bool>,
67+
ops::ShareDataKernel<int>, ops::ShareDataKernel<int8_t>,
68+
ops::ShareDataKernel<uint8_t>,
69+
ops::ShareDataKernel<paddle::platform::float16>,
70+
ops::ShareDataKernel<int64_t>,
71+
ops::ShareDataKernel<float>,
72+
ops::ShareDataKernel<double>)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/share_data_op.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
share_data, paddle::operators::ShareDataKernel<bool>,
19+
paddle::operators::ShareDataKernel<int>,
20+
paddle::operators::ShareDataKernel<int8_t>,
21+
paddle::operators::ShareDataKernel<uint8_t>,
22+
paddle::operators::ShareDataKernel<paddle::platform::float16>,
23+
paddle::operators::ShareDataKernel<int64_t>,
24+
paddle::operators::ShareDataKernel<float>,
25+
paddle::operators::ShareDataKernel<double>);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
class ShareDataKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext &ctx) const override {
25+
auto *in_var = ctx.InputVar("X");
26+
auto *out_var = ctx.OutputVar("Out");
27+
if (in_var->IsType<framework::LoDTensor>()) {
28+
const auto &origin_tensor = in_var->Get<framework::LoDTensor>();
29+
auto *detach_tensor = out_var->GetMutable<framework::LoDTensor>();
30+
detach_tensor->ShareDataWith(origin_tensor);
31+
} else {
32+
const auto &origin_selected_rows = in_var->Get<framework::SelectedRows>();
33+
auto *detach_selected_rows =
34+
out_var->GetMutable<framework::SelectedRows>();
35+
detach_selected_rows->mutable_value()->ShareDataWith(
36+
origin_selected_rows.value());
37+
}
38+
}
39+
};
40+
} // namespace operators
41+
} // namespace paddle

python/paddle/fluid/framework.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -947,35 +947,43 @@ def __init__(self,
947947
self._stop_gradient = stop_gradient
948948
self.is_data = is_data
949949

950-
@fake_interface_only
951950
def detach(self):
952951
"""
953-
**Notes**:
954-
**This API is ONLY available in Dygraph mode**
955-
956952
Returns a new Variable, detached from the current graph.
953+
It will share data with origin Variable and without tensor copy.
954+
In addition, the detached Variable doesn't provide gradient propagation.
957955
958956
Returns:
959957
( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.
960958
961-
962959
Examples:
963960
.. code-block:: python
964961
965-
import paddle.fluid as fluid
966-
from paddle.fluid.dygraph.base import to_variable
967-
from paddle.fluid.dygraph import Linear
968-
import numpy as np
962+
import paddle
969963
970-
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
971-
with fluid.dygraph.guard():
972-
linear = Linear(32, 64)
973-
data = to_variable(data)
974-
x = linear(data)
975-
y = x.detach()
964+
paddle.enable_static()
965+
966+
# create a static Variable
967+
x = paddle.static.data(name='x', shape=[3, 2, 1])
976968
969+
# create a detached Variable
970+
y = x.detach()
977971
"""
978-
pass
972+
973+
assert self.type == core.VarDesc.VarType.SELECTED_ROWS or \
974+
self.type == core.VarDesc.VarType.LOD_TENSOR, \
975+
"only support a variable with SELECTED_ROWS or LOD_TENSOR to be detached"
976+
977+
output = self.block.create_var(
978+
name=unique_name.generate_with_ignorable_key("detach_" + self.name),
979+
dtype=self.dtype,
980+
type=self.type,
981+
persistable=self.persistable,
982+
stop_gradient=True)
983+
984+
self.block.append_op(
985+
type='share_data', inputs={'X': [self]}, outputs={'Out': [output]})
986+
return output
979987

980988
@fake_interface_only
981989
def numpy(self):
@@ -1810,6 +1818,35 @@ def set_value(self, value, scope=None):
18101818

18111819
t.set(value, place)
18121820

1821+
def size(self):
1822+
"""
1823+
Returns the number of elements for current Variable, which is a int64 Variable with shape [1]
1824+
1825+
Returns:
1826+
Variable: the number of elements for current Variable
1827+
1828+
Examples:
1829+
.. code-block:: python
1830+
1831+
import paddle
1832+
1833+
paddle.enable_static()
1834+
1835+
# create a static Variable
1836+
x = paddle.static.data(name='x', shape=[3, 2, 1])
1837+
1838+
# get the number of elements of the Variable
1839+
y = x.size()
1840+
"""
1841+
1842+
output = self.block.create_var(
1843+
name=unique_name.generate_with_ignorable_key(self.name + "_size"),
1844+
dtype=core.VarDesc.VarType.INT64)
1845+
1846+
self.block.append_op(
1847+
type='size', inputs={'Input': [self]}, outputs={'Out': [output]})
1848+
return output
1849+
18131850

18141851
def get_all_op_protos():
18151852
"""

python/paddle/fluid/layers/math_op_patch.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"__rpow__": "A **= B",
4646
"__floordiv__": "A //B",
4747
"__mod__": "A % B",
48+
"__matmul__": "A @ B",
4849
"__eq__": "A == B",
4950
"__ne__": "A != B",
5051
"__lt__": "A < B",
@@ -195,6 +196,28 @@ def _scalar_op_(var, scale, bias):
195196
def _neg_(var):
196197
return _scalar_op_(var, -1.0, 0.0)
197198

199+
@property
200+
def _ndim_(self):
201+
"""
202+
Returns the dimension of current Variable
203+
204+
Returns:
205+
the dimension
206+
207+
Examples:
208+
.. code-block:: python
209+
210+
import paddle
211+
212+
paddle.enable_static()
213+
214+
# create a static Variable
215+
x = paddle.static.data(name='x', shape=[3, 2, 1])
216+
# print the dimension of the Variable
217+
print(x.ndim)
218+
"""
219+
return len(self.shape)
220+
198221
def _scalar_add_(var, value):
199222
return _scalar_op_(var, 1.0, value)
200223

@@ -228,17 +251,17 @@ def __impl__(self, other_var):
228251
other_var = float(other_var)
229252
# division is a special case
230253
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
231-
# the division result can only guarantee the numerical accuracy of 6 digits
232-
# after the decimal point. The result of numpy calculation is of float64 type,
233-
# so the calculation result here and the calculation result of numpy are
254+
# the division result can only guarantee the numerical accuracy of 6 digits
255+
# after the decimal point. The result of numpy calculation is of float64 type,
256+
# so the calculation result here and the calculation result of numpy are
234257
# different after 6 decimal point. If necessary, we can also use float64 here.
235258
# torch's behavior here is consistent with ours
236259
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
237260
self = astype(self, 'float32')
238261
# here use `scale` replace `elementwise` to get better performance
239262
# but only +, -, * can use this method
240263
# NOTE(chentianyu03): / can not use `scale` method,because the result of
241-
# `scale` method (self*(1/other_var)) do not exactly equal with the result
264+
# `scale` method (self*(1/other_var)) do not exactly equal with the result
242265
# of `elementwise_div` method.
243266
if scalar_method is not None:
244267
return scalar_method(self, other_var)
@@ -321,6 +344,9 @@ def __impl__(self, other_var):
321344
# b=-a
322345
('__neg__', _neg_),
323346
('astype', astype),
347+
('dim', lambda x: len(x.shape)),
348+
('ndimension', lambda x: len(x.shape)),
349+
('ndim', _ndim_),
324350
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
325351
_scalar_add_)),
326352
# a+b == b+a. Do not need to reverse explicitly
@@ -347,6 +373,8 @@ def __impl__(self, other_var):
347373
'elementwise_floordiv', False, None)),
348374
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
349375
None)),
376+
('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False,
377+
None)),
350378
# for logical compare
351379
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
352380
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),

python/paddle/fluid/tests/unittests/test_detach.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,6 @@ def test_NoDetachSingle_DetachMulti(self):
149149
array_detach_multi = self.detach_multi()
150150
assert np.array_equal(array_no_detach_single, array_detach_multi)
151151

152-
def test_detach_exception(self):
153-
x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32')
154-
y = fluid.layers.fc(input=x, size=10, bias_attr=True)
155-
with self.assertRaises(AssertionError):
156-
y_detach = y.detach()
157-
158152

159153
class TestInplace(unittest.TestCase):
160154
def test_forward_version(self):

python/paddle/fluid/tests/unittests/test_math_op_patch.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def test_astype(self):
271271
fetch_list=[b])
272272
self.assertTrue(numpy.allclose(a_np.astype('float32'), b_np))
273273

274-
@prog_scope()
275274
def test_bitwise_and(self):
276275
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
277276
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
@@ -336,6 +335,28 @@ def test_bitwise_not(self):
336335
fetch_list=[z])
337336
self.assertTrue(np.array_equal(out[0], out_np))
338337

338+
@prog_scope()
339+
def test_ndim(self):
340+
a = paddle.static.data(name="a", shape=[10, 1])
341+
self.assertEqual(a.dim(), 2)
342+
self.assertEqual(a.ndimension(), 2)
343+
self.assertEqual(a.ndim, 2)
344+
345+
@prog_scope()
346+
def test_matmul(self):
347+
a = paddle.static.data(name='a', shape=[2, 3], dtype='float32')
348+
b = paddle.static.data(name='b', shape=[3, 5], dtype='float32')
349+
c = a @b # __matmul__
350+
a_np = numpy.random.uniform(-1, 1, size=[2, 3]).astype('float32')
351+
b_np = numpy.random.uniform(-1, 1, size=[3, 5]).astype('float32')
352+
place = paddle.CPUPlace()
353+
exe = paddle.static.Executor(place)
354+
c_np = exe.run(paddle.static.default_main_program(),
355+
feed={"a": a_np,
356+
"b": b_np},
357+
fetch_list=[c])
358+
self.assertTrue(numpy.allclose(a_np @b_np, c_np))
359+
339360

340361
if __name__ == '__main__':
341362
unittest.main()

0 commit comments

Comments
 (0)