Skip to content

Commit 6f4fed2

Browse files
committed
fix conflict
2 parents 5e50e53 + 9d02313 commit 6f4fed2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+4539
-269
lines changed

paddle/fluid/extension/include/ext_tensor.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,20 @@ class PD_DLL_DECL Tensor {
8888
/// It's usually used to set the input tensor data.
8989
/// \param PlaceType of target place, of which
9090
/// the tensor will copy to.
91-
9291
template <typename T>
9392
Tensor copy_to(const PlaceType& place) const;
9493

94+
/// \brief Return a sub-tensor of the given tensor.
95+
/// It is usually used to extract a sub-tensor (which supports
96+
/// modifying the data of the original tensor) to perform further
97+
/// operations.
98+
/// \param begin_idx The index of the start row (inclusive) to slice.
99+
/// The index number begins from 0.
100+
/// \param end_idx The index of the end row (exclusive) to slice.
101+
/// The index number begins from begin_idx + 1.
102+
/// \return The sliced tensor.
103+
Tensor slice(const int64_t begin_idx, const int64_t end_idx) const;
104+
95105
/// \brief Return the shape of the Tensor.
96106
std::vector<int64_t> shape() const;
97107

paddle/fluid/extension/src/ext_tensor.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,21 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
124124
} \
125125
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
126126

127+
#define GET_INNER_PLACE \
128+
platform::Place place; \
129+
switch (place_) { \
130+
case PlaceType::kCPU: \
131+
place = platform::CPUPlace(); \
132+
break; \
133+
case PlaceType::kGPU: \
134+
place = platform::CUDAPlace(); \
135+
break; \
136+
default: \
137+
PADDLE_THROW(platform::errors::Unavailable( \
138+
"Custom operator unsupported place id(%d)", \
139+
static_cast<int>(place_))); \
140+
}
141+
127142
void Tensor::reshape(const std::vector<int64_t> &shape) {
128143
GET_CASTED_TENSOR
129144
auto new_dim = framework::make_ddim(shape);
@@ -257,6 +272,16 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
257272
return target;
258273
}
259274

275+
Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
276+
GET_CASTED_TENSOR
277+
GET_INNER_PLACE
278+
framework::Tensor intermediate = tensor->Slice(begin_idx, end_idx);
279+
Tensor target = Tensor(place_);
280+
framework::CustomTensorUtils::ShareDataFrom(
281+
static_cast<const void *>(&intermediate), target);
282+
return target;
283+
}
284+
260285
template PD_DLL_DECL Tensor
261286
Tensor::copy_to<float>(const PlaceType &target_place) const;
262287
template PD_DLL_DECL Tensor

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op
202202
cc_library(version SRCS version.cc)
203203
cc_test(version_test SRCS version_test.cc DEPS version)
204204

205-
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute shape_inference op_info operator glog version)
205+
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc DEPS attribute shape_inference op_info operator glog version)
206206

207207
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
208208

paddle/fluid/framework/custom_tensor_test.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ void TestAPISizeAndShape() {
9292
CHECK(t1.shape() == tensor_shape);
9393
}
9494

