2121
2222using anakin::graph::GraphGlobalMem;
2323using anakin::AK_FLOAT;
24- using anakin::saber::NV;
2524using anakin::saber::Shape;
2625
2726namespace paddle {
2827namespace inference {
2928namespace anakin {
3029
31- void BatchNormOpConverter::operator ()( const framework::proto::OpDesc &op,
32- const framework::BlockDesc &block_desc,
33- const framework::Scope &scope ,
34- bool test_mode) {
30+ template < typename TargetT>
31+ void BatchNormOpConverter<TargetT>:: operator ()(
32+ const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc ,
33+ const framework::Scope &scope, bool test_mode) {
3534 framework::OpDesc op_desc (op, nullptr );
3635 PADDLE_ENFORCE_EQ (op_desc.Output (" Y" ).size (), 1 );
3736 std::map<std::string, std::string> inputs;
@@ -48,9 +47,9 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
4847
4948 auto bn_op_name = op_name + " :bn" ;
5049 auto bn_output = bn_op_name + " _output" ;
51- engine_->AddOp (bn_op_name, " BatchNorm" , {inputs[" X" ]}, {bn_output});
52- engine_->AddOpAttr (bn_op_name, " epsilon" , epsilon);
53- engine_->AddOpAttr (bn_op_name, " momentum" , static_cast <float >(1.0 ));
50+ this -> engine_ ->AddOp (bn_op_name, " BatchNorm" , {inputs[" X" ]}, {bn_output});
51+ this -> engine_ ->AddOpAttr (bn_op_name, " epsilon" , epsilon);
52+ this -> engine_ ->AddOpAttr (bn_op_name, " momentum" , static_cast <float >(1.0 ));
5453
5554 auto scale_op_name = op_name + " :scale" ;
5655 auto get_lod_tensor = [this , &scope, &op_name](const std::string &var_name,
@@ -81,48 +80,54 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
8180 Shape shape1 (fill_shape (4 , framework::vectorize2int (mean_t .dims ())));
8281 Shape shape2 (fill_shape (4 , framework::vectorize2int (variance_t .dims ())));
8382 auto *weight1 =
84- GraphGlobalMem<NV >::Global ().template new_block <AK_FLOAT>(shape1);
83+ GraphGlobalMem<TargetT >::Global ().template new_block <AK_FLOAT>(shape1);
8584 auto *mean_data = static_cast <float *>(weight1->h_tensor ().mutable_data ());
8685 std::copy_n (mean_t .data <float >(), mean_t .numel (), mean_data);
87- engine_->AddOpAttr (bn_op_name, " weight_1" , *weight1);
86+ this -> engine_ ->AddOpAttr (bn_op_name, " weight_1" , *weight1);
8887
8988 auto *weight2 =
90- GraphGlobalMem<NV >::Global ().template new_block <AK_FLOAT>(shape2);
89+ GraphGlobalMem<TargetT >::Global ().template new_block <AK_FLOAT>(shape2);
9190 auto *variance_data =
9291 static_cast <float *>(weight2->h_tensor ().mutable_data ());
9392 std::copy_n (variance_t .data <float >(), variance_t .numel (), variance_data);
94- engine_->AddOpAttr (bn_op_name, " weight_2" , *weight2);
93+ this -> engine_ ->AddOpAttr (bn_op_name, " weight_2" , *weight2);
9594
9695 Shape shape3 (std::vector<int >({1 , 1 , 1 , 1 }));
9796 auto *weight3 =
98- GraphGlobalMem<NV >::Global ().template new_block <AK_FLOAT>(shape3);
97+ GraphGlobalMem<TargetT >::Global ().template new_block <AK_FLOAT>(shape3);
9998 auto *alpha_data = static_cast <float *>(weight3->h_tensor ().mutable_data ());
10099 float weight3_data[] = {1 };
101100 std::copy (std::begin (weight3_data), std::end (weight3_data), alpha_data);
102- engine_->AddOpAttr (bn_op_name, " weight_3" , *weight3);
101+ this -> engine_ ->AddOpAttr (bn_op_name, " weight_3" , *weight3);
103102
104103 Shape scale_shape (fill_shape (4 , framework::vectorize2int (scale_t .dims ())));
105- auto *scale =
106- GraphGlobalMem<NV>:: Global (). template new_block <AK_FLOAT>( scale_shape);
104+ auto *scale = GraphGlobalMem<TargetT>:: Global (). template new_block <AK_FLOAT>(
105+ scale_shape);
107106 auto *scale_data = static_cast <float *>(scale->h_tensor ().mutable_data ());
108107 std::copy_n (scale_t .data <float >(), scale_t .numel (), scale_data);
109108
110109 Shape bias_shape (fill_shape (4 , framework::vectorize2int (bias_t .dims ())));
111- auto *bias =
112- GraphGlobalMem<NV>:: Global (). template new_block <AK_FLOAT>( bias_shape);
110+ auto *bias = GraphGlobalMem<TargetT>:: Global (). template new_block <AK_FLOAT>(
111+ bias_shape);
113112 auto *bias_data = static_cast <float *>(bias->h_tensor ().mutable_data ());
114113 std::copy_n (bias_t .data <float >(), bias_t .numel (), bias_data);
115114
116- engine_->AddOp (scale_op_name, " Scale" , {bn_output}, {output});
117- engine_->AddOpAttr (scale_op_name, " axis" , 1 );
118- engine_->AddOpAttr (scale_op_name, " num_axes" , 1 );
119- engine_->AddOpAttr (scale_op_name, " bias_term" , true );
120- engine_->AddOpAttr (scale_op_name, " weight_1" , *scale);
121- engine_->AddOpAttr (scale_op_name, " weight_2" , *bias);
115+ this -> engine_ ->AddOp (scale_op_name, " Scale" , {bn_output}, {output});
116+ this -> engine_ ->AddOpAttr (scale_op_name, " axis" , 1 );
117+ this -> engine_ ->AddOpAttr (scale_op_name, " num_axes" , 1 );
118+ this -> engine_ ->AddOpAttr (scale_op_name, " bias_term" , true );
119+ this -> engine_ ->AddOpAttr (scale_op_name, " weight_1" , *scale);
120+ this -> engine_ ->AddOpAttr (scale_op_name, " weight_2" , *bias);
122121}
123122
124123} // namespace anakin
125124} // namespace inference
126125} // namespace paddle
127126
128- REGISTER_ANAKIN_OP_CONVERTER (batch_norm, BatchNormOpConverter);
127+ #ifdef PADDLE_WITH_CUDA
128+ REGISTER_CUDA_ANAKIN_OP_CONVERTER (batch_norm,
129+ BatchNormOpConverter<::anakin::saber::NV>);
130+ #endif
131+
132+ REGISTER_CPU_ANAKIN_OP_CONVERTER (batch_norm,
133+ BatchNormOpConverter<::anakin::saber::X86>);
0 commit comments