Skip to content

Commit a5a57a4

Browse files
yangchen-MSKeDengMS
authored andcommitted
Changed ConvBase into a class member variable (#1927)
* Changed ConvBase into a class member variable Currently, all Conv-family classes inherit from both ConvBase and OpKernel. Since what ConvBase provides is all about processing convolution attributes, it's more natural to move it as a class member variable. This change renamed ConvBase to ConvAttributes and moved it into a separate file conv_attributes by its own. Instead of inheriting from ConvBase, now each Conv-related class has a class member variable that is of type ConvAttributes. Hence, we removed unecessary multiple inheritance and increase composibility. More importantly, the change made it possible for some other providers such as Nuphar be able to re-use the functionalities provided by ConvAttributes class. Note that we also made similar changes to ConvTransposeBase. * fixed cuda build issue
1 parent 7e22ed4 commit a5a57a4

File tree

17 files changed

+442
-403
lines changed

17 files changed

+442
-403
lines changed

onnxruntime/contrib_ops/cpu/nchwc_ops.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
9898
const auto* B = context->Input<Tensor>(2);
9999
const auto* Sum = context->Input<Tensor>(3);
100100

101-
ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));
101+
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
102102

103103
const auto& X_shape = X->Shape();
104104
const auto& W_shape = W->Shape();
@@ -108,28 +108,28 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
108108
ORT_ENFORCE((static_cast<size_t>(X_shape[1]) < nchwc_block_size) || ((X_shape[1] % nchwc_block_size) == 0));
109109

110110
std::vector<int64_t> kernel_shape;
111-
ORT_RETURN_IF_ERROR(ConvBase::ComputeKernelShape(W_shape, kernel_shape));
111+
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W_shape, kernel_shape));
112112
if (kernel_shape.size() != 2) {
113113
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported convolution size.");
114114
}
115115

116-
std::vector<int64_t> pads(ConvBase::pads_);
116+
std::vector<int64_t> pads(conv_attrs_.pads);
117117
if (pads.empty()) {
118118
pads.resize(kernel_shape.size() * 2, 0);
119119
}
120-
std::vector<int64_t> dilations(ConvBase::dilations_);
120+
std::vector<int64_t> dilations(conv_attrs_.dilations);
121121
if (dilations.empty()) {
122122
dilations.resize(kernel_shape.size(), 1);
123123
}
124-
std::vector<int64_t> strides(ConvBase::strides_);
124+
std::vector<int64_t> strides(conv_attrs_.strides);
125125
if (strides.empty()) {
126126
strides.resize(kernel_shape.size(), 1);
127127
}
128128

129129
std::vector<int64_t> Y_dims;
130130
Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
131131
TensorShape input_shape = X->Shape().Slice(2);
132-
ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
132+
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
133133
auto* Y = context->Output(0, Y_dims);
134134
auto* y_data = Y->template MutableData<float>();
135135

@@ -151,7 +151,7 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
151151
pads.data(),
152152
strides.data(),
153153
Y_dims.data(),
154-
static_cast<size_t>(ConvBase::group_),
154+
static_cast<size_t>(conv_attrs_.group),
155155
X->template Data<float>(),
156156
W->template Data<float>(),
157157
B != nullptr ? B->template Data<float>() : nullptr,

onnxruntime/contrib_ops/cpu/nchwc_ops.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include "core/common/common.h"
77
#include "core/framework/op_kernel.h"
8-
#include "core/providers/cpu/nn/conv_base.h"
8+
#include "core/providers/cpu/nn/conv_attributes.h"
99
#include "core/providers/cpu/nn/pool.h"
1010
#include "contrib_ops/cpu/fused_activation.h"
1111

@@ -35,15 +35,17 @@ class ReorderOutput : public OpKernel {
3535
int64_t channels_;
3636
};
3737

38-
class NchwcConv : public OpKernel, public ConvBase {
38+
class NchwcConv : public OpKernel {
3939
public:
40-
NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
40+
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
4141
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
4242
}
4343

4444
Status Compute(OpKernelContext* context) const override;
4545

4646
private:
47+
ConvAttributes conv_attrs_;
48+
4749
MLAS_ACTIVATION activation_;
4850
};
4951

