Skip to content

Commit 4ec1c2c

Browse files
authored
Merge branch 'op2func_refactor' into dev/op2func_refactor_
2 parents 864e602 + 1dd0145 commit 4ec1c2c

28 files changed

+512
-618
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 92 additions & 226 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/operator.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
116116
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
117117
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
118118

119-
OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key);
120-
121119
class ExecutionContext;
122120
class OperatorBase;
123121

@@ -534,13 +532,15 @@ class OperatorWithKernel : public OperatorBase {
534532
}
535533

536534
/* member functions for adapting to tcmpt lib */
537-
// TODO(chenweihang): Temporarily as a class method
538-
virtual pt::KernelKey ConstructPtKernelKey(
539-
const VariableValueMap& inputs, const AttributeMap& attrs,
540-
const platform::Place& ctx_place) const;
541-
542-
virtual pt::KernelContext ConstructPtKernelContext(
543-
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const;
535+
/** In the Tensor calculation library, the new Kernel adopts a clearer and
536+
* more streamlined design. The arguments of the Kernel and the input and
537+
* output arguments registered in the original OpMaker do not match in some
538+
* cases, so we use map to record the arguments required by the kernel.
539+
* When selecting Kernel during Op execution, select the arguments of the
540+
* original Op according to the GetExpectedPtKernelArgs returned arguments.
541+
*/
542+
virtual KernelSignature GetExpectedPtKernelArgs(
543+
const ExecutionContext& ctx) const;
544544

545545
private:
546546
void RunImpl(const Scope& scope, const platform::Place& place) const final;
@@ -563,8 +563,9 @@ class OperatorWithKernel : public OperatorBase {
563563
const std::vector<std::string>& inplace_vars,
564564
const Scope& exec_scope) const;
565565

566-
void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
567-
const platform::Place& place) const;
566+
OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const;
567+
568+
void ChooseKernel(const ExecutionContext& ctx) const;
568569

569570
void HandleComplexGradToRealGrad(const Scope& scope,
570571
RuntimeContext* ctx) const;
@@ -582,8 +583,10 @@ class OperatorWithKernel : public OperatorBase {
582583
const std::string& name) const;
583584

584585
/* member functions for adapting to tcmpt lib */
585-
void ChoosePtKernel(const RuntimeContext& ctx,
586-
const platform::DeviceContext& dev_ctx) const;
586+
void ChoosePtKernel(const ExecutionContext& ctx) const;
587+
588+
pt::KernelContext BuildPtKernelContext(
589+
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const;
587590

588591
protected:
589592
mutable std::unique_ptr<OpKernelType> kernel_type_;
@@ -595,10 +598,11 @@ class OperatorWithKernel : public OperatorBase {
595598
mutable bool all_kernels_must_compute_runtime_shape_ = false;
596599
mutable std::mutex cache_update_mutex_;
597600
mutable bool enable_cache_transfer_scope_ = false;
598-
// TODO(chenweihang): Similar duplicate members are used for new tcmpt lib,
599-
// maybe we have better impl methods
601+
// NOTE(chenweihang): Similar op members are used to adapt to
602+
// new tcmpt kernel, if there is a better design in the future,
603+
// we may polish the implementation here
600604
mutable bool run_pt_kernel_ = false;
601-
mutable std::unique_ptr<pt::KernelKey> pt_kernel_key_;
605+
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
602606
mutable std::unique_ptr<pt::Kernel> pt_kernel_;
603607
};
604608

paddle/fluid/framework/tcmpt_utils.cc

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <sstream>
16+
1517
#include "paddle/fluid/framework/tcmpt_utils.h"
1618

1719
#include "paddle/fluid/framework/lod_tensor.h"
1820
#include "paddle/fluid/framework/selected_rows.h"
1921
#include "paddle/fluid/framework/variable.h"
22+
#include "paddle/fluid/string/string_helper.h"
2023

2124
namespace paddle {
2225
namespace framework {
@@ -62,7 +65,7 @@ std::shared_ptr<pt::DenseTensor> MakeTensorImpl<pt::DenseTensor>(
6265
proto::VarType::Type type) {
6366
return MakeTensorImpl<pt::DenseTensor, LoDTensor>(
6467
tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type),
65-
pt::TransToPtLayout(tensor.layout()));
68+
pt::TransToPtDataLayout(tensor.layout()));
6669
}
6770

6871
template <>
@@ -71,7 +74,7 @@ std::shared_ptr<pt::DenseTensor> MakeTensorImpl<pt::DenseTensor>(
7174
proto::VarType::Type type) {
7275
return MakeTensorImpl<pt::DenseTensor, Tensor>(
7376
tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type),
74-
pt::TransToPtLayout(tensor.layout()));
77+
pt::TransToPtDataLayout(tensor.layout()));
7578
}
7679

