@@ -34,39 +34,39 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
3434 const int64_t N = X->Shape ()[0 ];
3535 const int64_t C = X->Shape ()[1 ];
3636 const int64_t M = W->Shape ()[0 ];
37- ORT_RETURN_IF_ERROR (ValidateInputShape (X, W));
37+ ORT_RETURN_IF_ERROR (conv_attrs_. ValidateInputShape (X, W));
3838
3939 std::vector<int64_t > kernel_shape;
40- ORT_RETURN_IF_ERROR (ComputeKernelShape (W->Shape (), kernel_shape));
40+ ORT_RETURN_IF_ERROR (conv_attrs_. ComputeKernelShape (W->Shape (), kernel_shape));
4141
4242 bool Is2DKernel = kernel_shape.size () == 2 ;
43- std::vector<int64_t > pads (pads_ );
43+ std::vector<int64_t > pads (conv_attrs_. pads );
4444 if (pads.empty ()) {
4545 pads.resize (kernel_shape.size () * 2 , 0 );
4646 }
47- std::vector<int64_t > dilations (dilations_ );
47+ std::vector<int64_t > dilations (conv_attrs_. dilations );
4848 if (dilations.empty ()) {
4949 dilations.resize (kernel_shape.size (), 1 );
5050 }
51- std::vector<int64_t > strides (strides_ );
51+ std::vector<int64_t > strides (conv_attrs_. strides );
5252 if (strides.empty ()) {
5353 strides.resize (kernel_shape.size (), 1 );
5454 }
5555
5656 std::vector<int64_t > Y_dims;
5757 Y_dims.insert (Y_dims.begin (), {N, M});
5858 TensorShape input_shape = X->Shape ().Slice (2 );
59- ORT_RETURN_IF_ERROR (InferOutputShape (input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
59+ ORT_RETURN_IF_ERROR (conv_attrs_. InferOutputShape (input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
6060 Tensor* Y = context->Output (0 , TensorShape (Y_dims));
6161 TensorShape output_shape = Y->Shape ().Slice (2 );
6262
6363 const int64_t input_image_size = input_shape.Size ();
6464 const int64_t output_image_size = output_shape.Size ();
6565 const int64_t kernel_size = TensorShape (kernel_shape).Size ();
66- const int64_t X_offset = C / group_ * input_image_size;
67- const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / group_ ;
68- const int64_t W_offset = W->Shape ().Size () / group_ ;
69- const int64_t kernel_dim = C / group_ * kernel_size;
66+ const int64_t X_offset = C / conv_attrs_. group * input_image_size;
67+ const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / conv_attrs_. group ;
68+ const int64_t W_offset = W->Shape ().Size () / conv_attrs_. group ;
69+ const int64_t kernel_dim = C / conv_attrs_. group * kernel_size;
7070 const int64_t col_buffer_size = kernel_dim * output_image_size;
7171
7272 AllocatorPtr alloc;
@@ -85,11 +85,11 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
8585 output_shape.GetDims ().end ());
8686
8787 for (int image_id = 0 ; image_id < N; ++image_id) {
88- for (int group_id = 0 ; group_id < group_ ; ++group_id) {
88+ for (int group_id = 0 ; group_id < conv_attrs_. group ; ++group_id) {
8989 if (Is2DKernel) {
9090 math::Im2col<T, CPUMathUtil, StorageOrder::NCHW>(
9191 Xdata + group_id * X_offset,
92- C / group_ ,
92+ C / conv_attrs_. group ,
9393 input_shape[0 ],
9494 input_shape[1 ],
9595 kernel_shape[0 ],
@@ -122,7 +122,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
122122 math::Gemm<T>(
123123 CblasNoTrans,
124124 CblasNoTrans,
125- M / group_ ,
125+ M / conv_attrs_. group ,
126126 output_image_size,
127127 kernel_dim,
128128 1 ,
@@ -139,8 +139,8 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
139139 Ymatrix.rowwise () += Bvec.transpose ();
140140 }
141141
142- Xdata += X_offset * group_ ;
143- Ydata += Y_offset * group_ ;
142+ Xdata += X_offset * conv_attrs_. group ;
143+ Ydata += Y_offset * conv_attrs_. group ;
144144 }
145145
146146 return Status::OK ();
@@ -157,28 +157,28 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
157157 const int64_t N = X->Shape ()[0 ];
158158 const int64_t C = X->Shape ()[1 ];
159159 const int64_t M = W->Shape ()[0 ];
160- ORT_RETURN_IF_ERROR (ValidateInputShape (X, W));
160+ ORT_RETURN_IF_ERROR (conv_attrs_. ValidateInputShape (X, W));
161161
162162 std::vector<int64_t > kernel_shape;
163- ORT_RETURN_IF_ERROR (ComputeKernelShape (W->Shape (), kernel_shape));
163+ ORT_RETURN_IF_ERROR (conv_attrs_. ComputeKernelShape (W->Shape (), kernel_shape));
164164
165- std::vector<int64_t > pads (pads_ );
165+ std::vector<int64_t > pads (conv_attrs_. pads );
166166 if (pads.empty ()) {
167167 pads.resize (kernel_shape.size () * 2 , 0 );
168168 }
169- std::vector<int64_t > dilations (dilations_ );
169+ std::vector<int64_t > dilations (conv_attrs_. dilations );
170170 if (dilations.empty ()) {
171171 dilations.resize (kernel_shape.size (), 1 );
172172 }
173- std::vector<int64_t > strides (strides_ );
173+ std::vector<int64_t > strides (conv_attrs_. strides );
174174 if (strides.empty ()) {
175175 strides.resize (kernel_shape.size (), 1 );
176176 }
177177
178178 std::vector<int64_t > Y_dims;
179179 Y_dims.insert (Y_dims.begin (), {N, M});
180180 TensorShape input_shape = X->Shape ().Slice (2 );
181- ORT_RETURN_IF_ERROR (InferOutputShape (input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
181+ ORT_RETURN_IF_ERROR (conv_attrs_. InferOutputShape (input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
182182 Tensor* Y = context->Output (0 , TensorShape (Y_dims));
183183 TensorShape output_shape = Y->Shape ().Slice (2 );
184184
@@ -197,15 +197,15 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
197197 MlasConvPrepare (&Parameters,
198198 kernel_rank,
199199 static_cast <size_t >(N),
200- static_cast <size_t >(group_ ),
201- static_cast <size_t >(C / group_ ),
200+ static_cast <size_t >(conv_attrs_. group ),
201+ static_cast <size_t >(C / conv_attrs_. group ),
202202 input_shape.GetDims ().data (),
203203 kernel_shape.data (),
204204 dilations.data (),
205205 pads.data (),
206206 strides.data (),
207207 output_shape.GetDims ().data (),
208- static_cast <size_t >(M / group_ ),
208+ static_cast <size_t >(M / conv_attrs_. group ),
209209 &activation_,
210210 &WorkingBufferSize,
211211 tp);
@@ -224,10 +224,10 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
224224 const int64_t input_image_size = input_shape.Size ();
225225 const int64_t output_image_size = output_shape.Size ();
226226 const int64_t kernel_size = TensorShape (kernel_shape).Size ();
227- const int64_t X_offset = C / group_ * input_image_size;
228- const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / group_ ;
229- const int64_t W_offset = W->Shape ().Size () / group_ ;
230- const int64_t kernel_dim = C / group_ * kernel_size;
227+ const int64_t X_offset = C / conv_attrs_. group * input_image_size;
228+ const int64_t Y_offset = Y->Shape ().Size () / Y->Shape ()[0 ] / conv_attrs_. group ;
229+ const int64_t W_offset = W->Shape ().Size () / conv_attrs_. group ;
230+ const int64_t kernel_dim = C / conv_attrs_. group * kernel_size;
231231 const int64_t col_buffer_size = kernel_dim * output_image_size;
232232
233233 auto col_data = alloc->Alloc (sizeof (float ) * col_buffer_size);
@@ -240,7 +240,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
240240 output_shape.GetDims ().end ());
241241
242242 for (int image_id = 0 ; image_id < N; ++image_id) {
243- for (int group_id = 0 ; group_id < group_ ; ++group_id) {
243+ for (int group_id = 0 ; group_id < conv_attrs_. group ; ++group_id) {
244244 math::Im2colNd<float , CPUMathUtil, StorageOrder::NCHW>()(
245245 Xdata + group_id * X_offset,
246246 image_shape.GetDims ().data (),
@@ -257,7 +257,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
257257 math::Gemm<float >(
258258 CblasNoTrans,
259259 CblasNoTrans,
260- M / group_ ,
260+ M / conv_attrs_. group ,
261261 output_image_size,
262262 kernel_dim,
263263 1 ,
@@ -270,8 +270,8 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
270270
271271 MlasActivation (&activation_, Ydata, Bdata, M, output_image_size, output_image_size);
272272
273- Xdata += X_offset * group_ ;
274- Ydata += Y_offset * group_ ;
273+ Xdata += X_offset * conv_attrs_. group ;
274+ Ydata += Y_offset * conv_attrs_. group ;
275275 }
276276 }
277277
0 commit comments