diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml index 290b20357824eb..835926a4c377f0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -27,7 +27,8 @@ # - op : bilinear_interp -# - op : cast +- op : cast + dynamic_fallback : True - op : clip extra_args : str mkldnn_data_type="float32" diff --git a/paddle/phi/kernels/onednn/cast_kernel.cc b/paddle/phi/kernels/onednn/cast_kernel.cc index 9bf0a3e8a875fa..fdee92c2d4f167 100644 --- a/paddle/phi/kernels/onednn/cast_kernel.cc +++ b/paddle/phi/kernels/onednn/cast_kernel.cc @@ -19,6 +19,16 @@ namespace phi { +bool CastCheckIfOneDNNSupport(const KernelContext* ctx) { + if ((ctx->InputAt(0).dtype() != DataType::FLOAT32 && + ctx->InputAt(0).dtype() != DataType::BFLOAT16) || + (ctx->AttrAt(0) != DataType::FLOAT32 && + ctx->AttrAt(0) != DataType::BFLOAT16)) { + return false; + } + return true; +} + template void CastKernel(const Context& dev_ctx, const DenseTensor& x, @@ -56,4 +66,6 @@ void CastKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - cast, OneDNN, ONEDNN, phi::CastKernel, float, phi::dtype::bfloat16) {} + cast, OneDNN, ONEDNN, phi::CastKernel, float, phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::CastCheckIfOneDNNSupport; +} diff --git a/test/mkldnn/test_cast_mkldnn_op.py b/test/mkldnn/test_cast_mkldnn_op.py index 8856fd9071391b..a5a9562cf78130 100644 --- a/test/mkldnn/test_cast_mkldnn_op.py +++ b/test/mkldnn/test_cast_mkldnn_op.py @@ -47,7 +47,7 @@ def setUp(self): self.op_type = 'cast' def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output(check_dygraph=False, check_pir_onednn=True) def test_check_grad(self): self.check_grad_with_place( @@ -57,6 +57,7 @@ def test_check_grad(self): check_dygraph=False, user_defined_grads=[self.inputs['X']], user_defined_grad_outputs=[self.outputs['Out']], + check_pir_onednn=True, ) def init_shape(self):