diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index 0d3af1e55c2a01..52c8a68322f217 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -536,20 +536,19 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, // has to be created and registered if ((tensor_in->layout() == DataLayout::ONEDNN) && (var->IsType() == true) && - (expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && - (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == - DataLayout::kNHWC)) { + (expected_kernel_key.data_layout_ != DataLayout::ONEDNN)) { VLOG(7) << "Created reshaped dummy input based on MKL-DNN " "phi::DenseTensor , " "but kNHWC layout" << parameter_name << " in Operator " << op_base->Type(); - auto op = TransferLayout(var_name, - &new_var_name, - tensor_in->layout(), - DataLayout::kNHWC, - var_scope, - local_scope, - op_base->Type() == "fetch_v2"); + auto op = TransferLayout( + var_name, + &new_var_name, + tensor_in->layout(), + phi::OneDNNContext::tls().get_cur_paddle_data_layout(), + var_scope, + local_scope, + op_base->Type() == "fetch_v2"); if (op) { data_transfer_helper.RunAndConstructOpFuncNode( op, diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4ae5e0ebdf8720..a1e4aaf8d318d4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1487,6 +1487,37 @@ bool OperatorWithKernel::SupportsCUDNN(const phi::DataType data_type) const { } } +bool OperatorWithKernel::SupportsCPUBF16() const { + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + phi::TransToPhiKernelName(type_)); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::CPU && + kern_pair.first.dtype() == phi::DataType::BFLOAT16; + }); + if (has_phi_kernel) { + return true; + } else { + auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_); + if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) { + return false; + } else { + auto& op_kernels = op_kernel_iter->second; + return std::any_of( + op_kernels.begin(), + op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_cpu_place(kern_pair.first.place_) && + kern_pair.first.place_ == platform::CPUPlace() && + kern_pair.first.data_type_ == + proto::VarType::Type::VarType_Type_BF16; + }); + } + } +} + bool OperatorWithKernel::SupportsKernelType( const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const { auto& all_op_kernels = AllOpKernels(); @@ -1805,6 +1836,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) { kernel_type_->library_type_ = framework::LibraryType::kMKLDNN; kernel_type_->data_layout_ = framework::DataLayout::ONEDNN; + } else if (platform::is_cpu_place(kernel_type_->place_) && + kernel_type_->data_type_ == + proto::VarType::Type::VarType_Type_BF16 && + !this->SupportsCPUBF16() && + this->SupportsMKLDNN(phi::DataType::BFLOAT16)) { + kernel_type_->library_type_ = framework::LibraryType::kMKLDNN; + kernel_type_->data_layout_ = framework::DataLayout::ONEDNN; } #endif @@ -2131,6 +2169,13 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) { expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN; + } else if (platform::is_cpu_place(expected_kernel_key.place_) && + expected_kernel_key.data_type_ == + proto::VarType::Type::VarType_Type_BF16 && + !this->SupportsCPUBF16() && + this->SupportsMKLDNN(phi::DataType::BFLOAT16)) { + expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN; + expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN; } #endif diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d51c0ce0f415d0..7f47ef640c19ce 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -774,6 +774,8 @@ class OperatorWithKernel : public OperatorBase { bool SupportsKernelType(const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const; + bool SupportsCPUBF16() const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, phi::DataType data_type) const; diff --git a/paddle/phi/kernels/onednn/concat_grad_kernel.cc b/paddle/phi/kernels/onednn/concat_grad_kernel.cc index bbc57328ac2d6e..fc36fa4ab0fd86 100644 --- a/paddle/phi/kernels/onednn/concat_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/concat_grad_kernel.cc @@ -52,7 +52,7 @@ void ConcatGradKernel(const Context& dev_ctx, out_grad.mem_desc(), funcs::to_void_cast(out_grad.data())); for (auto& grad : x_grad) { - if (grad->numel() != 0UL) { + if (grad && grad->numel() != 0UL) { auto x_grad_vec_dims = common::vectorize(grad->dims()); auto slice_mem_p = reorder_handler.AcquireSubmemory( x_grad_vec_dims, offset, reorder_src_memory_p); diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index b8cd7a6b8d5d73..6437d346ff321e 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -317,10 +317,7 @@ def _cast_block(self, block): ) elif self._is_fp16_op(op.desc.original_id()) is True: if self.amp_dtype == "bfloat16": - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - elif ( + if ( op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32 ): @@ -361,10 +358,7 @@ def _cast_block(self, block): self._is_fp16_op(op.desc.original_id()) is True ): # fp16/bf16 if self.amp_dtype == "bfloat16": - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - elif ( + if ( op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32 diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 81e659bd4d52fc..0578ae163c3160 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -67,12 +67,6 @@ def set_op_dtype_to_fp16(op): if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: op._set_attr('dtype', __target_dtype__) - if __target_dtype__ == core.VarDesc.VarType.BF16: - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - if op.has_attr('mkldnn_data_type'): - op._set_attr('mkldnn_data_type', 'bfloat16') - # adapot for backward op # TODO check if bf16 and fp16 still share the same logic diff --git a/python/paddle/static/amp/bf16/amp_utils.py b/python/paddle/static/amp/bf16/amp_utils.py index bb8d62d85b8cd2..8a73b2c1f5514a 100644 --- a/python/paddle/static/amp/bf16/amp_utils.py +++ b/python/paddle/static/amp/bf16/amp_utils.py @@ -400,10 +400,6 @@ def cast_model_to_bf16( and op.attr(attr_name) == core.VarDesc.VarType.FP32 ): op._set_attr(attr_name, core.VarDesc.VarType.BF16) - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - if op.has_attr('mkldnn_data_type'): - op._set_attr('mkldnn_data_type', 'bfloat16') if startup_prog is not None: cast_initializers_to_bf16( @@ -593,10 +589,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None): core.VarDesc.VarType.FP32, ) elif op in bf16_op_set: - if op.has_attr('use_mkldnn'): - op._set_attr('use_mkldnn', True) - op._set_attr('mkldnn_data_type', 'bfloat16') - elif ( + if ( op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32 ): diff --git a/python/paddle/static/amp/bf16/decorator.py b/python/paddle/static/amp/bf16/decorator.py index 74e896bc7f7906..4b873ebe5d2122 100644 --- a/python/paddle/static/amp/bf16/decorator.py +++ b/python/paddle/static/amp/bf16/decorator.py @@ -45,8 +45,6 @@ class OptimizerWithMixedPrecision: def __init__(self, optimizer, amp_lists, use_pure_bf16, use_bf16_guard): self._optimizer = optimizer - if optimizer.type == 'sgd': - optimizer._use_mkldnn = True self._amp_lists = amp_lists self._param_grads = None self._train_program = None