From ec0be2823c2f4b3a0cb4889c798fecb353dfcf5f Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 24 Sep 2019 23:26:44 -0700 Subject: [PATCH 1/2] Changed ConvBase into a class member variable Currently, all Conv-family classes inherit from both ConvBase and OpKernel. Since what ConvBase provides is all about processing convolution attributes, it's more natural to move it as a class member variable. This change renamed ConvBase to ConvAttributes and moved it into a separate file conv_attributes by its own. Instead of inheriting from ConvBase, now each Conv-related class has a class member variable that is of type ConvAttributes. Hence, we removed unecessary multiple inheritance and increase composibility. More importantly, the change made it possible for some other providers such as Nuphar be able to re-use the functionalities provided by ConvAttributes class. Note that we also made similar changes to ConvTransposeBase. --- onnxruntime/contrib_ops/cpu/nchwc_ops.cc | 14 +- onnxruntime/contrib_ops/cpu/nchwc_ops.h | 8 +- onnxruntime/core/providers/cpu/nn/conv.cc | 64 ++--- onnxruntime/core/providers/cpu/nn/conv.h | 16 +- .../cpu/nn/{conv_base.h => conv_attributes.h} | 98 ++++--- .../core/providers/cpu/nn/conv_integer.cc | 28 +- .../core/providers/cpu/nn/conv_integer.h | 9 +- .../core/providers/cpu/nn/conv_transpose.cc | 217 +--------------- .../core/providers/cpu/nn/conv_transpose.h | 45 +--- .../cpu/nn/conv_transpose_attributes.h | 242 ++++++++++++++++++ .../core/providers/cpu/nn/qlinearconv.cc | 30 +-- .../core/providers/cpu/nn/qlinearconv.h | 11 +- onnxruntime/core/providers/cuda/nn/conv.cc | 18 +- onnxruntime/core/providers/cuda/nn/conv.h | 12 +- .../core/providers/cuda/nn/conv_transpose.cc | 8 +- .../core/providers/cuda/nn/conv_transpose.h | 8 +- onnxruntime/core/providers/mkldnn/nn/conv.cc | 15 +- 17 files changed, 441 insertions(+), 402 deletions(-) rename onnxruntime/core/providers/cpu/nn/{conv_base.h => conv_attributes.h} (70%) create mode 100644 onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 0ba035f3470a2..5e16eda1481b4 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -98,7 +98,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const { const auto* B = context->Input(2); const auto* Sum = context->Input(3); - ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); const auto& X_shape = X->Shape(); const auto& W_shape = W->Shape(); @@ -108,20 +108,20 @@ Status NchwcConv::Compute(OpKernelContext* context) const { ORT_ENFORCE((static_cast(X_shape[1]) < nchwc_block_size) || ((X_shape[1] % nchwc_block_size) == 0)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ConvBase::ComputeKernelShape(W_shape, kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape)); if (kernel_shape.size() != 2) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported convolution size."); } - std::vector pads(ConvBase::pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_shape.size() * 2, 0); } - std::vector dilations(ConvBase::dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_shape.size(), 1); } - std::vector strides(ConvBase::strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_shape.size(), 1); } @@ -129,7 +129,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); auto* Y = context->Output(0, Y_dims); auto* y_data = Y->template MutableData(); @@ -151,7 +151,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const { pads.data(), strides.data(), Y_dims.data(), - static_cast(ConvBase::group_), + static_cast(conv_attrs_.group), X->template Data(), W->template Data(), B != nullptr ? B->template Data() : nullptr, diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 5fa3a78f09bb9..5a6c606231362 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/providers/cpu/nn/conv_base.h" +#include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/cpu/nn/pool.h" #include "contrib_ops/cpu/fused_activation.h" @@ -35,15 +35,17 @@ class ReorderOutput : public OpKernel { int64_t channels_; }; -class NchwcConv : public OpKernel, public ConvBase { +class NchwcConv : public OpKernel { public: - NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); } Status Compute(OpKernelContext* context) const override; private: + ConvAttributes conv_attrs_; + MLAS_ACTIVATION activation_; }; diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index c0091936704d8..87ca17fd0c0cc 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -34,21 +34,21 @@ Status Conv::Compute(OpKernelContext* context) const { const int64_t N = X->Shape()[0]; const int64_t C = X->Shape()[1]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); bool Is2DKernel = kernel_shape.size() == 2; - std::vector pads(pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_shape.size() * 2, 0); } - std::vector dilations(dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_shape.size(), 1); } - std::vector strides(strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_shape.size(), 1); } @@ -56,17 +56,17 @@ Status Conv::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t X_offset = C / conv_attrs_.group * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group; + const int64_t W_offset = W->Shape().Size() / conv_attrs_.group; + const int64_t kernel_dim = C / conv_attrs_.group * kernel_size; const int64_t col_buffer_size = kernel_dim * output_image_size; AllocatorPtr alloc; @@ -85,11 +85,11 @@ Status Conv::Compute(OpKernelContext* context) const { output_shape.GetDims().end()); for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { + for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { if (Is2DKernel) { math::Im2col( Xdata + group_id * X_offset, - C / group_, + C / conv_attrs_.group, input_shape[0], input_shape[1], kernel_shape[0], @@ -122,7 +122,7 @@ Status Conv::Compute(OpKernelContext* context) const { math::Gemm( CblasNoTrans, CblasNoTrans, - M / group_, + M / conv_attrs_.group, output_image_size, kernel_dim, 1, @@ -139,8 +139,8 @@ Status Conv::Compute(OpKernelContext* context) const { Ymatrix.rowwise() += Bvec.transpose(); } - Xdata += X_offset * group_; - Ydata += Y_offset * group_; + Xdata += X_offset * conv_attrs_.group; + Ydata += Y_offset * conv_attrs_.group; } return Status::OK(); @@ -157,20 +157,20 @@ Status Conv::Compute(OpKernelContext* context) const { const int64_t N = X->Shape()[0]; const int64_t C = X->Shape()[1]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - std::vector pads(pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_shape.size() * 2, 0); } - std::vector dilations(dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_shape.size(), 1); } - std::vector strides(strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_shape.size(), 1); } @@ -178,7 +178,7 @@ Status Conv::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); @@ -197,15 +197,15 @@ Status Conv::Compute(OpKernelContext* context) const { MlasConvPrepare(&Parameters, kernel_rank, static_cast(N), - static_cast(group_), - static_cast(C / group_), + static_cast(conv_attrs_.group), + static_cast(C / conv_attrs_.group), input_shape.GetDims().data(), kernel_shape.data(), dilations.data(), pads.data(), strides.data(), output_shape.GetDims().data(), - static_cast(M / group_), + static_cast(M / conv_attrs_.group), &activation_, &WorkingBufferSize, tp); @@ -224,10 +224,10 @@ Status Conv::Compute(OpKernelContext* context) const { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t X_offset = C / conv_attrs_.group * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group; + const int64_t W_offset = W->Shape().Size() / conv_attrs_.group; + const int64_t kernel_dim = C / conv_attrs_.group * kernel_size; const int64_t col_buffer_size = kernel_dim * output_image_size; auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size); @@ -240,7 +240,7 @@ Status Conv::Compute(OpKernelContext* context) const { output_shape.GetDims().end()); for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { + for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), @@ -257,7 +257,7 @@ Status Conv::Compute(OpKernelContext* context) const { math::Gemm( CblasNoTrans, CblasNoTrans, - M / group_, + M / conv_attrs_.group, output_image_size, kernel_dim, 1, @@ -270,8 +270,8 @@ Status Conv::Compute(OpKernelContext* context) const { MlasActivation(&activation_, Ydata, Bdata, M, output_image_size, output_image_size); - Xdata += X_offset * group_; - Ydata += Y_offset * group_; + Xdata += X_offset * conv_attrs_.group; + Ydata += Y_offset * conv_attrs_.group; } } diff --git a/onnxruntime/core/providers/cpu/nn/conv.h b/onnxruntime/core/providers/cpu/nn/conv.h index 3e366e1b49775..6d8d4ad92e20d 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.h +++ b/onnxruntime/core/providers/cpu/nn/conv.h @@ -3,24 +3,28 @@ #pragma once -#include "core/providers/cpu/nn/conv_base.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" #include "core/mlas/inc/mlas.h" namespace onnxruntime { template -class Conv : public OpKernel, public ConvBase { +class Conv : public OpKernel { public: - Conv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { } Status Compute(OpKernelContext* context) const override; + + private: + ConvAttributes conv_attrs_; }; template <> -class Conv : public OpKernel, public ConvBase { +class Conv : public OpKernel { public: - Conv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { activation_.ActivationKind = MlasIdentityActivation; } @@ -28,6 +32,8 @@ class Conv : public OpKernel, public ConvBase { protected: MLAS_ACTIVATION activation_; + + ConvAttributes conv_attrs_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_base.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h similarity index 70% rename from onnxruntime/core/providers/cpu/nn/conv_base.h rename to onnxruntime/core/providers/cpu/nn/conv_attributes.h index 9d6a4316d5730..7286bcaecf59b 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_base.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -5,7 +5,7 @@ #include "core/common/common.h" #include "core/common/exceptions.h" -#include "core/framework/op_kernel.h" +#include "core/framework/op_node_proto_helper.h" #include "core/providers/cpu/nn/autopad_type.h" #include "core/util/math.h" @@ -58,54 +58,52 @@ Status ComputePadAndOutputShape( return Status::OK(); } -// base class used by Conv and ConvTranspose -class ConvBase { - protected: - explicit ConvBase(const OpKernelInfo& info) { - std::string auto_pad; - auto status = info.GetAttr("auto_pad", &auto_pad); - auto_pad_ = status.IsOK() ? StringToAutoPadType(auto_pad) : AutoPadType::NOTSET; +// A helper struct holding attributes for Conv-family ops +struct ConvAttributes { + explicit ConvAttributes(const OpNodeProtoHelper& info) { + std::string auto_pad_str; + auto status = info.GetAttr("auto_pad", &auto_pad_str); + auto_pad = status.IsOK() ? StringToAutoPadType(auto_pad_str) : AutoPadType::NOTSET; - kernel_shape_specified_ = info.GetAttrs("kernel_shape", kernel_shape_).IsOK(); + kernel_shape_specified = info.GetAttrs("kernel_shape", kernel_shape_).IsOK(); - status = info.GetAttrs("strides", strides_); + status = info.GetAttrs("strides", strides); if (!status.IsOK()) { - strides_.resize(kernel_shape_.size(), 1); + strides.resize(kernel_shape_.size(), 1); } - status = info.GetAttrs("pads", pads_); + status = info.GetAttrs("pads", pads); if (!status.IsOK()) { - pads_.resize(kernel_shape_.size() * 2, 0); + pads.resize(kernel_shape_.size() * 2, 0); } - status = info.GetAttrs("dilations", dilations_); + status = info.GetAttrs("dilations", dilations); if (!status.IsOK()) { - dilations_.resize(kernel_shape_.size(), 1); + dilations.resize(kernel_shape_.size(), 1); } - status = info.GetAttr("group", &group_); + status = info.GetAttr("group", &group); if (!status.IsOK()) { - group_ = 1; + group = 1; } #if false // TODO: Re-enable when attributes values are guaranteed to be filled. - std::string auto_pad; - ORT_ENFORCE(info.GetAttr("auto_pad", &auto_pad).IsOK()); - auto_pad_ = StringToAutoPadType(auto_pad); - ORT_ENFORCE(info.GetAttr("group", &group_).IsOK()); + std::string auto_pad_str; + ORT_ENFORCE(info.GetAttr("auto_pad", &auto_pad_str).IsOK()); + auto_pad = StringToAutoPadType(auto_pad_str); + ORT_ENFORCE(info.GetAttr("group", &group).IsOK()); ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape_).IsOK()); - ORT_ENFORCE(info.GetAttrs("strides", strides_).IsOK()); - ORT_ENFORCE(info.GetAttrs("pads", pads_).IsOK()); - ORT_ENFORCE(info.GetAttrs("dilations", dilations_).IsOK()); + ORT_ENFORCE(info.GetAttrs("strides", strides).IsOK()); + ORT_ENFORCE(info.GetAttrs("pads", pads).IsOK()); + ORT_ENFORCE(info.GetAttrs("dilations", dilations).IsOK()); #endif } - ~ConvBase() = default; + ~ConvAttributes() = default; - protected: Status ComputeKernelShape(const TensorShape& weight_shape, std::vector& kernel_shape) const { - if (kernel_shape_specified_) { + if (kernel_shape_specified) { kernel_shape = kernel_shape_; if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", @@ -137,17 +135,17 @@ class ConvBase { " W: ", W->Shape().ToString().c_str()); } - if (C != W->Shape()[1] * group_) { + if (C != W->Shape()[1] * group) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.", " C: ", C, " kernel channels: ", W->Shape()[1], - " group: ", group_); + " group: ", group); } - if (M % group_ != 0) { + if (M % group != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output channels M is not divisible by group.", " M: ", M, - " group: ", group_); + " group: ", group); } return Status::OK(); } @@ -155,26 +153,26 @@ class ConvBase { template Status InferOutputShape(const TensorShape& input_shape, const std::vector& kernel_shape, - const std::vector& strides, - const std::vector& dilations, - std::vector* pads, + const std::vector& strides_p, + const std::vector& dilations_p, + std::vector* pads_p, std::vector* output_shape) const { size_t rank = input_shape.NumDimensions(); for (size_t dim = 0; dim < rank; ++dim) { - if (dim >= strides.size() || dim >= kernel_shape.size() || - dim >= dilations.size() || dim >= pads->size() || - rank + dim >= pads->size()) { + if (dim >= strides_p.size() || dim >= kernel_shape.size() || + dim >= dilations_p.size() || dim >= pads_p->size() || + rank + dim >= pads_p->size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Out of bound access to array"); } int64_t dim_size = 0; ORT_RETURN_IF_ERROR(ComputePadAndOutputShape( input_shape[dim], - strides[dim], + strides_p[dim], kernel_shape[dim], - dilations[dim], - auto_pad_, - &pads->at(dim), - &pads->at(input_shape.NumDimensions() + dim), + dilations_p[dim], + auto_pad, + &pads_p->at(dim), + &pads_p->at(input_shape.NumDimensions() + dim), &dim_size)); if (dim_size <= 0) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input shape: " + input_shape.ToString()); @@ -184,14 +182,14 @@ class ConvBase { return Status::OK(); } - AutoPadType auto_pad_; - int64_t group_; - bool kernel_shape_specified_; - std::vector strides_; - std::vector pads_; - std::vector dilations_; - std::string activation_; - float alpha_; + AutoPadType auto_pad; + int64_t group; + bool kernel_shape_specified; + std::vector strides; + std::vector pads; + std::vector dilations; + std::string activation; + float alpha; private: std::vector kernel_shape_; // must use ComputeKernelShape(...), instead of kernel_shape_ diff --git a/onnxruntime/core/providers/cpu/nn/conv_integer.cc b/onnxruntime/core/providers/cpu/nn/conv_integer.cc index 534cb75a6e840..39c01b98cfbc8 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_integer.cc @@ -41,20 +41,20 @@ Status ConvInteger::Compute(OpKernelContext* context) const { const int64_t N = X->Shape()[0]; const int64_t C = X->Shape()[1]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - std::vector pads(pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_shape.size() * 2, 0); } - std::vector dilations(dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_shape.size(), 1); } - std::vector strides(strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_shape.size(), 1); } @@ -62,7 +62,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); @@ -75,10 +75,10 @@ Status ConvInteger::Compute(OpKernelContext* context) const { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t X_offset = C / conv_attrs_.group * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group; + const int64_t W_offset = W->Shape().Size() / conv_attrs_.group; + const int64_t kernel_dim = C / conv_attrs_.group * kernel_size; const int64_t col_buffer_size = kernel_dim * output_image_size; auto col_data = alloc->Alloc(sizeof(uint8_t) * col_buffer_size); @@ -91,7 +91,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const { output_shape.GetDims().end()); for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { + for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), @@ -108,7 +108,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const { false, input_offset); - QGemmu8u8_s32(static_cast(M / group_), + QGemmu8u8_s32(static_cast(M / conv_attrs_.group), static_cast(output_image_size), static_cast(kernel_dim), W->template Data() + group_id * W_offset, @@ -122,8 +122,8 @@ Status ConvInteger::Compute(OpKernelContext* context) const { nullptr); } - Xdata += X_offset * group_; - Ydata += Y_offset * group_; + Xdata += X_offset * conv_attrs_.group; + Ydata += Y_offset * conv_attrs_.group; } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/nn/conv_integer.h b/onnxruntime/core/providers/cpu/nn/conv_integer.h index 267e567e20e55..dc603bcb31768 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_integer.h +++ b/onnxruntime/core/providers/cpu/nn/conv_integer.h @@ -3,14 +3,17 @@ #pragma once -#include "core/providers/cpu/nn/conv_base.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" namespace onnxruntime { -class ConvInteger : public OpKernel, public ConvBase { +class ConvInteger : public OpKernel { public: - explicit ConvInteger(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + explicit ConvInteger(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { } Status Compute(OpKernelContext* context) const override; + + ConvAttributes conv_attrs_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 9fd9cd1502147..4e3c7b33ea684 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -29,200 +29,6 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), ConvTranspose); -inline void ComputeTransposePadAndOutputShape( - const int64_t in_size, - const int64_t stride, - const int64_t kernel, - const int64_t dilation, - const int64_t adj, - AutoPadType pad_type, - int64_t* pad_head, - int64_t* pad_tail, - int64_t* out_size) { - if (*out_size != -1) { - ORT_ENFORCE(*out_size >= 0); - // total padding size - int64_t paddings = std::max(0, (in_size - 1) * stride + kernel + dilation - 1 + adj - *out_size); - if (pad_type == AutoPadType::SAME_UPPER) { // pad more on head when paddings are odd. - *pad_head = paddings - paddings / 2; - *pad_tail = paddings / 2; - } else { - // for pad_type is NOTSET, SAME_LOWER or VALID - // set pad_head as paddings/2, pad_tail as paddings-paddings/2. - // That said, we pad more on tail when paddings are odd. - *pad_head = paddings / 2; - *pad_tail = paddings - paddings / 2; - } - return; - } - if (pad_type != AutoPadType::NOTSET) { - switch (pad_type) { - // We handle cases of AutoPadType::VALID and AutoPadType::SAME_UPPER/LOWER, - // the same way - case AutoPadType::VALID: - case AutoPadType::SAME_UPPER: - case AutoPadType::SAME_LOWER: - *pad_head = 0; - *pad_tail = 0; - *out_size = (in_size - 1) * stride + kernel + dilation - 1 + adj; - break; - default: - throw NotImplementedException("pad type not supported"); - } - } else { - *out_size = - (in_size - 1) * stride + kernel + dilation - 1 + adj - *pad_head - *pad_tail; - } -} - -Status ConvTransposeBase::PrepareForCompute(OpKernelContext* context, bool has_bias, ConvTransposeBase::Prepare& p, bool dynamic_padding) const { - const Tensor* X = context->Input(0); - const Tensor* F = context->Input(1); - const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; - const TensorShape& input_shape = X->Shape(); - - // input validations - if (group_ <= 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", - " group: ", group_); - } - - if (input_shape.NumDimensions() != 4) { - // This condition is not true for two tests in ONNX tests series: - // test_convtranspose_1d_cpu, test_convtranspose_3d_cpu. - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 4-dimensional.", - " X: ", X->Shape().ToString().c_str()); - } - - if (input_shape.NumDimensions() != F->Shape().NumDimensions()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", - " X: ", X->Shape().ToString().c_str(), - " W: ", F->Shape().ToString().c_str()); - } - - const int64_t num_input_channels = input_shape[1]; - - if (F->Shape()[0] != num_input_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", - " filter_number: ", F->Shape()[0], - " num_input_channels: ", num_input_channels); - } - - const int64_t N = input_shape[0]; - const int64_t H = input_shape[2]; - const int64_t W = input_shape[3]; - const int64_t num_output_channels_multiplier = F->Shape()[1]; - const int64_t num_output_channels = num_output_channels_multiplier * group_; - - // it looks like num_output_channels is really k*group_ similar to how in the conv case - // num_input_channels is k*group_. hence removing the check for num_output_channels here. - - if (num_input_channels % group_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", - " num_input_channels: ", num_input_channels, - " group: ", group_); - } - - std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(F->Shape(), kernel_shape)); - - std::vector output_padding(output_padding_); - if (output_padding.empty()) { - output_padding.resize(kernel_shape.size(), 0); - } - std::vector pads; - pads.reserve(2 * (input_shape.NumDimensions() - 2)); - if (dynamic_padding) { - for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { - pads.push_back(Pads->Data()[i]); - } - } else { - pads.assign(pads_.begin(), pads_.end()); - } - if (pads.empty()) { - pads.resize(kernel_shape.size() * 2, 0); - } - std::vector dilations(dilations_); - if (dilations.empty()) { - dilations.resize(kernel_shape.size(), 1); - } - std::vector strides(strides_); - if (strides.empty()) { - strides.resize(kernel_shape.size(), 1); - } - - std::vector Y_dims; - - ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, dilations, output_padding, &pads, &Y_dims); - TensorShape Yshape(Y_dims); - Tensor* Y = context->Output(0, Yshape); - - p.X = X; - p.F = F; - p.B = B; - p.Y = Y; - p.N = N; - p.H = H; - p.W = W; - p.num_input_channels = num_input_channels; - p.num_output_channels = num_output_channels; - p.kernel_shape = std::move(kernel_shape); - p.pads = std::move(pads); - p.strides = std::move(strides); - p.dilations = std::move(dilations); - return Status::OK(); -} - -void ConvTransposeBase::ComputePadsAndOutputShape( - const TensorShape input_shape, - const int64_t output_channel, - const std::vector& kernel_shape, - const std::vector& strides, - const std::vector& dilations, - const std::vector& output_padding, - std::vector* pads, - std::vector* output_shape) const { - const int64_t N = input_shape[0]; - const int64_t H = input_shape[2]; - const int64_t W = input_shape[3]; - int64_t output_height = -1; - int64_t output_width = -1; - size_t output_shape_size = output_shape_.size(); - - if (output_shape_size != 0) { - output_height = output_shape_[output_shape_size - 2]; - output_width = output_shape_[output_shape_size - 1]; - ORT_ENFORCE(output_height >= H, "Output height cannot be smaller than input height."); - ORT_ENFORCE(output_width >= W, "Output width cannot be smaller than input width."); - } - - ComputeTransposePadAndOutputShape( - H, - strides[0], - kernel_shape[0], - dilations[0], - output_padding[0], - auto_pad_, - &pads->at(0), - &pads->at(2), - &output_height); - - ComputeTransposePadAndOutputShape( - W, - strides[1], - kernel_shape[1], - dilations[1], - output_padding[1], - auto_pad_, - &pads->at(1), - &pads->at(3), - &output_width); - - output_shape->insert(output_shape->begin(), {N, output_channel, output_height, output_width}); -} - template Status ConvTranspose::Compute(OpKernelContext* context) const { return ConvTranspose::DoConvTranspose(context, false); @@ -234,15 +40,16 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); size_t num_inputs = OpKernel::Node().InputDefs().size(); - Prepare p; + ConvTransposeAttributes::Prepare p; bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - ORT_RETURN_IF_ERROR(PrepareForCompute(context, has_bias, p, dynamic_padding)); + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); const int64_t input_image_size = p.H * p.W; - const int64_t X_offset = p.num_input_channels / group_ * input_image_size; - const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / group_; - const int64_t W_offset = p.F->Shape().Size() / group_; - const int64_t kernel_dim = p.num_output_channels / group_ * p.kernel_shape[0] * p.kernel_shape[1]; + const int64_t X_offset = p.num_input_channels / conv_transpose_attrs_.group * input_image_size; + const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / conv_transpose_attrs_.group; + const int64_t W_offset = p.F->Shape().Size() / conv_transpose_attrs_.group; + const int64_t kernel_dim = + p.num_output_channels / conv_transpose_attrs_.group * p.kernel_shape[0] * p.kernel_shape[1]; const int64_t output_image_size = p.Y->Shape()[2] * p.Y->Shape()[3]; AllocatorPtr alloc; @@ -257,14 +64,14 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ T* Ydata = p.Y->template MutableData(); for (auto image_id = 0; image_id < p.N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { + for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) { // Weight term math::Gemm( CblasTrans, CblasNoTrans, kernel_dim, input_image_size, - p.num_input_channels / group_, + p.num_input_channels / conv_transpose_attrs_.group, 1, filter_data + group_id * W_offset, Xdata + group_id * X_offset, @@ -275,7 +82,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ // Col2im math::Col2im( col_buffer_data, - p.num_output_channels / group_, + p.num_output_channels / conv_transpose_attrs_.group, p.Y->Shape()[2], p.Y->Shape()[3], p.kernel_shape[0], @@ -298,8 +105,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ Ymatrix.rowwise() += Bvec.transpose(); } - Xdata += X_offset * group_; - Ydata += Y_offset * group_; + Xdata += X_offset * conv_transpose_attrs_.group; + Ydata += Y_offset * conv_transpose_attrs_.group; } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index 19e32f151287a..af2b26bf73ea4 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -17,54 +17,23 @@ #pragma once -#include "core/providers/cpu/nn/conv_base.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" namespace onnxruntime { -class ConvTransposeBase : public ConvBase { - protected: - ConvTransposeBase(const OpKernelInfo& info) - : ConvBase(info), - output_padding_(info.GetAttrsOrDefault("output_padding")), - output_shape_(info.GetAttrsOrDefault("output_shape")) { - } - - struct Prepare { - const Tensor* X; - const Tensor* F; - const Tensor* B; - Tensor* Y; - int64_t N; - int64_t H; - int64_t W; - int64_t num_input_channels; - int64_t num_output_channels; - std::vector kernel_shape; - std::vector pads; - std::vector dilations; - std::vector strides; - }; - - Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, bool dynamic_padding = false) const; - - void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel, - const std::vector& kernel_shape, const std::vector& strides, - const std::vector& dilations, const std::vector& output_padding, - std::vector* pads, std::vector* output_shape) const; - - const std::vector output_padding_; - const std::vector output_shape_; -}; - template -class ConvTranspose : public OpKernel, public ConvTransposeBase { +class ConvTranspose : public OpKernel { public: - ConvTranspose(const OpKernelInfo& info) : OpKernel(info), ConvTransposeBase(info) {} + ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {} Status Compute(OpKernelContext* context) const override; protected: Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; + + private: + ConvTransposeAttributes conv_transpose_attrs_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h new file mode 100644 index 0000000000000..b1f32032d9fab --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -0,0 +1,242 @@ +/** +* Copyright (c) 2016-present, Facebook, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +/* Modifications Copyright (c) Microsoft. */ + +#pragma once + +#include "core/providers/cpu/nn/conv_attributes.h" + +namespace onnxruntime { + +struct ConvTransposeAttributes : public ConvAttributes { + explicit ConvTransposeAttributes(const OpNodeProtoHelper& info) + : ConvAttributes(info), + output_padding(info.GetAttrsOrDefault("output_padding")), + output_shape(info.GetAttrsOrDefault("output_shape")) { + } + + struct Prepare { + const Tensor* X; + const Tensor* F; + const Tensor* B; + Tensor* Y; + int64_t N; + int64_t H; + int64_t W; + int64_t num_input_channels; + int64_t num_output_channels; + std::vector kernel_shape; + std::vector pads; + std::vector dilations; + std::vector strides; + }; + + Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, bool dynamic_padding = false) const { + const Tensor* X = context->Input(0); + const Tensor* F = context->Input(1); + const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; + const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; + const TensorShape& input_shape = X->Shape(); + + // input validations + if (group <= 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", + " group: ", group); + } + + if (input_shape.NumDimensions() != 4) { + // This condition is not true for two tests in ONNX tests series: + // test_convtranspose_1d_cpu, test_convtranspose_3d_cpu. + // TODO: the error message should tell which operator raises it. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 4-dimensional.", + " X: ", X->Shape().ToString().c_str()); + } + + if (input_shape.NumDimensions() != F->Shape().NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", + " X: ", X->Shape().ToString().c_str(), + " W: ", F->Shape().ToString().c_str()); + } + + const int64_t num_input_channels = input_shape[1]; + + if (F->Shape()[0] != num_input_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", + " filter_number: ", F->Shape()[0], + " num_input_channels: ", num_input_channels); + } + + const int64_t N = input_shape[0]; + const int64_t H = input_shape[2]; + const int64_t W = input_shape[3]; + const int64_t num_output_channels_multiplier = F->Shape()[1]; + const int64_t num_output_channels = num_output_channels_multiplier * group; + + // it looks like num_output_channels is really k*group similar to how in the conv case + // num_input_channels is k*group. hence removing the check for num_output_channels here. + + if (num_input_channels % group != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", + " num_input_channels: ", num_input_channels, + " group: ", group); + } + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(ComputeKernelShape(F->Shape(), kernel_shape)); + + std::vector local_output_padding(output_padding); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape.size(), 0); + } + std::vector local_pads; + local_pads.reserve(2 * (input_shape.NumDimensions() - 2)); + if (dynamic_padding) { + for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { + local_pads.push_back(Pads->Data()[i]); + } + } else { + local_pads.assign(pads.begin(), pads.end()); + } + if (local_pads.empty()) { + local_pads.resize(kernel_shape.size() * 2, 0); + } + std::vector local_dilations(dilations); + if (local_dilations.empty()) { + local_dilations.resize(kernel_shape.size(), 1); + } + std::vector local_strides(strides); + if (local_strides.empty()) { + local_strides.resize(kernel_shape.size(), 1); + } + + std::vector Y_dims; + + ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, + local_strides, local_dilations, local_output_padding, &local_pads, &Y_dims); + TensorShape Yshape(Y_dims); + Tensor* Y = context->Output(0, Yshape); + + p.X = X; + p.F = F; + p.B = B; + p.Y = Y; + p.N = N; + p.H = H; + p.W = W; + p.num_input_channels = num_input_channels; + p.num_output_channels = num_output_channels; + p.kernel_shape = std::move(kernel_shape); + p.pads = std::move(local_pads); + p.strides = std::move(local_strides); + p.dilations = std::move(local_dilations); + return Status::OK(); + } + + void ComputePadsAndOutputShape(TensorShape input_shape, int64_t output_channel, + const std::vector& kernel_shape, const std::vector& p_strides, + const std::vector& p_dilations, const std::vector& p_output_padding, + std::vector* p_pads, std::vector* output_shape_p) const { + const int64_t N = input_shape[0]; + const int64_t H = input_shape[2]; + const int64_t W = input_shape[3]; + int64_t output_height = -1; + int64_t output_width = -1; + size_t output_shape_size = output_shape.size(); + + if (output_shape_size != 0) { + output_height = output_shape[output_shape_size - 2]; + output_width = output_shape[output_shape_size - 1]; + ORT_ENFORCE(output_height >= H, "Output height cannot be smaller than input height."); + ORT_ENFORCE(output_width >= W, "Output width cannot be smaller than input width."); + } + + ComputeTransposePadAndOutputShape( + H, + p_strides[0], + kernel_shape[0], + p_dilations[0], + p_output_padding[0], + auto_pad, + &p_pads->at(0), + &p_pads->at(2), + &output_height); + + ComputeTransposePadAndOutputShape( + W, + p_strides[1], + kernel_shape[1], + p_dilations[1], + p_output_padding[1], + auto_pad, + &p_pads->at(1), + &p_pads->at(3), + &output_width); + + output_shape_p->insert(output_shape_p->begin(), {N, output_channel, output_height, output_width}); + } + + const std::vector output_padding; + const std::vector output_shape; + +private: + void ComputeTransposePadAndOutputShape ( + const int64_t in_size, + const int64_t stride, + const int64_t kernel, + const int64_t dilation, + const int64_t adj, + AutoPadType pad_type, + int64_t* pad_head, + int64_t* pad_tail, + int64_t* out_size) const { + if (*out_size != -1) { + ORT_ENFORCE(*out_size >= 0); + // total padding size + int64_t paddings = std::max(0, (in_size - 1) * stride + kernel + dilation - 1 + adj - *out_size); + if (pad_type == AutoPadType::SAME_UPPER) { // pad more on head when paddings are odd. + *pad_head = paddings - paddings / 2; + *pad_tail = paddings / 2; + } else { + // for pad_type is NOTSET, SAME_LOWER or VALID + // set pad_head as paddings/2, pad_tail as paddings-paddings/2. + // That said, we pad more on tail when paddings are odd. + *pad_head = paddings / 2; + *pad_tail = paddings - paddings / 2; + } + return; + } + if (pad_type != AutoPadType::NOTSET) { + switch (pad_type) { + // We handle cases of AutoPadType::VALID and AutoPadType::SAME_UPPER/LOWER, + // the same way + case AutoPadType::VALID: + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: + *pad_head = 0; + *pad_tail = 0; + *out_size = (in_size - 1) * stride + kernel + dilation - 1 + adj; + break; + default: + throw NotImplementedException("pad type not supported"); + } + } else { + *out_size = + (in_size - 1) * stride + kernel + dilation - 1 + adj - *pad_head - *pad_tail; + } + } +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc index 78a53679325e8..6df671c932085 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc @@ -63,20 +63,20 @@ Status QLinearConv::Compute(OpKernelContext* context) const { const int64_t N = X->Shape()[0]; const int64_t C = X->Shape()[1]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - std::vector pads(pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_shape.size() * 2, 0); } - std::vector dilations(dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_shape.size(), 1); } - std::vector strides(strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_shape.size(), 1); } @@ -84,7 +84,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); @@ -97,12 +97,12 @@ Status QLinearConv::Compute(OpKernelContext* context) const { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t X_offset = C / conv_attrs_.group * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group; + const int64_t W_offset = W->Shape().Size() / conv_attrs_.group; + const int64_t kernel_dim = C / conv_attrs_.group * kernel_size; const int64_t col_buffer_size = kernel_dim * output_image_size; - const int bias_offset = static_cast(M / group_); + const int bias_offset = static_cast(M / conv_attrs_.group); auto col_data = alloc->Alloc(sizeof(uint8_t) * col_buffer_size); BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); @@ -114,7 +114,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { output_shape.GetDims().end()); for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { + for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { math::Im2colNd()( Xdata + group_id * X_offset, image_shape.GetDims().data(), @@ -137,7 +137,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { *filter_offset->template Data(), *input_offset->template Data(), *result_offset->template Data(), - static_cast(M / group_), + static_cast(M / conv_attrs_.group), static_cast(output_image_size), static_cast(kernel_dim), integer_multiplier, @@ -145,8 +145,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { bias == nullptr ? nullptr : bias->template Data() + group_id * bias_offset); } - Xdata += X_offset * group_; - Ydata += Y_offset * group_; + Xdata += X_offset * conv_attrs_.group; + Ydata += Y_offset * conv_attrs_.group; } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.h b/onnxruntime/core/providers/cpu/nn/qlinearconv.h index 9179da587c1f4..72d7741bb4788 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.h +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.h @@ -3,15 +3,18 @@ #pragma once -#include "core/providers/cpu/nn/conv_base.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" #include "core/util/gemmlowp_common.h" namespace onnxruntime { -class QLinearConv : public OpKernel, public ConvBase { +class QLinearConv : public OpKernel { public: - explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) { + explicit QLinearConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { } - Status Compute(OpKernelContext* context) const override; + Status Compute(OpKernelContext* context) const override; + + ConvAttributes conv_attrs_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index be1c24ebed1b5..5360be1832499 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -59,27 +59,28 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { const int64_t N = X->Shape()[0]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); auto rank = kernel_shape.size(); - std::vector pads(pads_); + std::vector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(rank * 2, 0); } - std::vector dilations(dilations_); + std::vector dilations(conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(rank, 1); } - std::vector strides(strides_); + std::vector strides(conv_attrs_.strides); if (strides.empty()) { strides.resize(rank, 1); } std::vector y_dims; y_dims.insert(y_dims.begin(), {N, M}); - ORT_RETURN_IF_ERROR(InferOutputShape(x_shape.Slice(2), kernel_shape, strides, dilations, &pads, &y_dims)); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(x_shape.Slice(2), kernel_shape, + strides, dilations, &pads, &y_dims)); s_.y_dims = y_dims; std::vector x_dims_cudnn = x_dims; @@ -102,8 +103,9 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, mode, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(group_))); + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + mode, CudnnTensor::GetDataType())); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(conv_attrs_.group))); if (has_bias) { const Tensor* B = context->Input(2); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index f6fe32531a52f..c24e9a9b6e31b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -7,7 +7,7 @@ #include "core/platform/ort_mutex.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/cudnn_common.h" -#include "core/providers/cpu/nn/conv_base.h" +#include "core/providers/cpu/nn/conv_attributes.h" #include namespace onnxruntime { @@ -142,20 +142,22 @@ enum : size_t { }; template -class Conv : public CudaKernel, public ConvBase { +class Conv : public CudaKernel { public: - Conv(const OpKernelInfo& info) : CudaKernel(info), ConvBase(info) { - auto pads_size = pads_.size(); + Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { + auto pads_size = pads.size(); ORT_ENFORCE(pads_size % 2 == 0); auto rank = pads_size / 2; for (size_t i = 0; i < rank; i++) { - ORT_ENFORCE(pads_[i] == pads_[i + rank], "cudnn only supports symmetric padding"); + ORT_ENFORCE(pads[i] == pads[i + rank], "cudnn only supports symmetric padding"); } } Status ComputeInternal(OpKernelContext* context) const override; private: + ConvAttributes conv_attrs_; + mutable CudnnConvState s_; }; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 9c5865c78fc6d..3f4e140842ef0 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -59,7 +59,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(context, has_bias, p, dynamic_padding)); + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); const auto& y_dims = p.Y->Shape().GetDims(); s_.y_dims = y_dims; @@ -71,8 +71,10 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, mode, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(group_))); + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, + p.dilations, mode, CudnnTensor::GetDataType())); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, + gsl::narrow_cast(conv_transpose_attrs_.group_))); if (has_bias) { const auto& b_shape = p.B->Shape(); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 9030e8f856f25..5bc2afb9ed425 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -4,20 +4,22 @@ #pragma once #include "core/providers/cuda/cudnn_common.h" -#include "core/providers/cpu/nn/conv_transpose.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "conv.h" namespace onnxruntime { namespace cuda { template -class ConvTranspose : public CudaKernel, public ConvTransposeBase { +class ConvTranspose : public CudaKernel { public: - ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), ConvTransposeBase(info){}; + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info){}; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; private: + ConvTransposeAttributes conv_transpose_attrs_; + mutable CudnnConvState s_; }; diff --git a/onnxruntime/core/providers/mkldnn/nn/conv.cc b/onnxruntime/core/providers/mkldnn/nn/conv.cc index ab1827092d061..e8e25b59bd4f7 100644 --- a/onnxruntime/core/providers/mkldnn/nn/conv.cc +++ b/onnxruntime/core/providers/mkldnn/nn/conv.cc @@ -265,12 +265,12 @@ Status Conv::Compute(OpKernelContext* context) const { const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; const int64_t N = X->Shape()[0]; const int64_t M = W->Shape()[0]; - const int group_mkl = static_cast(onnxruntime::ConvBase::group_); + const int group_mkl = static_cast(this->conv_attrs_.group); - ORT_RETURN_IF_ERROR(onnxruntime::ConvBase::ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(this->conv_attrs_.ValidateInputShape(X, W)); std::vector kernel_shape; - ORT_RETURN_IF_ERROR(onnxruntime::ConvBase::ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_RETURN_IF_ERROR(this->conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); const size_t kernel_rank = kernel_shape.size(); if (kernel_rank > 3) { @@ -292,15 +292,15 @@ Status Conv::Compute(OpKernelContext* context) const { } } - std::vector pads(onnxruntime::ConvBase::pads_); + std::vector pads(this->conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_rank * 2, 0); } - std::vector dilations(onnxruntime::ConvBase::dilations_); + std::vector dilations(this->conv_attrs_.dilations); if (dilations.empty()) { dilations.resize(kernel_rank, 1); } - std::vector strides(onnxruntime::ConvBase::strides_); + std::vector strides(this->conv_attrs_.strides); if (strides.empty()) { strides.resize(kernel_rank, 1); } @@ -308,7 +308,8 @@ Status Conv::Compute(OpKernelContext* context) const { std::vector Y_dims; Y_dims.insert(Y_dims.begin(), {N, M}); TensorShape input_shape = X->Shape().Slice(2); - ORT_RETURN_IF_ERROR(onnxruntime::ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + ORT_RETURN_IF_ERROR(this->conv_attrs_.InferOutputShape(input_shape, kernel_shape, + strides, dilations, &pads, &Y_dims)); Tensor* Y = context->Output(0, TensorShape(Y_dims)); TensorShape output_shape = Y->Shape().Slice(2); From b3577e9e010f54ba34e44c3ffc91644e3ec4c744 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 26 Sep 2019 01:45:12 -0700 Subject: [PATCH 2/2] fixed cuda build issue --- onnxruntime/core/providers/cuda/nn/conv.h | 4 ++-- onnxruntime/core/providers/cuda/nn/conv_transpose.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index c24e9a9b6e31b..ca3266c811a92 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -145,11 +145,11 @@ template class Conv : public CudaKernel { public: Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { - auto pads_size = pads.size(); + auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); auto rank = pads_size / 2; for (size_t i = 0; i < rank; i++) { - ORT_ENFORCE(pads[i] == pads[i + rank], "cudnn only supports symmetric padding"); + ORT_ENFORCE(conv_attrs_.pads[i] == conv_attrs_.pads[i + rank], "cudnn only supports symmetric padding"); } } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 3f4e140842ef0..3e51d36223ac1 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -58,7 +58,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ s_.cached_benchmark_results.clear(); } - Prepare p; + ConvTransposeAttributes::Prepare p; ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); const auto& y_dims = p.Y->Shape().GetDims(); @@ -74,7 +74,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, mode, CudnnTensor::GetDataType())); CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, - gsl::narrow_cast(conv_transpose_attrs_.group_))); + gsl::narrow_cast(conv_transpose_attrs_.group))); if (has_bias) { const auto& b_shape = p.B->Shape();