7780
std::shared_ptr<tcmpt::TensorBase> InputVariableToPtTensor(
@@ -150,5 +153,115 @@ std::shared_ptr<tcmpt::TensorBase> OutputVariableToPtTensor(
150153
return nullptr;
151154
}
152155

156+
OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key) {
157+
proto::VarType::Type data_type = pt::TransToProtoVarType(kernel_key.dtype());
158+
platform::Place place = pt::TransToFluidPlace(kernel_key.backend());
159+
DataLayout data_layout = pt::TransToFluidDataLayout(kernel_key.layout());
160+
LibraryType library_type = LibraryType::kPlain;
161+
if (kernel_key.backend() == pt::Backend::kMKLDNN) {
162+
library_type = LibraryType::kMKLDNN;
163+
} else if (kernel_key.backend() == pt::Backend::kCUDNN) {
164+
library_type = LibraryType::kCUDNN;
165+
} else {
166+
// do nothing
167+
}
168+
// TODO(chenweihang): the customized_type_value is lost
169+
return OpKernelType(data_type, place, data_layout, library_type);
170+
}
171+
172+
pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type) {
173+
pt::Backend backend = pt::TransToPtBackend(kernel_type.place_);
174+
if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
175+
backend = pt::Backend::kMKLDNN;
176+
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
177+
backend = pt::Backend::kCUDNN;
178+
} else {
179+
// do
180+
}
181+
pt::DataLayout layout = pt::TransToPtDataLayout(kernel_type.data_layout_);
182+
pt::DataType dtype = pt::TransToPtDataType(kernel_type.data_type_);
183+
return pt::KernelKey(backend, layout, dtype);
184+
}
185+
186+
KernelSignatureMap& KernelSignatureMap::Instance() {
187+
static KernelSignatureMap g_kernel_signature_map;
188+
return g_kernel_signature_map;
189+
}
190+
191+
const paddle::SmallVector<std::string>&
192+
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
193+
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
194+
auto& in = op_proto_->inputs()[i];
195+
auto& in_name = in.name();
196+
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
197+
VLOG(1) << "Parse PtKernel input: skip extra & quant input - " << in_name;
198+
continue;
199+
}
200+
// If contains dispensable input, we should override the
201+
// GetExpectedPtKernelArgs method self
202+
if (in.has_dispensable() && in.dispensable()) {
203+
VLOG(1) << "Parse PtKernel input: skip dispensable input - " << in_name;
204+
continue;
205+
}
206+
VLOG(1) << "Parse PtKernel input: " << in_name;
207+
input_names_.emplace_back(in_name);
208+
}
209+
return input_names_;
210+
}
211+
212+
const paddle::SmallVector<std::string>&
213+
KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
214+
for (int i = 0; i < op_proto_->outputs_size(); ++i) {
215+
auto& out = op_proto_->outputs()[i];
216+
auto& out_name = out.name();
217+
// TODO(chenweihang): outputs also need skip some cases
218+
VLOG(1) << "Parse PtKernel output: " << out_name;
219+
output_names_.emplace_back(out_name);
220+
}
221+
return output_names_;
222+
}
223+
224+
const paddle::SmallVector<std::string>&
225+
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
226+
for (int i = 0; i < op_proto_->attrs_size(); ++i) {
227+
auto& attr = op_proto_->attrs()[i];
228+
auto& attr_name = attr.name();
229+
if (attr_name == "use_mkldnn" || attr_name == "op_role" ||
230+
attr_name == "op_role_var" || attr_name == "op_namescope" ||
231+
attr_name == "op_callstack" || attr_name == "op_device") {
232+
VLOG(1) << "Parse PtKernel attribute: skip needless attr - " << attr_name;
233+
continue;
234+
}
235+
if ((attr.has_extra() && attr.extra()) ||
236+
(attr.has_quant() && attr.quant())) {
237+
VLOG(1) << "Parse PtKernel attribute: skip extra & quant attr - "
238+
<< attr_name;
239+
continue;
240+
}
241+
VLOG(1) << "Parse PtKernel attribute: " << attr_name;
242+
attr_names_.emplace_back(attr_name);
243+
}
244+
245+
return attr_names_;
246+
}
247+
248+
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
249+
return std::make_pair(
250+
op_proto_->type(),
251+
std::make_tuple(GetInputArgsNames(), GetAttrsArgsNames(),
252+
GetOutputArgsNames()));
253+
}
254+
255+
std::string KernelSignatureToString(const KernelSignature& signature) {
256+
std::stringstream os;
257+
os << "Kernel Signature - name: " << signature.first << "; inputs: "
258+
<< string::join_strings(std::get<0>(signature.second), ", ")
259+
<< "; attributes: "
260+
<< string::join_strings(std::get<1>(signature.second), ", ")
261+
<< "; outputs: "
262+
<< string::join_strings(std::get<2>(signature.second), ", ");
263+
return os.str();
264+
}
265+
153266
} // namespace framework
154267
} // namespace paddle

paddle/fluid/framework/tcmpt_utils.h

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <unordered_map>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/framework.pb.h"
22+
#include "paddle/fluid/framework/op_kernel_type.h"
1723
#include "paddle/fluid/framework/tensor.h"
24+
#include "paddle/fluid/imperative/type_defs.h"
25+
#include "paddle/fluid/platform/macros.h"
1826
#include "paddle/fluid/platform/place.h"
19-
2027
#include "paddle/tcmpt/api/include/core.h"
28+
#include "paddle/utils/flat_hash_map.h"
29+
#include "paddle/utils/small_vector.h"
2130

