Skip to content

Commit f0151d0

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_log2
2 parents faedba2 + c42e656 commit f0151d0

File tree

73 files changed

+1768
-491
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+1768
-491
lines changed

cmake/cblas.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ if(NOT DEFINED CBLAS_PROVIDER AND WITH_SYSTEM_BLAS)
102102
find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
103103
${REFERENCE_CBLAS_LIB_SEARCH_PATHS})
104104
find_library(REFERENCE_BLAS_LIBRARY NAMES blas PATHS
105-
${REFERENCE_BLAS_LIB_SEARCH_PATHS})
105+
${REFERENCE_CBLAS_LIB_SEARCH_PATHS})
106106

107107
if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
108108
set(CBLAS_PROVIDER REFERENCE_CBLAS)
@@ -127,9 +127,9 @@ endif()
127127
# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
128128

129129
include_directories(${CBLAS_INC_DIR})
130-
if(NOT ${CBLAS_PROVIDER} STREQUAL MKLML)
131-
target_link_libraries(cblas ${CBLAS_LIBRARIES})
132-
elseif(${CBLAS_PROVIDER} STREQUAL REFERENCE_CBLAS)
130+
if(${CBLAS_PROVIDER} STREQUAL REFERENCE_CBLAS)
133131
target_link_libraries(cblas gfortran ${CBLAS_LIBRARIES} ${REFERENCE_BLAS_LIBRARY})
132+
elseif(NOT ${CBLAS_PROVIDER} STREQUAL MKLML)
133+
target_link_libraries(cblas ${CBLAS_LIBRARIES})
134134
endif()
135135

cmake/cupti.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ find_path(CUPTI_INCLUDE_DIR cupti.h
88
PATHS ${CUPTI_ROOT} ${CUPTI_ROOT}/include
99
$ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/include
1010
${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/include
11+
${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include
1112
NO_DEFAULT_PATH
1213
)
1314

@@ -27,6 +28,7 @@ list(APPEND CUPTI_CHECK_LIBRARY_DIRS
2728
$ENV{CUPTI_ROOT}/lib64
2829
$ENV{CUPTI_ROOT}/lib
2930
/usr/lib
31+
${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib64
3032
${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/lib64)
3133
find_library(CUPTI_LIBRARY NAMES libcupti.so libcupti.dylib # libcupti_static.a
3234
PATHS ${CUPTI_CHECK_LIBRARY_DIRS} ${CUPTI_INCLUDE_DIR} ${__libpath_hist}

cmake/external/pybind11.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ ExternalProject_Add(
3434
"${PYBIND_DOWNLOAD_CMD}"
3535
PREFIX ${PYBIND_PREFIX_DIR}
3636
SOURCE_DIR ${PYBIND_SOURCE_DIR}
37+
# If we explicitly leave the `UPDATE_COMMAND` of the ExternalProject_Add
38+
# function in CMakeLists blank, it will cause another parameter GIT_TAG
39+
# to be modified without triggering incremental compilation, and the
40+
# third-party library version changes cannot be incorporated.
41+
# reference: https://cmake.org/cmake/help/latest/module/ExternalProject.html
3742
UPDATE_COMMAND ""
3843
CONFIGURE_COMMAND ""
3944
BUILD_COMMAND ""

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ function(pass_library TARGET DEST)
2323

2424
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
2525
if(pass_library_DIR)
26-
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
26+
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
2727
else()
28-
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
28+
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
2929
endif()
3030

3131
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2102,7 +2102,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
21022102
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
21032103
std::unordered_set<std::string> supported_op_types =
21042104
std::unordered_set<std::string>(
2105-
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2"});
2105+
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2", "sum"});
21062106
if (!bfloat16_enabled_op_types.empty()) {
21072107
supported_op_types = bfloat16_enabled_op_types;
21082108
}

paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
147147
} // namespace paddle
148148
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
149149
paddle::framework::ir::ConvBiasFusePass);
150-
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
151-
paddle::framework::ir::Conv2DTransposeBiasFusePass);
152-
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
153-
paddle::framework::ir::Conv3DBiasFusePass);
154150
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
155151
.AddCombination(
156152
paddle::framework::compatible::OpVersionComparatorCombination()
157153
.EQ("conv2d", 0)
158154
.EQ("elementwise_add", 0));
155+
156+
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
157+
paddle::framework::ir::Conv2DTransposeBiasFusePass);
158+
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
159+
.AddCombination(
160+
paddle::framework::compatible::OpVersionComparatorCombination()
161+
.EQ("conv2d_transpose", 0)
162+
.EQ("elementwise_add", 0));
163+
164+
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
165+
paddle::framework::ir::Conv3DBiasFusePass);

paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
4444
op->SetInput("X", {inputs[0]});
4545
} else if (type == "reshape2") {
4646
op->SetInput("X", {inputs[0]});
47+
} else if (type == "sum") {
48+
op->SetInput("X", {inputs[0], inputs[1]});
4749
} else {
4850
FAIL() << "Unexpected operator type.";
4951
}
@@ -61,8 +63,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
6163
ProgramDesc BuildProgramDesc() {
6264
ProgramDesc prog;
6365

64-
for (auto& v : std::vector<std::string>(
65-
{"a", "b", "c", "f", "g", "h", "k", "l", "m", "n", "o", "p"})) {
66+
for (auto& v :
67+
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
68+
"n", "o", "p", "r", "s"})) {
6669
prog.MutableBlock(0)->Var(v);
6770
}
6871

