@@ -69,7 +69,42 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
6969 }
7070 }
7171
72- bool modify_conv (Node* conv, Node* bn, Graph& graph) {
72+ void scale_by_dim (Tensor& W, Tensor& s, const int axis) {
73+ ONNX_ASSERT (W.sizes ().size () > 1 && s.sizes ().size () == 1 && s.sizes ()[0 ] == W.sizes ()[axis]);
74+ ONNX_ASSERT (s.elem_type () == W.elem_type ());
75+ const int64_t inner_size = W.size_from_dim (axis+1 );
76+ const int64_t outer_size = axis > 0 ? std::accumulate (W.sizes ().begin (), W.sizes ().begin () + axis, 1 , std::multiplies<int >()) : 1 ;
77+ const int64_t axis_size = W.sizes ()[axis];
78+
79+ #define DO_SCALE (TENSOR_TYPE ) \
80+ TENSOR_TYPE* ptr = W.data <TENSOR_TYPE>(); \
81+ const TENSOR_TYPE* s_ptr = s.data <TENSOR_TYPE>(); \
82+ int64_t counter = 0 ; \
83+ for (int64_t i = 0 ; i < outer_size; ++i) { \
84+ for (int64_t j = 0 ; j < axis_size; ++j) { \
85+ for (int64_t k = 0 ; k < inner_size; ++k) { \
86+ ptr[counter++] *= s_ptr[j]; \
87+ } \
88+ } \
89+ }
90+
91+ switch (s.elem_type ()) {
92+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
93+ DO_SCALE (float )
94+ break ;
95+ }
96+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
97+ DO_SCALE (double )
98+ break ;
99+ }
100+ default :
101+ TENSOR_ASSERTM (
102+ false , " Operation scale_by_dim not supported for data type %s" , to_string (W.elem_type ()).c_str ());
103+ }
104+ #undef DO_SCALE
105+ }
106+
107+ bool modify_conv (Node* conv, Node* bn, Graph& graph, const bool is_conv) {
73108 const auto & bn_inputs = bn->inputs ();
74109 const auto & conv_inputs = conv->inputs ();
75110 auto end_iter = graph.initializers ().end ();
@@ -136,7 +171,6 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
136171 var.add (eps); \
137172 var.sqrt (); \
138173 s.divide (var); \
139- W.scale_by_first_dim (s); \
140174 bc.subtract (m); \
141175 bc.multiply (s); \
142176 bc.add (bbn);
@@ -154,21 +188,38 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
154188 return false ;
155189 }
156190#undef DO_COMPUTATION
191+ if (is_conv) {
192+ scale_by_dim (W, s, 0 );
193+ } else {
194+ scale_by_dim (W, s, 1 );
195+ }
157196 replace_inputs (W, bc, conv, graph);
158197 return true ;
159198 }
160199
161- bool patternMatchPredicate (Node* node) override {
200+ inline bool matchConvBn (Node * node) {
162201 return node->kind () == kBatchNormalization &&
163202 node->inputs ()[0 ]->node ()->kind () == kConv ;
164203 }
204+
205+ inline bool matchConvTransposeBn (Node *node) {
206+ return node->kind () == kBatchNormalization &&
207+ node->inputs ()[0 ]->node ()->kind () == kConvTranspose ;
208+ }
209+
210+ bool patternMatchPredicate (Node *node) override {
211+ return matchConvBn (node) || matchConvTransposeBn (node);
212+ }
213+
165214 bool runTransform (Node* n, Graph& graph,
166215 NodeDestroyType& destroy_current) override {
216+ const bool is_conv = matchConvBn (n);
217+
167218 Node* bn = n;
168219 Node* conv = n->inputs ()[0 ]->node ();
169220 auto origInput = bn->inputs ()[0 ];
170221 if (origInput->uses ().size () > 1 || bn->outputs ().size () > 1 ||
171- !modify_conv (conv, bn, graph)) {
222+ !modify_conv (conv, bn, graph, is_conv )) {
172223 destroy_current = NodeDestroyType::DestroyZero;
173224 return false ;
174225 }
0 commit comments