Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions paddle/phi/core/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ class KernelContext {
return paddle::none;
}

const TensorBase* MutableIutputAt(size_t idx) const {
return inputs_.at(idx);
}
const TensorBase* MutableInputAt(size_t idx) const { return inputs_.at(idx); }

template <typename TensorType>
TensorType* MutableOutputAt(size_t idx) {
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void FusedConv2DKernel(const Context& dev_ctx,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = onednn_data_type == "bfloat16";
bool is_bfloat16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand All @@ -48,7 +48,7 @@ void FusedConv2DKernel(const Context& dev_ctx,
groups,
data_format,
true,
is_BFLOAT16,
is_bfloat16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
Expand All @@ -73,7 +73,7 @@ void FusedDepthwiseConv2DKernel(
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = onednn_data_type == "bfloat16";
bool is_bfloat16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand All @@ -87,7 +87,7 @@ void FusedDepthwiseConv2DKernel(
groups,
data_format,
true,
is_BFLOAT16,
is_bfloat16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
Expand All @@ -111,7 +111,7 @@ void FusedConv3DKernel(const Context& dev_ctx,
bool fuse_residual_conn,
bool force_fp32_output,
DenseTensor* out) {
bool is_BFLOAT16 = onednn_data_type == "bfloat16";
bool is_bfloat16 = onednn_data_type == "bfloat16";

ConvOnednn<T>(dev_ctx,
&input,
Expand All @@ -125,7 +125,7 @@ void FusedConv3DKernel(const Context& dev_ctx,
groups,
data_format,
true,
is_BFLOAT16,
is_bfloat16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/legacy/gpu/int_bincount.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void IntBincount(const Context &dev_ctx,

auto bins_dtype = TransToDataType(out_dtype);

// auto x_dytpe = x.dtype();
// auto x_dtype = x.dtype();
auto low_v = static_cast<T>(low);
auto high_v = static_cast<T>(high);
PD_CHECK(static_cast<int64_t>(low_v) == low);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void apply_moe_dispatch_bwd(const T* y_grad,
// topk_grad_with_mask_launcher<float>(combine_weights_grad,
// expert_id,
// combine_weights,
// gate_logtis_grad,
// gate_logits_grad,
// num_rows, k, num_experts, stream);
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/onednn/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
namespace phi {
bool AddNCheckIfOneDNNSupport(const KernelContext* dev_ctx) {
for (size_t i = 0; i < dev_ctx->InputsSize(); i++) {
if (!DenseTensor::classof(dev_ctx->MutableIutputAt(i))) {
if (!DenseTensor::classof(dev_ctx->MutableInputAt(i))) {
return false;
}
}
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/onednn/conv_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void ComputeFP32(const OneDNNContext& dev_ctx,
int groups,
const std::string& data_format,
bool is_test,
bool is_BFLOAT16,
bool is_bfloat16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
Expand All @@ -108,7 +108,7 @@ void ComputeFP32(const OneDNNContext& dev_ctx,
groups,
data_format,
is_test,
is_BFLOAT16,
is_bfloat16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
Expand Down Expand Up @@ -157,7 +157,7 @@ void ComputeINT8(const OneDNNContext& dev_ctx,
int groups,
const std::string& data_format,
bool is_test,
bool is_BFLOAT16,
bool is_bfloat16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
Expand Down Expand Up @@ -196,7 +196,7 @@ void ComputeINT8(const OneDNNContext& dev_ctx,
groups,
data_format,
is_test,
is_BFLOAT16,
is_bfloat16,
fuse_activation,
fuse_residual_conn,
force_fp32_output,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/onednn/conv_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ConvOneDNNHandlerT
int groups,
const std::string& data_format UNUSED,
bool is_test,
bool is_BFLOAT16,
bool is_bfloat16,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
Expand Down Expand Up @@ -183,7 +183,7 @@ class ConvOneDNNHandlerT
*/
auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
auto data_type = dnnl::memory::data_type::f32;
if (is_BFLOAT16 || std::is_same<T_out, dtype::bfloat16>::value) {
if (is_bfloat16 || std::is_same<T_out, dtype::bfloat16>::value) {
data_type = dnnl::memory::data_type::bf16;
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/onednn/conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void ConvKernel(const Context& dev_ctx,
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
: false;
bool is_BFLOAT16 =
bool is_bfloat16 =
dev_ctx.HasDnnAttr("mkldnn_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("mkldnn_data_type")) ==
Expand All @@ -47,7 +47,7 @@ void ConvKernel(const Context& dev_ctx,
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("onednn_data_type")) ==
"bfloat16"
: is_BFLOAT16;
: is_bfloat16;
bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/onednn/conv_transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class ConvTransposeOneDNNHandlerT
*/
auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
auto data_type = dnnl::memory::data_type::f32;
const bool is_BFLOAT16 =
const bool is_bfloat16 =
dev_ctx.HasDnnAttr("mkldnn_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("mkldnn_data_type")) ==
Expand All @@ -162,7 +162,7 @@ class ConvTransposeOneDNNHandlerT
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("onednn_data_type")) ==
"bfloat16"
: is_BFLOAT16;
: is_bfloat16;
if (is_onednn_BFLOAT16 || std::is_same<T_out, dtype::bfloat16>::value) {
data_type = dnnl::memory::data_type::bf16;
}
Expand Down Expand Up @@ -499,7 +499,7 @@ void Conv2dTransposeKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::string& data_format UNUSED,
DenseTensor* out) {
const bool is_BFLOAT16 =
const bool is_bfloat16 =
dev_ctx.HasDnnAttr("mkldnn_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("mkldnn_data_type")) ==
Expand All @@ -510,7 +510,7 @@ void Conv2dTransposeKernel(const Context& dev_ctx,
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("onednn_data_type")) ==
"bfloat16"
: is_BFLOAT16;
: is_bfloat16;
const bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
Expand Down Expand Up @@ -556,7 +556,7 @@ void Conv2dTransposeBiasKernel(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::string& data_format UNUSED,
DenseTensor* out) {
const bool is_BFLOAT16 =
const bool is_bfloat16 =
dev_ctx.HasDnnAttr("mkldnn_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("mkldnn_data_type")) ==
Expand All @@ -567,7 +567,7 @@ void Conv2dTransposeBiasKernel(const Context& dev_ctx,
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("onednn_data_type")) ==
"bfloat16"
: is_BFLOAT16;
: is_bfloat16;
const bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/onednn/sgd_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
namespace phi {

bool SgdCheckIfOneDNNSupport(const KernelContext* dev_ctx) {
if (DenseTensor::classof(dev_ctx->MutableIutputAt(0)) &&
DenseTensor::classof(dev_ctx->MutableIutputAt(2))) {
if (DenseTensor::classof(dev_ctx->MutableInputAt(0)) &&
DenseTensor::classof(dev_ctx->MutableInputAt(2))) {
return true;
}
return false;
}

bool SgdSparseCheckIfOneDNNSupport(const KernelContext* dev_ctx) {
if (DenseTensor::classof(dev_ctx->MutableIutputAt(0)) &&
SelectedRows::classof(dev_ctx->MutableIutputAt(2))) {
if (DenseTensor::classof(dev_ctx->MutableInputAt(0)) &&
SelectedRows::classof(dev_ctx->MutableInputAt(2))) {
return true;
}
return false;
Expand All @@ -49,7 +49,7 @@ void SGDDenseKernel(const Context& dev_ctx,
const T* param_data = param.data<T>();
const auto* grad_data = grad.data<T>();
const auto* lr = learning_rate.data<T>();
// Since denese SGD is not in place operation, first copy params to output
// Since dense SGD is not in place operation, first copy params to output
// tensor and then update it.
std::memcpy(out_data, param_data, param.memory_size());
funcs::OneDNNAXPYHandler<T>(param_out->numel(), -lr[0], dev_ctx.GetEngine())(
Expand Down
4 changes: 2 additions & 2 deletions test/ir/inference/auto_scan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from paddle.base.core import PassVersionChecker
from paddle.static.log_helper import get_logger

# windows and xpu not support tensort
# windows and xpu not support tensorrt
if os.name != 'nt' and (not os.getenv('WITH_XPU')):
try:
from paddle.tensorrt.export import (
Expand Down Expand Up @@ -171,7 +171,7 @@ def transform_to_trt_program(self, pir_program, trt_config):
trt_config.precision_mode = PrecisionMode.FP16

paddle.framework.set_flags({"FLAGS_trt_min_group_size": 1})
# translalte pir program to trt program
# translate pir program to trt program
scope = paddle.static.global_scope()
program_with_trt = convert_to_trt(pir_program, trt_config, scope)

Expand Down