@@ -75,6 +78,7 @@ ProgramDesc BuildProgramDesc() {
7578
SetOp(&prog, "concat", "concat2", {"l", "m"}, {"n"});
7679
SetOp(&prog, "transpose2", "transpose", {"n"}, {"o"});
7780
SetOp(&prog, "reshape2", "reshape", {"o"}, {"p"});
81+
SetOp(&prog, "sum", "sum", {"p", "r"}, {"s"});
7882

7983
return prog;
8084
}
@@ -122,15 +126,15 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
122126
}
123127

124128
TEST(Bfloat16PlacementPass, enable_all) {
125-
MainTest({"conv2d", "pool2d", "relu", "concat"}, 7);
129+
MainTest({"conv2d", "pool2d", "relu", "concat", "sum"}, 8);
126130
}
127131

128132
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
129133
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
130134
MainTest({"conv2d", "pool2d"}, 3);
131135
}
132136

133-
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(5); }
137+
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(6); }
134138

135139
} // namespace ir
136140
} // namespace framework

paddle/fluid/framework/op_version_registry.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,75 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/op_version_registry.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace compatible {
20+
21+
namespace {
22+
template <OpUpdateType type__, typename InfoType>
23+
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
24+
return new OpUpdate<InfoType, type__>(info);
25+
}
26+
}
27+
28+
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
29+
const std::string& remark,
30+
const OpAttrVariantT& default_value) {
31+
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
32+
OpAttrInfo(name, remark, default_value)));
33+
return std::move(*this);
34+
}
35+
36+
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
37+
const std::string& remark,
38+
const OpAttrVariantT& default_value) {
39+
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
40+
OpAttrInfo(name, remark, default_value)));
41+
return std::move(*this);
42+
}
43+
44+
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
45+
const std::string& remark) {
46+
infos_.emplace_back(
47+
new_update<OpUpdateType::kNewInput>(OpInputOutputInfo(name, remark)));
48+
return std::move(*this);
49+
}
50+
51+
OpVersionDesc&& OpVersionDesc::NewOutput(const std::string& name,
52+
const std::string& remark) {
53+
infos_.emplace_back(
54+
new_update<OpUpdateType::kNewOutput>(OpInputOutputInfo(name, remark)));
55+
return std::move(*this);
56+
}
57+
58+
OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
59+
const std::string& remark) {
60+
infos_.emplace_back(new_update<OpUpdateType::kBugfixWithBehaviorChanged>(
61+
OpBugfixInfo(remark)));
62+
return std::move(*this);
63+
}
64+
65+
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
66+
PADDLE_ENFORCE_EQ(
67+
op_version_map_.find(op_type), op_version_map_.end(),
68+
platform::errors::AlreadyExists(
69+
"'%s' is registered in operator version more than once.", op_type));
70+
op_version_map_.insert(
71+
std::pair<std::string, OpVersion>{op_type, OpVersion()});
72+
return op_version_map_[op_type];
73+
}
74+
uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const {
75+
PADDLE_ENFORCE_NE(
76+
op_version_map_.count(op_type), 0,
77+
platform::errors::InvalidArgument(
78+
"The version of operator type %s has not been registered.", op_type));
79+
return op_version_map_.find(op_type)->second.version_id();
80+
}
81+
82+
// Provide a fake registration item for pybind testing.
83+
#include "paddle/fluid/framework/op_version_registry.inl"
84+
85+
} // namespace compatible
86+
} // namespace framework
87+
} // namespace paddle

0 commit comments

Comments
 (0)