@@ -48,44 +48,7 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
4848
4949void PriorBoxCompute::PrepareForRun () {
5050 auto & param = this ->Param <param_t >();
51- std::vector<float > min_size = param.min_sizes ;
52- std::vector<float > max_size = param.max_sizes ;
53- std::vector<float > aspect_ratio = param.aspect_ratios ;
5451 std::vector<float > variance = param.variances_ ;
55- std::vector<float > aspect_ratios_vec;
56- bool is_flip = param.flip ;
57-
58- ExpandAspectRatios (aspect_ratio, is_flip, &aspect_ratios_vec);
59- prior_num = aspect_ratios_vec.size () * min_size.size ();
60- prior_num += max_size.size ();
61-
62- CHECK_LE (aspect_ratios_vec.size (), 16 );
63- xpu_aspect_ratios_guard_ =
64- TargetWrapperXPU::MallocScratchPad (16 * sizeof (float ));
65- XPU_CALL (xpu_memcpy (xpu_aspect_ratios_guard_->addr_ ,
66- aspect_ratios_vec.data (),
67- aspect_ratios_vec.size () * sizeof (float ),
68- XPUMemcpyKind::XPU_HOST_TO_DEVICE));
69- ar_num = aspect_ratios_vec.size ();
70-
71- CHECK_LE (min_size.size (), 8 );
72- xpu_min_sizes_guard_ = TargetWrapperXPU::MallocScratchPad (8 * sizeof (float ));
73- XPU_CALL (xpu_memcpy (xpu_min_sizes_guard_->addr_ ,
74- min_size.data (),
75- min_size.size () * sizeof (float ),
76- XPUMemcpyKind::XPU_HOST_TO_DEVICE));
77- min_size_num = min_size.size ();
78-
79- max_size_num = max_size.size ();
80- if (max_size_num > 0 ) {
81- CHECK_LE (max_size.size (), 8 );
82- xpu_max_sizes_guard_ =
83- TargetWrapperXPU::MallocScratchPad (8 * sizeof (float ));
84- XPU_CALL (xpu_memcpy (xpu_max_sizes_guard_->addr_ ,
85- max_size.data (),
86- max_size.size () * sizeof (float ),
87- XPUMemcpyKind::XPU_HOST_TO_DEVICE));
88- }
8952
9053 CHECK_EQ (variance.size (), 4 );
9154 variance_xpu_guard_ = TargetWrapperXPU::MallocScratchPad (4 * sizeof (float ));
@@ -99,6 +62,21 @@ void PriorBoxCompute::Run() {
9962 auto & param = this ->Param <param_t >();
10063 auto & ctx = this ->ctx_ ->As <XPUContext>();
10164
65+ std::vector<float > aspect_ratio = param.aspect_ratios ;
66+ std::vector<float > aspect_ratios_vec;
67+ bool is_flip = param.flip ;
68+ ExpandAspectRatios (aspect_ratio, is_flip, &aspect_ratios_vec);
69+ CHECK_LE (aspect_ratios_vec.size (), 16 );
70+ prior_num = aspect_ratios_vec.size () * param.min_sizes .size ();
71+ prior_num += param.max_sizes .size ();
72+ ar_num = aspect_ratios_vec.size ();
73+ min_size_num = param.min_sizes .size ();
74+ max_size_num = param.max_sizes .size ();
75+ CHECK_LE (min_size_num, 8 );
76+ if (max_size_num > 0 ) {
77+ CHECK_LE (max_size_num, 8 );
78+ }
79+
10280 bool is_clip = param.clip ;
10381 auto image_dims = param.image ->dims ();
10482 int im_width = static_cast <int >(image_dims[3 ]);
@@ -125,31 +103,21 @@ void PriorBoxCompute::Run() {
125103 param.variances ->Resize ({height, width, prior_num, 4 });
126104
127105 bool min_max_aspect_ratios_order = param.min_max_aspect_ratios_order ;
128- float * xpu_aspect_ratios =
129- reinterpret_cast <float *>(xpu_aspect_ratios_guard_->addr_ );
130- float * xpu_min_sizes = reinterpret_cast <float *>(xpu_min_sizes_guard_->addr_ );
131- float * xpu_max_sizes =
132- (max_size_num > 0 )
133- ? (reinterpret_cast <float *>(xpu_max_sizes_guard_->addr_ ))
134- : nullptr ;
135-
136- int r = xdnn::prior_box_gen (ctx.GetRawContext (),
137- param.boxes ->mutable_data <float >(TARGET (kXPU )),
138- xpu_aspect_ratios,
139- height,
140- width,
141- im_height,
142- im_width,
143- ar_num,
144- offset,
145- step_width,
146- step_height,
147- xpu_min_sizes,
148- xpu_max_sizes,
149- min_size_num,
150- max_size_num,
151- is_clip,
152- min_max_aspect_ratios_order);
106+ int r = xdnn::gen_prior_box<float >(
107+ ctx.GetRawContext (),
108+ param.boxes ->mutable_data <float >(TARGET (kXPU )),
109+ {aspect_ratios_vec.data (), ar_num, nullptr },
110+ {param.min_sizes .data (), min_size_num, nullptr },
111+ {param.max_sizes .data (), max_size_num, nullptr },
112+ height,
113+ width,
114+ im_height,
115+ im_width,
116+ offset,
117+ step_width,
118+ step_height,
119+ is_clip,
120+ min_max_aspect_ratios_order);
153121 CHECK_EQ (r, 0 );
154122
155123 float * xpu_variances_in =
0 commit comments