95+
void TestAPISlice() {
96+
std::vector<int64_t> tensor_shape_origin1 = {5, 5};
97+
std::vector<int64_t> tensor_shape_sub1 = {3, 5};
98+
std::vector<int64_t> tensor_shape_origin2 = {5, 5, 5};
99+
std::vector<int64_t> tensor_shape_sub2 = {1, 5, 5};
100+
#ifdef PADDLE_WITH_CUDA
101+
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin1);
102+
t1.mutable_data<float>();
103+
CHECK(t1.slice(0, 5).shape() == tensor_shape_origin1);
104+
CHECK(t1.slice(0, 3).shape() == tensor_shape_sub1);
105+
auto t2 = paddle::Tensor(paddle::PlaceType::kGPU, tensor_shape_origin2);
106+
t2.mutable_data<float>();
107+
CHECK(t2.slice(4, 5).shape() == tensor_shape_sub2);
108+
#endif
109+
auto t3 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin1);
110+
t3.mutable_data<float>();
111+
CHECK(t3.slice(0, 5).shape() == tensor_shape_origin1);
112+
CHECK(t3.slice(0, 3).shape() == tensor_shape_sub1);
113+
auto t4 = paddle::Tensor(paddle::PlaceType::kCPU, tensor_shape_origin2);
114+
t4.mutable_data<float>();
115+
CHECK(t4.slice(4, 5).shape() == tensor_shape_sub2);
116+
117+
// Test writing function for sliced tensor
118+
auto t = InitCPUTensorForTest<float>();
119+
auto t_sliced = t.slice(0, 1);
120+
auto* t_sliced_data_ptr = t_sliced.mutable_data<float>();
121+
for (int64_t i = 0; i < t_sliced.size(); i++) {
122+
t_sliced_data_ptr[i] += static_cast<float>(5);
123+
}
124+
auto* t_data_ptr = t.mutable_data<float>();
125+
for (int64_t i = 0; i < t_sliced.size(); i++) {
126+
CHECK_EQ(t_data_ptr[i], static_cast<float>(10));
127+
}
128+
}
129+
95130
template <typename T>
96131
paddle::DataType TestDtype() {
97132
std::vector<int64_t> tensor_shape = {5, 5};
@@ -261,6 +296,8 @@ TEST(CustomTensor, copyTest) {
261296
TestAPISizeAndShape();
262297
VLOG(2) << "TestPlace";
263298
TestAPIPlace();
299+
VLOG(2) << "TestSlice";
300+
TestAPISlice();
264301
VLOG(2) << "TestCast";
265302
GroupTestCast();
266303
VLOG(2) << "TestDtypeConvert";

paddle/fluid/framework/framework.proto

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ enum AttrType {
3838
FLOAT64S = 12;
3939
}
4040

41+
message ProcessMeshDesc {
42+
required int32 id = 1;
43+
required int32 parent_id = 2;
44+
repeated int32 topology = 3;
45+
repeated int32 process_group = 4;
46+
};
47+
4148
// OpDesc describes an instance of a C++ framework::OperatorBase
4249
// derived class type.
4350
message OpDesc {
@@ -167,6 +174,15 @@ message VarType {
167174
}
168175

169176
message VarDesc {
177+
178+
message Attr {
179+
required string name = 1;
180+
required AttrType type = 2;
181+
optional int32 i = 3;
182+
optional string s = 4;
183+
repeated int32 ints = 5;
184+
};
185+
170186
required string name = 1;
171187
required VarType type = 2;
172188
optional bool persistable = 3 [ default = false ];
@@ -175,6 +191,7 @@ message VarDesc {
175191
optional bool need_check_feed = 4 [ default = false ];
176192
optional bool is_parameter = 5 [ default = false ];
177193
optional bool stop_gradient = 6 [ default = false ];
194+
repeated Attr attrs = 7;
178195
}
179196

180197
message BlockDesc {

paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,43 @@ QuantDequantFusePass::QuantDequantFusePass() {
153153
.AddAttr("data_format")
154154
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
155155
.End();
156+
AddOpCompat(OpCompat("depthwise_conv2d"))
157+
.AddInput("Input")
158+
.IsTensor()
159+
.End()
160+
.AddInput("Filter")
161+
.IsTensor()
162+
.End()
163+
.AddInput("Bias")
164+
.IsTensor()
165+
.IsOptional()
166+
.End()
167+
.AddInput("ResidualData")
168+
.IsTensor()
169+
.IsOptional()
170+
.End()
171+
.AddOutput("Output")
172+
.IsTensor()
173+
.End()
174+
.AddAttr("strides")
175+
.IsType<std::vector<int>>()
176+
.End()
177+
.AddAttr("paddings")
178+
.IsType<std::vector<int>>()
179+
.End()
180+
.AddAttr("padding_algorithm")
181+
.IsOptional()
182+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
183+
.End()
184+
.AddAttr("groups")
185+
.IsNumGE(1)
186+
.End()
187+
.AddAttr("dilations")
188+
.IsType<std::vector<int>>()
189+
.End()
190+
.AddAttr("data_format")
191+
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
192+
.End();
156193
AddOpCompat(OpCompat("mul"))
157194
.AddInput("X")
158195
.IsTensor()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/process_mesh_desc.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
int32_t ProcessMeshDesc::next_id = -1;
21+
22+
ProcessMeshDesc::ProcessMeshDesc(const std::vector<int32_t> &topo,
23+
const std::vector<int32_t> &process_group,
24+
int32_t parent_id) {
25+
int32_t cur_id = ++next_id;
26+
desc_.set_id(cur_id);
27+
desc_.set_parent_id(parent_id);
28+
for (size_t i = 0; i != topo.size(); ++i) {
29+
desc_.add_topology(topo[i]);
30+
}
31+
for (size_t i = 0; i != process_group.size(); ++i) {
32+
desc_.add_process_group(process_group[i]);
33+
}
34+
ProcessMeshDescMap::GetInstance().Insert(cur_id, this);
35+
}
36+
37+
std::vector<int32_t> ProcessMeshDesc::Topology() const {
38+
size_t size = desc_.topology_size();
39+
std::vector<int32_t> ret(size);
40+
for (auto i = 0; i != desc_.topology_size(); ++i) {
41+
ret[i] = desc_.topology(i);
42+
}
43+
return ret;
44+
}
45+
46+
std::vector<int32_t> ProcessMeshDesc::ProcessGroup() const {
47+
size_t size = desc_.process_group_size();
48+
std::vector<int32_t> ret(size);
49+
for (auto i = 0; i != desc_.process_group_size(); ++i) {
50+
ret[i] = desc_.process_group(i);
51+
}
52+
return ret;
53+
}
54+
55+
ProcessMeshDescMap &ProcessMeshDescMap::GetInstance() {
56+
static ProcessMeshDescMap g_process_mesh_desc_map;
57+
return g_process_mesh_desc_map;
58+
}
59+
60+
} // namespace framework
61+
} // namespace paddle
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <unordered_map>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/framework.pb.h"
21+
#include "paddle/fluid/framework/proto_desc.h"
22+
#include "paddle/fluid/platform/enforce.h"
23+
#include "paddle/fluid/platform/macros.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
28+
class ProcessMeshDesc {
29+
public:
30+
ProcessMeshDesc(const std::vector<int32_t>& topo,
31+
const std::vector<int32_t>& process_group, int32_t parent_id);
32+
33+
int32_t ID() const { return desc_.id(); }
34+
int32_t Parent() const { return desc_.parent_id(); }
35+
36+
std::vector<int32_t> Topology() const;
37+
std::vector<int32_t> ProcessGroup() const;
38+
39+
static int32_t next_id;
40+
41+
private:
42+
proto::ProcessMeshDesc desc_; // not_own
43+
};
44+
45+
class ProcessMeshDescMap {
46+
public:
47+
static ProcessMeshDescMap& GetInstance();
48+
49+
bool Has(int32_t index) const { return map_.find(index) != map_.end(); }
50+
51+
void Insert(int32_t index, ProcessMeshDesc* mesh) {
52+
PADDLE_ENFORCE_NE(
53+
Has(index), true,
54+
platform::errors::AlreadyExists("Index (%d) has been used.", index));
55+
map_.insert(std::make_pair(index, mesh));
56+
}
57+
58+
private:
59+
ProcessMeshDescMap() = default;
60+
// Use raw pointer to avoid double free
61+
std::unordered_map<int32_t, ProcessMeshDesc*> map_;
62+
DISABLE_COPY_AND_ASSIGN(ProcessMeshDescMap);
63+
};
64+
} // namespace framework
65+
} // namespace paddle

paddle/fluid/framework/proto_desc.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,13 @@ constexpr int kRootBlockIndex = 0;
2222
// The Parent Index of root Block, this block does not exist.
2323
constexpr int kNoneBlockIndex = -1;
2424

25+
// The Parent Index of root ProcessMesh, this ProcessMesh does not exist.
26+
constexpr int kNoneProcessMeshIndex = -1;
27+
28+
// If a attribute name has a certain suffix, it means that the
29+
// atrribute is a distributed-related attribute for auto parallel.
30+
// e.g., "mesh_id@PARALLEL".
31+
constexpr char kAutoParallelSuffix[] = "@PARALLEL";
32+
2533
} // namespace framework
2634
} // namespace paddle

