Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cpu/nchwc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
const auto* B = context->Input<Tensor>(2);
const auto* Sum = context->Input<Tensor>(3);

ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));

const auto& X_shape = X->Shape();
const auto& W_shape = W->Shape();
Expand All @@ -108,28 +108,28 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
ORT_ENFORCE((static_cast<size_t>(X_shape[1]) < nchwc_block_size) || ((X_shape[1] % nchwc_block_size) == 0));

std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(ConvBase::ComputeKernelShape(W_shape, kernel_shape));
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape));
if (kernel_shape.size() != 2) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported convolution size.");
}

std::vector<int64_t> pads(ConvBase::pads_);
std::vector<int64_t> pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(ConvBase::dilations_);
std::vector<int64_t> dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(ConvBase::strides_);
std::vector<int64_t> strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}

std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
auto* Y = context->Output(0, Y_dims);
auto* y_data = Y->template MutableData<float>();

Expand All @@ -151,7 +151,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
pads.data(),
strides.data(),
Y_dims.data(),
static_cast<size_t>(ConvBase::group_),
static_cast<size_t>(conv_attrs_.group),
X->template Data<float>(),
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/nchwc_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_base.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/cpu/nn/pool.h"
#include "contrib_ops/cpu/fused_activation.h"

Expand Down Expand Up @@ -35,15 +35,17 @@ class ReorderOutput : public OpKernel {
int64_t channels_;
};

class NchwcConv : public OpKernel, public ConvBase {
class NchwcConv : public OpKernel {
public:
NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
}

Status Compute(OpKernelContext* context) const override;

private:
ConvAttributes conv_attrs_;

MLAS_ACTIVATION activation_;
};

Expand Down
64 changes: 32 additions & 32 deletions onnxruntime/core/providers/cpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,39 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));

std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));

bool Is2DKernel = kernel_shape.size() == 2;
std::vector<int64_t> pads(pads_);
std::vector<int64_t> pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
std::vector<int64_t> dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
std::vector<int64_t> strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}

std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);

const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;

AllocatorPtr alloc;
Expand All @@ -85,11 +85,11 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
output_shape.GetDims().end());

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (Is2DKernel) {
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
Xdata + group_id * X_offset,
C / group_,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
Expand Down Expand Up @@ -122,7 +122,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
math::Gemm<T>(
CblasNoTrans,
CblasNoTrans,
M / group_,
M / conv_attrs_.group,
output_image_size,
kernel_dim,
1,
Expand All @@ -139,8 +139,8 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
Ymatrix.rowwise() += Bvec.transpose();
}

Xdata += X_offset * group_;
Ydata += Y_offset * group_;
Xdata += X_offset * conv_attrs_.group;
Ydata += Y_offset * conv_attrs_.group;
}

return Status::OK();
Expand All @@ -157,28 +157,28 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));

std::vector<int64_t> kernel_shape;
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));

std::vector<int64_t> pads(pads_);
std::vector<int64_t> pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
std::vector<int64_t> dilations(dilations_);
std::vector<int64_t> dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
std::vector<int64_t> strides(strides_);
std::vector<int64_t> strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}

std::vector<int64_t> Y_dims;
Y_dims.insert(Y_dims.begin(), {N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);

Expand All @@ -197,15 +197,15 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
MlasConvPrepare(&Parameters,
kernel_rank,
static_cast<size_t>(N),
static_cast<size_t>(group_),
static_cast<size_t>(C / group_),
static_cast<size_t>(conv_attrs_.group),
static_cast<size_t>(C / conv_attrs_.group),
input_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
output_shape.GetDims().data(),
static_cast<size_t>(M / group_),
static_cast<size_t>(M / conv_attrs_.group),
&activation_,
&WorkingBufferSize,
tp);
Expand All @@ -224,10 +224,10 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / group_ * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
const int64_t W_offset = W->Shape().Size() / group_;
const int64_t kernel_dim = C / group_ * kernel_size;
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;

auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
Expand All @@ -240,7 +240,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
output_shape.GetDims().end());

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>()(
Xdata + group_id * X_offset,
image_shape.GetDims().data(),
Expand All @@ -257,7 +257,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
math::Gemm<float>(
CblasNoTrans,
CblasNoTrans,
M / group_,
M / conv_attrs_.group,
output_image_size,
kernel_dim,
1,
Expand All @@ -270,8 +270,8 @@ Status Conv<float>::Compute(OpKernelContext* context) const {

MlasActivation(&activation_, Ydata, Bdata, M, output_image_size, output_image_size);

Xdata += X_offset * group_;
Ydata += Y_offset * group_;
Xdata += X_offset * conv_attrs_.group;
Ydata += Y_offset * conv_attrs_.group;
}
}

Expand Down
16 changes: 11 additions & 5 deletions onnxruntime/core/providers/cpu/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,37 @@

#pragma once

#include "core/providers/cpu/nn/conv_base.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/mlas/inc/mlas.h"

namespace onnxruntime {

template <typename T>
class Conv : public OpKernel, public ConvBase {
class Conv : public OpKernel {
public:
Conv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
}

Status Compute(OpKernelContext* context) const override;

private:
ConvAttributes conv_attrs_;
};

template <>
class Conv<float> : public OpKernel, public ConvBase {
class Conv<float> : public OpKernel {
public:
Conv<float>(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
Conv<float>(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
activation_.ActivationKind = MlasIdentityActivation;
}

Status Compute(OpKernelContext* context) const override;

protected:
MLAS_ACTIVATION activation_;

ConvAttributes conv_attrs_;
};

} // namespace onnxruntime
Loading