onnxruntime/core/providers/cpu/nn/conv.cc

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,39 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
3434
const int64_t N = X->Shape()[0];
3535
const int64_t C = X->Shape()[1];
3636
const int64_t M = W->Shape()[0];
37-
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
37+
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
3838

3939
std::vector<int64_t> kernel_shape;
40-
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
40+
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
4141

4242
bool Is2DKernel = kernel_shape.size() == 2;
43-
std::vector<int64_t> pads(pads_);
43+
std::vector<int64_t> pads(conv_attrs_.pads);
4444
if (pads.empty()) {
4545
pads.resize(kernel_shape.size() * 2, 0);
4646
}
47-
std::vector<int64_t> dilations(dilations_);
47+
std::vector<int64_t> dilations(conv_attrs_.dilations);
4848
if (dilations.empty()) {
4949
dilations.resize(kernel_shape.size(), 1);
5050
}
51-
std::vector<int64_t> strides(strides_);
51+
std::vector<int64_t> strides(conv_attrs_.strides);
5252
if (strides.empty()) {
5353
strides.resize(kernel_shape.size(), 1);
5454
}
5555

5656
std::vector<int64_t> Y_dims;
5757
Y_dims.insert(Y_dims.begin(), {N, M});
5858
TensorShape input_shape = X->Shape().Slice(2);
59-
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
59+
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
6060
Tensor* Y = context->Output(0, TensorShape(Y_dims));
6161
TensorShape output_shape = Y->Shape().Slice(2);
6262

6363
const int64_t input_image_size = input_shape.Size();
6464
const int64_t output_image_size = output_shape.Size();
6565
const int64_t kernel_size = TensorShape(kernel_shape).Size();
66-
const int64_t X_offset = C / group_ * input_image_size;
67-
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
68-
const int64_t W_offset = W->Shape().Size() / group_;
69-
const int64_t kernel_dim = C / group_ * kernel_size;
66+
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
67+
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
68+
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
69+
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
7070
const int64_t col_buffer_size = kernel_dim * output_image_size;
7171

7272
AllocatorPtr alloc;
@@ -85,11 +85,11 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
8585
output_shape.GetDims().end());
8686

8787
for (int image_id = 0; image_id < N; ++image_id) {
88-
for (int group_id = 0; group_id < group_; ++group_id) {
88+
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
8989
if (Is2DKernel) {
9090
math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
9191
Xdata + group_id * X_offset,
92-
C / group_,
92+
C / conv_attrs_.group,
9393
input_shape[0],
9494
input_shape[1],
9595
kernel_shape[0],
@@ -122,7 +122,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
122122
math::Gemm<T>(
123123
CblasNoTrans,
124124
CblasNoTrans,
125-
M / group_,
125+
M / conv_attrs_.group,
126126
output_image_size,
127127
kernel_dim,
128128
1,
@@ -139,8 +139,8 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
139139
Ymatrix.rowwise() += Bvec.transpose();
140140
}
141141

142-
Xdata += X_offset * group_;
143-
Ydata += Y_offset * group_;
142+
Xdata += X_offset * conv_attrs_.group;
143+
Ydata += Y_offset * conv_attrs_.group;
144144
}
145145

146146
return Status::OK();
@@ -157,28 +157,28 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
157157
const int64_t N = X->Shape()[0];
158158
const int64_t C = X->Shape()[1];
159159
const int64_t M = W->Shape()[0];
160-
ORT_RETURN_IF_ERROR(ValidateInputShape(X, W));
160+
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
161161

162162
std::vector<int64_t> kernel_shape;
163-
ORT_RETURN_IF_ERROR(ComputeKernelShape(W->Shape(), kernel_shape));
163+
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
164164

165-
std::vector<int64_t> pads(pads_);
165+
std::vector<int64_t> pads(conv_attrs_.pads);
166166
if (pads.empty()) {
167167
pads.resize(kernel_shape.size() * 2, 0);
168168
}
169-
std::vector<int64_t> dilations(dilations_);
169+
std::vector<int64_t> dilations(conv_attrs_.dilations);
170170
if (dilations.empty()) {
171171
dilations.resize(kernel_shape.size(), 1);
172172
}
173-
std::vector<int64_t> strides(strides_);
173+
std::vector<int64_t> strides(conv_attrs_.strides);
174174
if (strides.empty()) {
175175
strides.resize(kernel_shape.size(), 1);
176176
}
177177

178178
std::vector<int64_t> Y_dims;
179179
Y_dims.insert(Y_dims.begin(), {N, M});
180180
TensorShape input_shape = X->Shape().Slice(2);
181-
ORT_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
181+
ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
182182
Tensor* Y = context->Output(0, TensorShape(Y_dims));
183183
TensorShape output_shape = Y->Shape().Slice(2);
184184

@@ -197,15 +197,15 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
197197
MlasConvPrepare(&Parameters,
198198
kernel_rank,
199199
static_cast<size_t>(N),
200-
static_cast<size_t>(group_),
201-
static_cast<size_t>(C / group_),
200+
static_cast<size_t>(conv_attrs_.group),
201+
static_cast<size_t>(C / conv_attrs_.group),
202202
input_shape.GetDims().data(),
203203
kernel_shape.data(),
204204
dilations.data(),
205205
pads.data(),
206206
strides.data(),
207207
output_shape.GetDims().data(),
208-
static_cast<size_t>(M / group_),
208+
static_cast<size_t>(M / conv_attrs_.group),
209209
&activation_,
210210
&WorkingBufferSize,
211211
tp);
@@ -224,10 +224,10 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
224224
const int64_t input_image_size = input_shape.Size();
225225
const int64_t output_image_size = output_shape.Size();
226226
const int64_t kernel_size = TensorShape(kernel_shape).Size();
227-
const int64_t X_offset = C / group_ * input_image_size;
228-
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_;
229-
const int64_t W_offset = W->Shape().Size() / group_;
230-
const int64_t kernel_dim = C / group_ * kernel_size;
227+
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
228+
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
229+
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
230+
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
231231
const int64_t col_buffer_size = kernel_dim * output_image_size;
232232

233233
auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size);
@@ -240,7 +240,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
240240
output_shape.GetDims().end());
241241