2231
namespace paddle {
2332
namespace framework {
2433

34+
/* tensor translate */
35+
2536
template <typename PtTensorImplT, typename VariableT>
2637
std::shared_ptr<PtTensorImplT> MakeTensorImpl(const VariableT& tensor,
2738
pt::Backend backend,
@@ -49,5 +60,74 @@ std::shared_ptr<tcmpt::TensorBase> InputVariableToPtTensor(
4960
std::shared_ptr<tcmpt::TensorBase> OutputVariableToPtTensor(
5061
framework::Variable* variable, const pt::TensorArgDef& arg_def);
5162

63+
/* Kernel Key translate */
64+
65+
OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key);
66+
pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type);
67+
68+
/* Kernel Args parse */
69+
70+
// TODO(chenweihang): we can generate this map by proto info in compile time
71+
class KernelSignatureMap {
72+
public:
73+
static KernelSignatureMap& Instance();
74+
75+
bool Has(const std::string& op_type) const {
76+
return map_.find(op_type) != map_.end();
77+
}
78+
79+
void Insert(const std::string& op_type, const KernelSignature& signature) {
80+
if (!Has(op_type)) {
81+
map_.insert({op_type, signature});
82+
}
83+
}
84+
85+
const KernelSignature* GetNullable(const std::string& op_type) const {
86+
auto it = map_.find(op_type);
87+
if (it == map_.end()) {
88+
return nullptr;
89+
} else {
90+
return &it->second;
91+
}
92+
}
93+
94+
private:
95+
KernelSignatureMap() = default;
96+
paddle::flat_hash_map<std::string, KernelSignature> map_;
97+
98+
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);
99+
};
100+
101+
class KernelArgsNameMaker {
102+
public:
103+
virtual ~KernelArgsNameMaker() {}
104+
virtual const paddle::SmallVector<std::string>& GetInputArgsNames() = 0;
105+
virtual const paddle::SmallVector<std::string>& GetOutputArgsNames() = 0;
106+
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
107+
};
108+
109+
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
110+
public:
111+
explicit KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto)
112+
: op_proto_(op_proto) {}
113+
114+
~KernelArgsNameMakerByOpProto() {}
115+
116+
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
117+
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
118+
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
119+
120+
KernelSignature GetKernelSignature();
121+
122+
private:
123+
framework::proto::OpProto* op_proto_;
124+
125+
paddle::SmallVector<std::string> input_names_;
126+
paddle::SmallVector<std::string> output_names_;
127+
paddle::SmallVector<std::string> attr_names_;
128+
};
129+
130+
std::string KernelSignatureToString(const KernelSignature& signature);
131+
52132
} // namespace framework
53133
} // namespace paddle

paddle/fluid/framework/tcmpt_utils_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,19 @@ TEST(TcmptUtils, VarToPtTensor) {
4949
auto* data =
5050
value->mutable_data<int>(make_ddim({1, 1}), paddle::platform::CPUPlace());
5151
data[0] = 123;
52-
auto tensor_def = pt::TensorArgDef(pt::Backend::kCUDA, pt::DataLayout::kNCHW,
52+
pt::Backend expect_backend = pt::Backend::kCPU;
53+
54+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
55+
expect_backend = pt::Backend::kCUDA;
56+
#endif
57+
auto tensor_def = pt::TensorArgDef(expect_backend, pt::DataLayout::kNCHW,
5358
pt::DataType::kINT32);
5459
// 2. test API
5560
auto tensor_x = InputVariableToPtTensor(v, tensor_def);
5661
// 3. check result
57-
ASSERT_EQ(tensor_x->backend(), pt::Backend::kCUDA);
62+
ASSERT_EQ(tensor_x->backend(), expect_backend);
5863
ASSERT_EQ(tensor_x->data_type(), pt::DataType::kINT32);
64+
5965
}
6066

6167
} // namespace framework

paddle/fluid/framework/type_defs.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ limitations under the License. */
1717
#include <map>
1818
#include <memory>
1919
#include <string>
20+
#include <tuple>
2021
#include <unordered_map>
2122
#include <unordered_set>
2223
#include <vector>
2324
#include "paddle/fluid/imperative/type_defs.h"
2425
#include "paddle/fluid/platform/variant.h"
26+
#include "paddle/utils/small_vector.h"
2527

2628
namespace paddle {
2729
namespace framework {
@@ -82,5 +84,13 @@ using InferShapeFN = std::function<void(InferShapeContext*)>;
8284
using InplacePair = std::unordered_map<std::string, std::string>;
8385
using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;
8486

87+
// tuple(input_names, attr_names, output_names)
88+
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
89+
paddle::SmallVector<std::string>,
90+
paddle::SmallVector<std::string>>;
91+
// TODD(yuanrisheng): impl implicit overload signature, use KernelArgsTuple
92+
// directly
93+
using KernelSignature = std::pair<std::string, KernelArgsTuple>;
94+
8595
} // namespace framework
8696
} // namespace paddle

0 commit comments

Comments
 (0)