paddle/fluid/framework/var_desc.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,46 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
280280
}
281281
}
282282

283+
std::vector<std::string> VarDesc::AttrNames() const {
284+
std::vector<std::string> retv;
285+
retv.reserve(attrs_.size());
286+
for (auto &attr : attrs_) {
287+
retv.push_back(attr.first);
288+
}
289+
return retv;
290+
}
291+
292+
void VarDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); }
293+
294+
void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
295+
// NOTICE(sandyhouse): pybind11 will take the empty list in python as
296+
// the std::vector<int> type in C++; so we have to change the attr's type
297+
// here if we meet this issue
298+
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
299+
if (attr_type == proto::AttrType::INTS &&
300+
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
301+
// Find current attr via attr name and set the correct attribute value
302+
this->attrs_[name] = std::vector<int>();
303+
return;
304+
}
305+
bool valid = attr_type == proto::AttrType::INT ||
306+
attr_type == proto::AttrType::STRING ||
307+
attr_type == proto::AttrType::INTS;
308+
PADDLE_ENFORCE_EQ(valid, true, platform::errors::InvalidArgument(
309+
"The value for attr (%s) must be "
310+
"one of list or int or string.",
311+
name));
312+
313+
this->attrs_[name] = v;
314+
}
315+
316+
Attribute VarDesc::GetAttr(const std::string &name) const {
317+
auto it = attrs_.find(name);
318+
PADDLE_ENFORCE_NE(it, attrs_.end(), platform::errors::NotFound(
319+
"Attribute %s is not found.", name));
320+
return it->second;
321+
}
322+
283323
bool operator==(const VarDesc &left, const VarDesc &right) {
284324
return left.Proto()->SerializeAsString() ==
285325
right.Proto()->SerializeAsString();

0 commit comments

Comments
 (0)