242242
for (int image_id = 0; image_id < N; ++image_id) {
243-
for (int group_id = 0; group_id < group_; ++group_id) {
243+
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
244244
math::Im2colNd<float, CPUMathUtil, StorageOrder::NCHW>()(
245245
Xdata + group_id * X_offset,
246246
image_shape.GetDims().data(),
@@ -257,7 +257,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
257257
math::Gemm<float>(
258258
CblasNoTrans,
259259
CblasNoTrans,
260-
M / group_,
260+
M / conv_attrs_.group,
261261
output_image_size,
262262
kernel_dim,
263263
1,
@@ -270,8 +270,8 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
270270

271271
MlasActivation(&activation_, Ydata, Bdata, M, output_image_size, output_image_size);
272272

273-
Xdata += X_offset * group_;
274-
Ydata += Y_offset * group_;
273+
Xdata += X_offset * conv_attrs_.group;
274+
Ydata += Y_offset * conv_attrs_.group;
275275
}
276276
}
277277

onnxruntime/core/providers/cpu/nn/conv.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,37 @@
33

44
#pragma once
55

6-
#include "core/providers/cpu/nn/conv_base.h"
6+
#include "core/framework/op_kernel.h"
7+
#include "core/providers/cpu/nn/conv_attributes.h"
78
#include "core/mlas/inc/mlas.h"
89

910
namespace onnxruntime {
1011

1112
template <typename T>
12-
class Conv : public OpKernel, public ConvBase {
13+
class Conv : public OpKernel {
1314
public:
14-
Conv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
15+
Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
1516
}
1617

1718
Status Compute(OpKernelContext* context) const override;
19+
20+
private:
21+
ConvAttributes conv_attrs_;
1822
};
1923

2024
template <>
21-
class Conv<float> : public OpKernel, public ConvBase {
25+
class Conv<float> : public OpKernel {
2226
public:
23-
Conv<float>(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
27+
Conv<float>(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
2428
activation_.ActivationKind = MlasIdentityActivation;
2529
}
2630

2731
Status Compute(OpKernelContext* context) const override;
2832

2933
protected:
3034
MLAS_ACTIVATION activation_;
35+
36+
ConvAttributes conv_attrs_;
3137
};
3238

3339
} // namespace onnxruntime

0 commit comments

Comments
 (0)