Skip to content
60 changes: 56 additions & 4 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,7 @@ static pir::Value AddOneDNN2PaddleLayoutTransferOp(
}

block->push_back(op);
auto new_in = op->result(0);

return new_in;
return op->result(0);
}
#endif

Expand Down Expand Up @@ -1253,6 +1251,12 @@ phi::KernelKey GetKernelKey(
kernel_backend = paddle::experimental::ParseBackend(place);
}

#ifdef PADDLE_WITH_DNNL
if (kernel_backend != phi::Backend::ONEDNN &&
kernel_layout == phi::DataLayout::ONEDNN) {
kernel_layout = phi::DataLayout::ANY;
}
#endif
phi::KernelKey res(kernel_backend, kernel_layout, kernel_dtype);

// kernel backend infered incorrectly from memcpy op operands,
Expand Down Expand Up @@ -1284,6 +1288,11 @@ phi::KernelKey GetKernelKey(

if (NeedFallBackCpu((op), kernel_fn_str, res)) {
res.set_backend(phi::Backend::CPU);
#ifdef PADDLE_WITH_DNNL
if (res.layout() == phi::DataLayout::ONEDNN) {
res.set_layout(phi::DataLayout::ANY);
}
#endif
VLOG(8) << "kernel backend must be on CPU when need fallback";
}

Expand Down Expand Up @@ -2375,6 +2384,38 @@ std::vector<pir::Value> BuildInputs(
new_in = AddOneDNN2PaddleLayoutTransferOp(
new_in, phi::DataLayout::ANY, block);
}
} else if (new_in_type.isa<pir::VectorType>() &&
new_in.defining_op()->isa<::pir::CombineOp>()) {
bool need_replace_combine_op = false;
std::vector<pir::Value> new_vec_inputs;
std::vector<pir::Type> types_in_vec;
for (auto& in : new_in.defining_op()->operands()) {
auto in_value = in.source();
if (in_value.type().isa<AllocatedDenseTensorType>()) {
if (in_value.type()
.dyn_cast<AllocatedDenseTensorType>()
.data_layout() == phi::DataLayout::ONEDNN) {
need_replace_combine_op = true;
in_value = AddOneDNN2PaddleLayoutTransferOp(
in_value, phi::DataLayout::ANY, block);
}
new_vec_inputs.push_back(in_value);
types_in_vec.push_back(in_value.type());
}
}
if (need_replace_combine_op) {
std::string combine_op_name(pir::CombineOp::name());
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

pir::Type target_vec_type = pir::VectorType::get(ctx, types_in_vec);
pir::Operation* operation = pir::Operation::Create(
new_vec_inputs, {}, {target_vec_type}, op_info);
new_in.defining_op()->ReplaceAllUsesWith(operation->results());
block->erase(*new_in.defining_op());

new_in = operation->result(0);
block->push_back(operation);
}
}
}
#endif
Expand Down Expand Up @@ -3048,16 +3089,27 @@ void ProcessBlock(
op_item = op_item_inner;
op_info_parser = GetOpYamlInfoParser(op_item_inner);
kernel_key.set_backend(phi::Backend::ONEDNN);
kernel_key.set_layout(phi::DataLayout::ONEDNN);
}
} else if (FLAGS_use_mkldnn && kernel_key.backend() == phi::Backend::CPU &&
!op_item->HasTrait<OneDNNTrait>() &&
SupportsMKLDNN(kernel_name, phi::DataType::BFLOAT16)) {
SupportsMKLDNN(kernel_name, kernel_key.dtype())) {
// Support FLAGS_use_mkldnn
auto op_item_inner = PdOp2OneDNNOp(op_item, block, ctx);
if (op_item_inner != op_item) {
op_item = op_item_inner;
op_info_parser = GetOpYamlInfoParser(op_item_inner);
kernel_key.set_backend(phi::Backend::ONEDNN);
kernel_key.set_layout(phi::DataLayout::ONEDNN);
}
} else if (kernel_key.backend() == phi::Backend::ONEDNN &&
!op_item->HasTrait<OneDNNTrait>()) {
auto op_item_inner = PdOp2OneDNNOp(op_item, block, ctx);
if (op_item_inner != op_item) {
op_item = op_item_inner;
op_info_parser = GetOpYamlInfoParser(op_item_inner);
kernel_key.set_backend(phi::Backend::ONEDNN);
kernel_key.set_layout(phi::DataLayout::ONEDNN);
}
}
#endif
Expand Down
131 changes: 131 additions & 0 deletions test/legacy_test/test_static_save_load_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import paddle
from paddle import base
from paddle.base import core, framework
from paddle.framework.io_utils import is_pir_fetch_var
from paddle.pir_utils import IrGuard


@unittest.skipIf(
Expand Down Expand Up @@ -162,6 +164,135 @@ def test_ptb_rnn_cpu_bfloat16(self):
base_t = base_map[var.name]
np.testing.assert_array_equal(new_t, base_t)

def test_ptb_rnn_cpu_bfloat16_pir(self):
with IrGuard():
seed = 90
hidden_size = 10
vocab_size = 500
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
batch_num = 100

with new_program_scope():
paddle.seed(seed)
ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale,
)

place = self.set_place()
exe = base.Executor(place)
sgd = paddle.optimizer.SGD(learning_rate=1e-3)
x = paddle.static.data(
name="x", shape=[-1, num_steps], dtype='int64'
)
y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = paddle.static.data(
name="init_hidden", shape=[-1, 1], dtype='float32'
)
init_cell = paddle.static.data(
name="init_cell", shape=[-1, 1], dtype='float32'
)

ptb_model, sgd = paddle.amp.decorate(
models=ptb_model,
optimizers=sgd,
level="O2",
dtype='bfloat16',
)

with paddle.amp.auto_cast(
enable=True,
level='O2',
dtype='bfloat16',
custom_black_list={'transpose2', 'concat'},
use_promote=True,
):
(
static_loss,
static_last_hidden,
static_last_cell,
) = ptb_model(x, y, init_hidden, init_cell)
sgd.minimize(static_loss)
exe.run(paddle.static.default_startup_program())

for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32'
)
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32'
)

fetch_list = [
static_loss,
static_last_hidden,
static_last_cell,
]

out = exe.run(
paddle.static.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data,
},
fetch_list=fetch_list,
)

# get value before save
main_program = paddle.static.default_main_program()
base_map = {}
for var in main_program.list_vars():
if var.persistable and not is_pir_fetch_var(var):
t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
# make sure all the parameter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
save_dir = os.path.join(self.temp_dir.name, "test_1")
paddle.static.save(main_program, save_dir)

# set var to zero
for var in main_program.list_vars():
if var.persistable and not is_pir_fetch_var(var):
ten = (
base.global_scope().find_var(var.name).get_tensor()
)
ten.set(np.zeros_like(np.array(ten)), place)

new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
# make sure all the parameter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)

paddle.static.load(
main_program,
os.path.join(self.temp_dir.name, "test_1.pdparams"),
exe,
)

for var in main_program.list_vars():
if var.persistable and not is_pir_fetch_var(var):
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
base_t = base_map[var.name]
np.testing.assert_array_equal(new_t, base_t)


if __name__ == '__main__':
paddle.enable_static()
Expand Down