Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ void FusedConv2DKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = mkldnn_data_type == "bfloat16";
bool is_BFLOAT16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand Down Expand Up @@ -68,12 +68,12 @@ void FusedDepthwiseConv2DKernel(
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = mkldnn_data_type == "bfloat16";
bool is_BFLOAT16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand Down Expand Up @@ -106,12 +106,12 @@ void FusedConv3DKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = mkldnn_data_type == "bfloat16";
bool is_BFLOAT16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/kernels/onednn/multi_gru_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MultiGRUHandler {
const std::string& gate_activation,
int layers,
bool origin_mode,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
float scale_data,
float shift_data,
bool force_fp32_output,
Expand Down Expand Up @@ -695,7 +695,7 @@ void RunKernel(const Context& dev_ctx,
const std::string& gate_activation,
int layers_in,
bool origin_mode,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
float scale_data,
float shift_data,
bool force_fp32_output,
Expand All @@ -710,7 +710,7 @@ void RunKernel(const Context& dev_ctx,
gate_activation,
layers_in,
origin_mode,
mkldnn_data_type,
onednn_data_type,
scale_data,
shift_data,
force_fp32_output,
Expand All @@ -732,7 +732,7 @@ void RunKernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void MultiGRUMKLDNNKernel(
void MultiGRUONEDNNKernel(
const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& weight_x,
Expand All @@ -743,7 +743,7 @@ void MultiGRUMKLDNNKernel(
const std::string& gate_activation,
int layers,
bool origin_mode,
const std::string& mkldnn_data_type,
const std::string& onednn_data_type,
float scale_data,
float shift_data,
bool force_fp32_output,
Expand All @@ -769,7 +769,7 @@ void MultiGRUMKLDNNKernel(
gate_activation,
layers,
origin_mode,
mkldnn_data_type,
onednn_data_type,
scale_data,
shift_data,
force_fp32_output,
Expand All @@ -785,7 +785,7 @@ void MultiGRUMKLDNNKernel(
gate_activation,
layers,
origin_mode,
mkldnn_data_type,
onednn_data_type,
scale_data,
shift_data,
force_fp32_output,
Expand All @@ -795,4 +795,4 @@ void MultiGRUMKLDNNKernel(
} // namespace phi

PD_REGISTER_KERNEL(
multi_gru, OneDNN, ONEDNN, phi::MultiGRUMKLDNNKernel, float, uint8_t) {}
multi_gru, OneDNN, ONEDNN, phi::MultiGRUONEDNNKernel, float, uint8_t) {}
Loading