Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

# - op : bilinear_interp

# - op : cast
- op : cast
dynamic_fallback : True

- op : clip
extra_args : str mkldnn_data_type="float32"
Expand Down
14 changes: 13 additions & 1 deletion paddle/phi/kernels/onednn/cast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@

namespace phi {

bool CastCheckIfOneDNNSupport(const KernelContext* ctx) {
if ((ctx->InputAt<phi::DenseTensor>(0).dtype() != DataType::FLOAT32 &&
ctx->InputAt<phi::DenseTensor>(0).dtype() != DataType::BFLOAT16) ||
(ctx->AttrAt<DataType>(0) != DataType::FLOAT32 &&
ctx->AttrAt<DataType>(0) != DataType::BFLOAT16)) {
return false;
}
return true;
}

template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -56,4 +66,6 @@ void CastKernel(const Context& dev_ctx,
} // namespace phi

PD_REGISTER_KERNEL(
cast, OneDNN, ONEDNN, phi::CastKernel, float, phi::dtype::bfloat16) {}
cast, OneDNN, ONEDNN, phi::CastKernel, float, phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::CastCheckIfOneDNNSupport;
}
3 changes: 2 additions & 1 deletion test/mkldnn/test_cast_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setUp(self):
self.op_type = 'cast'

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output(check_dygraph=False, check_pir_onednn=True)

def test_check_grad(self):
self.check_grad_with_place(
Expand All @@ -57,6 +57,7 @@ def test_check_grad(self):
check_dygraph=False,
user_defined_grads=[self.inputs['X']],
user_defined_grad_outputs=[self.outputs['Out']],
check_pir_onednn=True,
)

def init_shape(self):
Expand Down