Skip to content

Commit 2bfbf62

Browse files
committed
enabled split op for inference
1 parent e72f151 commit 2bfbf62

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()(
22622262
PDNode *patterns::Bfloat16Placement::operator()(
22632263
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
22642264
std::unordered_set<std::string> supported_op_types =
2265-
std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose",
2266-
"elementwise_add", "elementwise_mul",
2267-
"fc", "fusion_gru", "gelu", "layer_norm",
2268-
"matmul", "pool2d", "relu", "reshape2",
2269-
"softmax", "sum", "transpose2"});
2265+
std::unordered_set<std::string>(
2266+
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
2267+
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
2268+
"matmul", "pool2d", "relu", "reshape2", "softmax", "split", "sum",
2269+
"transpose2"});
22702270
if (!bfloat16_enabled_op_types.empty()) {
22712271
supported_op_types = bfloat16_enabled_op_types;
22722272
}

paddle/fluid/operators/split_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ This operator splits the input tensor into multiple sub-tensors.
148148
AddAttr<bool>("use_mkldnn",
149149
"(bool, default false) Only used in mkldnn kernel")
150150
.SetDefault(false);
151+
AddAttr<std::string>(
152+
"mkldnn_data_type",
153+
"(string, default \"float32\"). Data type of mkldnn kernel")
154+
.SetDefault("float32")
155+
.InEnum({"float32", "bfloat16"});
151156
}
152157
};
153158

python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def setUp(self):
4141
self.num = 0
4242
self.init_data()
4343
self.inputs = {'X': self.x}
44-
self.attrs = {'use_mkldnn': True, 'num': self.num}
44+
self.attrs = {
45+
'use_mkldnn': True,
46+
'num': self.num,
47+
'mkldnn_data_type': "bfloat16"
48+
}
4549

4650
if self.axis is not None:
4751
self.attrs['axis'] = self.axis
@@ -56,7 +60,7 @@ def setUp(self):
5660
for i in range(len(self.out))]}
5761

5862
def test_check_output(self):
59-
self.check_output(check_dygraph=False)
63+
self.check_output_with_place(core.CPUPlace())
6064

6165

6266
# TODO jakpiase enable grad check(concat op)

0 commit comments

Comments
 (0)