Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions lite/backends/opencl/cl_kernel/image/grid_sampler_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ __kernel void grid_sampler(__read_only image2d_t input,
grid_y = fmin(fmax(grid_y, 0), y_max);
#endif
#endif

#ifdef NEAREST
int in_ind_w = round(grid_x);
int in_ind_h = round(grid_y);
int x_p = out_c * out_width + in_ind_w;
int y_p = out_n * out_height + in_ind_h;

CL_DTYPE4 out_val;
if (in_ind_w < 0 || in_ind_w > out_width - 1 || in_ind_h < 0 ||
in_ind_h > out_height - 1) {
out_val = (CL_DTYPE4)(0.0);
} else {
out_val = READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(x_p, y_p));
}
#endif
#ifdef BILINEAR
int xw = floor(grid_x);
int yn = floor(grid_y);
int x_p = out_c * out_width + xw;
Expand Down Expand Up @@ -119,6 +135,7 @@ __kernel void grid_sampler(__read_only image2d_t input,
in_ne * (CL_DTYPE4)(dw) * (CL_DTYPE4)(ds) +
in_sw * (CL_DTYPE4)(de) * (CL_DTYPE4)(dn) +
in_se * (CL_DTYPE4)(dw) * (CL_DTYPE4)(dn);
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, outpoints, out_val);

if (out_hblk_id * 4 + 1 < out_height) { // y
Expand Down Expand Up @@ -168,6 +185,22 @@ __kernel void grid_sampler(__read_only image2d_t input,
grid_y = fmin(fmax(grid_y, 0), y_max);
#endif
#endif

#ifdef NEAREST
int in_ind_w = round(grid_x);
int in_ind_h = round(grid_y);
int x_p = out_c * out_width + in_ind_w;
int y_p = out_n * out_height + in_ind_h;

CL_DTYPE4 out_val;
if (in_ind_w < 0 || in_ind_w > out_width - 1 || in_ind_h < 0 ||
in_ind_h > out_height - 1) {
out_val = (CL_DTYPE4)(0.0);
} else {
out_val = READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(x_p, y_p));
}
#endif
#ifdef BILINEAR
xw = floor(grid_x);
yn = floor(grid_y);
x_p = out_c * out_width + xw;
Expand Down Expand Up @@ -202,6 +235,7 @@ __kernel void grid_sampler(__read_only image2d_t input,
in_ne * (CL_DTYPE4)(dw) * (CL_DTYPE4)(ds) +
in_sw * (CL_DTYPE4)(de) * (CL_DTYPE4)(dn) +
in_se * (CL_DTYPE4)(dw) * (CL_DTYPE4)(dn);
#endif
WRITE_IMG_TYPE(
CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y), out_val);
}
Expand Down Expand Up @@ -253,6 +287,22 @@ __kernel void grid_sampler(__read_only image2d_t input,
grid_y = fmin(fmax(grid_y, 0), y_max);
#endif
#endif

#ifdef NEAREST
int in_ind_w = round(grid_x);
int in_ind_h = round(grid_y);
int x_p = out_c * out_width + in_ind_w;
int y_p = out_n * out_height + in_ind_h;

CL_DTYPE4 out_val;
if (in_ind_w < 0 || in_ind_w > out_width - 1 || in_ind_h < 0 ||
in_ind_h > out_height - 1) {
out_val = (CL_DTYPE4)(0.0);
} else {
out_val = READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(x_p, y_p));
}
#endif
#ifdef BILINEAR
xw = floor(grid_x);
yn = floor(grid_y);
x_p = out_c * out_width + xw;
Expand Down Expand Up @@ -287,6 +337,7 @@ __kernel void grid_sampler(__read_only image2d_t input,
in_ne * (CL_DTYPE4)(dw) * (CL_DTYPE4)(ds) +
in_sw * (CL_DTYPE4)(de) * (CL_DTYPE4)(dn) +
in_se * (CL_DTYPE4)(dw) * (CL_DTYPE4)(dn);
#endif
WRITE_IMG_TYPE(
CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y), out_val);
}
Expand Down Expand Up @@ -338,6 +389,22 @@ __kernel void grid_sampler(__read_only image2d_t input,
grid_y = fmin(fmax(grid_y, 0), y_max);
#endif
#endif

#ifdef NEAREST
int in_ind_w = round(grid_x);
int in_ind_h = round(grid_y);
int x_p = out_c * out_width + in_ind_w;
int y_p = out_n * out_height + in_ind_h;

CL_DTYPE4 out_val;
if (in_ind_w < 0 || in_ind_w > out_width - 1 || in_ind_h < 0 ||
in_ind_h > out_height - 1) {
out_val = (CL_DTYPE4)(0.0);
} else {
out_val = READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(x_p, y_p));
}
#endif
#ifdef BILINEAR
xw = floor(grid_x);
yn = floor(grid_y);
x_p = out_c * out_width + xw;
Expand Down Expand Up @@ -372,6 +439,7 @@ __kernel void grid_sampler(__read_only image2d_t input,
in_ne * (CL_DTYPE4)(dw) * (CL_DTYPE4)(ds) +
in_sw * (CL_DTYPE4)(de) * (CL_DTYPE4)(dn) +
in_se * (CL_DTYPE4)(dw) * (CL_DTYPE4)(dn);
#endif
WRITE_IMG_TYPE(
CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y), out_val);
}
Expand Down
21 changes: 14 additions & 7 deletions lite/kernels/opencl/grid_sampler_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,23 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
bool align_corners = grid_param_->align_corners;
build_options_ = align_corners ? " -DALIGN_CORNER " : "";
std::string padding_mode = grid_param_->padding_mode;
if (padding_mode == "border") {
std::string mode = grid_param_->mode;
if (padding_mode == "zeros") {
build_options_ += "";
} else if (padding_mode == "border") {
build_options_ += " -DBORDER ";
} else if (padding_mode == "reflection") {
build_options_ += " -DREFLECTION ";
} else {
LOG(FATAL) << "Unsupported grid sampler with padding mode:"
<< padding_mode;
}
if (mode == "nearest") {
build_options_ += " -DNEAREST ";
} else if (mode == "bilinear") {
build_options_ += " -DBILINEAR ";
} else {
LOG(FATAL) << "Unsupported grid sampler with interp mode:" << mode;
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
Expand All @@ -67,12 +80,6 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),

void ReInitWhenNeeded() override {
grid_param_ = param_.get_mutable<param_t>();
bool align_corners = grid_param_->align_corners;
std::string padding_mode = grid_param_->padding_mode;
std::string mode = grid_param_->mode;
if (mode != "bilinear") {
LOG(FATAL) << "Unsupported grid samper with interpolate mode:" << mode;
}
auto x_dims = grid_param_->x->dims();
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
first_epoch_for_reinit_) {
Expand Down
10 changes: 1 addition & 9 deletions lite/tests/unittest_py/op/test_grid_sampler_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,7 @@ def sample_predictor_configs(self):
return self.get_predictor_configs(), ["grid_sampler"], (1e-5, 1e-5)

def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
if predictor_config.target() == TargetType.OpenCL:
if program_config.ops[0].attrs["mode"] != "bilinear":
return True

self.add_ignore_check_case(
teller1, IgnoreReasons.PADDLELITE_NOT_SUPPORT,
"Lite does not support this op in a specific case on opencl. We need to fix it as soon as possible."
)
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=300)
Expand Down