diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index d81f0e4ed912d9..53e6b31eec940c 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -511,9 +511,7 @@ static pir::Value AddOneDNN2PaddleLayoutTransferOp( } block->push_back(op); - auto new_in = op->result(0); - - return new_in; + return op->result(0); } #endif @@ -1253,6 +1251,12 @@ phi::KernelKey GetKernelKey( kernel_backend = paddle::experimental::ParseBackend(place); } +#ifdef PADDLE_WITH_DNNL + if (kernel_backend != phi::Backend::ONEDNN && + kernel_layout == phi::DataLayout::ONEDNN) { + kernel_layout = phi::DataLayout::ANY; + } +#endif phi::KernelKey res(kernel_backend, kernel_layout, kernel_dtype); // kernel backend infered incorrectly from memcpy op operands, @@ -1284,6 +1288,11 @@ phi::KernelKey GetKernelKey( if (NeedFallBackCpu((op), kernel_fn_str, res)) { res.set_backend(phi::Backend::CPU); +#ifdef PADDLE_WITH_DNNL + if (res.layout() == phi::DataLayout::ONEDNN) { + res.set_layout(phi::DataLayout::ANY); + } +#endif VLOG(8) << "kernel backend must be on CPU when need fallback"; } @@ -2375,6 +2384,38 @@ std::vector BuildInputs( new_in = AddOneDNN2PaddleLayoutTransferOp( new_in, phi::DataLayout::ANY, block); } + } else if (new_in_type.isa() && + new_in.defining_op()->isa<::pir::CombineOp>()) { + bool need_replace_combine_op = false; + std::vector new_vec_inputs; + std::vector types_in_vec; + for (auto& in : new_in.defining_op()->operands()) { + auto in_value = in.source(); + if (in_value.type().isa()) { + if (in_value.type() + .dyn_cast() + .data_layout() == phi::DataLayout::ONEDNN) { + need_replace_combine_op = true; + in_value = AddOneDNN2PaddleLayoutTransferOp( + in_value, phi::DataLayout::ANY, block); + } + new_vec_inputs.push_back(in_value); + types_in_vec.push_back(in_value.type()); + } + } + if (need_replace_combine_op) { + std::string combine_op_name(pir::CombineOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + pir::Type target_vec_type = pir::VectorType::get(ctx, types_in_vec); + pir::Operation* operation = pir::Operation::Create( + new_vec_inputs, {}, {target_vec_type}, op_info); + new_in.defining_op()->ReplaceAllUsesWith(operation->results()); + block->erase(*new_in.defining_op()); + + new_in = operation->result(0); + block->push_back(operation); + } } } #endif @@ -3048,16 +3089,27 @@ void ProcessBlock( op_item = op_item_inner; op_info_parser = GetOpYamlInfoParser(op_item_inner); kernel_key.set_backend(phi::Backend::ONEDNN); + kernel_key.set_layout(phi::DataLayout::ONEDNN); } } else if (FLAGS_use_mkldnn && kernel_key.backend() == phi::Backend::CPU && !op_item->HasTrait() && - SupportsMKLDNN(kernel_name, phi::DataType::BFLOAT16)) { + SupportsMKLDNN(kernel_name, kernel_key.dtype())) { // Support FLAGS_use_mkldnn auto op_item_inner = PdOp2OneDNNOp(op_item, block, ctx); if (op_item_inner != op_item) { op_item = op_item_inner; op_info_parser = GetOpYamlInfoParser(op_item_inner); kernel_key.set_backend(phi::Backend::ONEDNN); + kernel_key.set_layout(phi::DataLayout::ONEDNN); + } + } else if (kernel_key.backend() == phi::Backend::ONEDNN && + !op_item->HasTrait()) { + auto op_item_inner = PdOp2OneDNNOp(op_item, block, ctx); + if (op_item_inner != op_item) { + op_item = op_item_inner; + op_info_parser = GetOpYamlInfoParser(op_item_inner); + kernel_key.set_backend(phi::Backend::ONEDNN); + kernel_key.set_layout(phi::DataLayout::ONEDNN); } } #endif diff --git a/test/legacy_test/test_static_save_load_bf16.py b/test/legacy_test/test_static_save_load_bf16.py index d898136bbde6ab..fe088936f671f3 100644 --- a/test/legacy_test/test_static_save_load_bf16.py +++ b/test/legacy_test/test_static_save_load_bf16.py @@ -26,6 +26,8 @@ import paddle from paddle import base from paddle.base import core, framework +from paddle.framework.io_utils import is_pir_fetch_var +from paddle.pir_utils import IrGuard @unittest.skipIf( @@ -162,6 +164,135 @@ def test_ptb_rnn_cpu_bfloat16(self): base_t = base_map[var.name] np.testing.assert_array_equal(new_t, base_t) + def test_ptb_rnn_cpu_bfloat16_pir(self): + with IrGuard(): + seed = 90 + hidden_size = 10 + vocab_size = 500 + num_layers = 1 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 100 + + with new_program_scope(): + paddle.seed(seed) + ptb_model = PtbModel( + "ptb_model", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale, + ) + + place = self.set_place() + exe = base.Executor(place) + sgd = paddle.optimizer.SGD(learning_rate=1e-3) + x = paddle.static.data( + name="x", shape=[-1, num_steps], dtype='int64' + ) + y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32') + init_hidden = paddle.static.data( + name="init_hidden", shape=[-1, 1], dtype='float32' + ) + init_cell = paddle.static.data( + name="init_cell", shape=[-1, 1], dtype='float32' + ) + + ptb_model, sgd = paddle.amp.decorate( + models=ptb_model, + optimizers=sgd, + level="O2", + dtype='bfloat16', + ) + + with paddle.amp.auto_cast( + enable=True, + level='O2', + dtype='bfloat16', + custom_black_list={'transpose2', 'concat'}, + use_promote=True, + ): + ( + static_loss, + static_last_hidden, + static_last_cell, + ) = ptb_model(x, y, init_hidden, init_cell) + sgd.minimize(static_loss) + exe.run(paddle.static.default_startup_program()) + + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32' + ) + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32' + ) + + fetch_list = [ + static_loss, + static_last_hidden, + static_last_cell, + ] + + out = exe.run( + paddle.static.default_main_program(), + feed={ + "x": x_data, + "y": y_data, + "init_hidden": init_hidden_data, + "init_cell": init_cell_data, + }, + fetch_list=fetch_list, + ) + + # get value before save + main_program = paddle.static.default_main_program() + base_map = {} + for var in main_program.list_vars(): + if var.persistable and not is_pir_fetch_var(var): + t = np.array( + base.global_scope().find_var(var.name).get_tensor() + ) + # make sure all the parameter or optimizer var have been update + self.assertTrue(np.sum(np.abs(t)) != 0) + base_map[var.name] = t + save_dir = os.path.join(self.temp_dir.name, "test_1") + paddle.static.save(main_program, save_dir) + + # set var to zero + for var in main_program.list_vars(): + if var.persistable and not is_pir_fetch_var(var): + ten = ( + base.global_scope().find_var(var.name).get_tensor() + ) + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array( + base.global_scope().find_var(var.name).get_tensor() + ) + # make sure all the parameter or optimizer var have been set to zero + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + paddle.static.load( + main_program, + os.path.join(self.temp_dir.name, "test_1.pdparams"), + exe, + ) + + for var in main_program.list_vars(): + if var.persistable and not is_pir_fetch_var(var): + new_t = np.array( + base.global_scope().find_var(var.name).get_tensor() + ) + base_t = base_map[var.name] + np.testing.assert_array_equal(new_t, base_t) + if __name__ == '__main__': paddle.enable_static()