Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions fastdeploy/pybind/fastdeploy_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ void BindRuntime(pybind11::module& m) {
pybind11::class_<Runtime>(m, "Runtime")
.def(pybind11::init())
.def("init", &Runtime::Init)
.def("infer",
[](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs(self.NumOutputs());
self.Infer(inputs, &outputs);
return outputs;
})
.def("infer",
[](Runtime& self, std::map<std::string, pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size());
Expand Down Expand Up @@ -132,6 +138,32 @@ void BindRuntime(pybind11::module& m) {
.value("FP64", FDDataType::FP64)
.value("UINT8", FDDataType::UINT8);

pybind11::class_<FDTensor>(m, "FDTensor", pybind11::buffer_protocol())
.def(pybind11::init())
.def("cpu_data",
[](FDTensor& self) {
auto ptr = self.CpuData();
auto numel = self.Numel();
auto dtype = FDDataTypeToNumpyDataType(self.dtype);
auto base = pybind11::array(dtype, self.shape);
return pybind11::array(dtype, self.shape, ptr, base);
})
.def("resize", static_cast<void (FDTensor::*)(size_t)>(&FDTensor::Resize))
.def("resize",
static_cast<void (FDTensor::*)(const std::vector<int64_t>&)>(
&FDTensor::Resize))
.def(
"resize",
[](FDTensor& self, const std::vector<int64_t>& shape,
const FDDataType& dtype, const std::string& name,
const Device& device) { self.Resize(shape, dtype, name, device); })
.def("numel", &FDTensor::Numel)
.def("nbytes", &FDTensor::Nbytes)
.def_readwrite("name", &FDTensor::name)
.def_readonly("shape", &FDTensor::shape)
.def_readonly("dtype", &FDTensor::dtype)
.def_readonly("device", &FDTensor::device);

m.def("get_available_backends", []() { return GetAvailableBackends(); });
}

Expand Down
69 changes: 39 additions & 30 deletions fastdeploy/vision/detection/contrib/yolov5.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "fastdeploy/vision/detection/contrib/yolov5.h"

#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"

Expand Down Expand Up @@ -74,14 +75,14 @@ YOLOv5::YOLOv5(const std::string& model_file, const std::string& params_file,

bool YOLOv5::Initialize() {
// parameters for preprocess
size = {640, 640};
padding_value = {114.0, 114.0, 114.0};
is_mini_pad = false;
is_no_pad = false;
is_scale_up = false;
stride = 32;
max_wh = 7680.0;
multi_label = true;
size_ = {640, 640};
padding_value_ = {114.0, 114.0, 114.0};
is_mini_pad_ = false;
is_no_pad_ = false;
is_scale_up_ = false;
stride_ = 32;
max_wh_ = 7680.0;
multi_label_ = true;

if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
Expand All @@ -90,23 +91,34 @@ bool YOLOv5::Initialize() {
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
// TODO(qiuyanjun): remove
// is_dynamic_input_ = false;
// auto shape = InputInfoOfRuntime(0).shape;
// for (int i = 0; i < shape.size(); ++i) {
// // if height or width is dynamic
// if (i >= 2 && shape[i] <= 0) {
// is_dynamic_input_ = true;
// break;
// }
// }
// if (!is_dynamic_input_) {
// is_mini_pad_ = false;
// }
return true;
}

bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
const std::vector<float> padding_value,
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
int stride, float max_wh, bool multi_label) {
// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

// process after image load
double ratio = (size[0] * 1.0) / std::max(static_cast<float>(mat->Height()),
static_cast<float>(mat->Width()));
Expand Down Expand Up @@ -147,7 +159,8 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
bool YOLOv5::Postprocess(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold, bool multi_label) {
float conf_threshold, float nms_iou_threshold, bool multi_label,
float max_wh) {
FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now.");
result->Clear();
if (multi_label) {
Expand Down Expand Up @@ -251,13 +264,9 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,

std::map<std::string, std::array<float, 2>> im_info;

// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};

if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
Expand All @@ -279,7 +288,7 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
#endif

if (!Postprocess(output_tensors[0], result, im_info, conf_threshold,
nms_iou_threshold, multi_label)) {
nms_iou_threshold, multi_label_)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
Expand Down
68 changes: 38 additions & 30 deletions fastdeploy/vision/detection/contrib/yolov5.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,57 +41,65 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
float conf_threshold = 0.25,
float nms_iou_threshold = 0.5);

// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据,传给后端进行推理
// im_info为预处理过程保存的数据,在后处理中需要用到
static bool Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size = {640, 640},
const std::vector<float> padding_value = {114.0, 114.0,
114.0},
bool is_mini_pad = false, bool is_no_pad = false,
bool is_scale_up = false, int stride = 32,
float max_wh = 7680.0, bool multi_label = true);

// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息,后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
// nms_iou_threshold 后处理时NMS设定的iou阈值
// multi_label 后处理时box选取是否采用多标签方式
static bool Postprocess(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold, bool multi_label,
float max_wh = 7680.0);

// 以下为模型在预测时的一些参数,基本是前后处理所需
// 用户在创建模型后,可根据模型的要求,以及自己的需求
// 对参数进行修改
// tuple of (width, height)
std::vector<int> size;
std::vector<int> size_;
// padding value, size should be same with Channels
std::vector<float> padding_value;
std::vector<float> padding_value_;
// only pad to the minimum rectange which height and width is times of stride
bool is_mini_pad;
bool is_mini_pad_;
// while is_mini_pad = false and is_no_pad = true, will resize the image to
// the set size
bool is_no_pad;
bool is_no_pad_;
// if is_scale_up is false, the input image only can be zoom out, the maximum
// resize scale cannot exceed 1.0
bool is_scale_up;
bool is_scale_up_;
// padding stride, for is_mini_pad
int stride;
int stride_;
// for offseting the boxes by classes when using NMS
float max_wh;
float max_wh_;
// for different strategies to get boxes when postprocessing
bool multi_label;
bool multi_label_;

private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();

// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据,传给后端进行推理
// im_info为预处理过程保存的数据,在后处理中需要用到
bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);

// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息,后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
// nms_iou_threshold 后处理时NMS设定的iou阈值
// multi_label 后处理时box选取是否采用多标签方式
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold,
bool multi_label);

// 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致
bool IsDynamicInput() const { return is_dynamic_input_; }

void LetterBox(Mat* mat, std::vector<int> size, std::vector<float> color,
bool _auto, bool scale_fill = false, bool scale_up = true,
int stride = 32);
static void LetterBox(Mat* mat, std::vector<int> size,
std::vector<float> color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);

// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
Expand Down
43 changes: 35 additions & 8 deletions fastdeploy/vision/detection/contrib/yolov5_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,40 @@ void BindYOLOv5(pybind11::module& m) {
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::detection::YOLOv5::size)
.def_readwrite("padding_value", &vision::detection::YOLOv5::padding_value)
.def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad)
.def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad)
.def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up)
.def_readwrite("stride", &vision::detection::YOLOv5::stride)
.def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh)
.def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label);
.def_static("preprocess",
[](pybind11::array& data, const std::vector<int>& size,
const std::vector<float> padding_value, bool is_mini_pad,
bool is_no_pad, bool is_scale_up, int stride, float max_wh,
bool multi_label) {
auto mat = PyArrayToCvMat(data);
fastdeploy::vision::Mat fd_mat(mat);
FDTensor output;
std::map<std::string, std::array<float, 2>> im_info;
vision::detection::YOLOv5::Preprocess(
&fd_mat, &output, &im_info, size, padding_value,
is_mini_pad, is_no_pad, is_scale_up, stride, max_wh,
multi_label);
return make_pair(output, im_info);
})
.def_static("postprocess",
[](FDTensor& infer_result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold,
bool multi_label, float max_wh) {
vision::DetectionResult result;
vision::detection::YOLOv5::Postprocess(
infer_result, &result, im_info, conf_threshold,
nms_iou_threshold, multi_label, max_wh);
return result;
})
.def_readwrite("size", &vision::detection::YOLOv5::size_)
.def_readwrite("padding_value",
&vision::detection::YOLOv5::padding_value_)
.def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad_)
.def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad_)
.def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up_)
.def_readwrite("stride", &vision::detection::YOLOv5::stride_)
.def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh_)
.def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label_);
}
} // namespace fastdeploy
3 changes: 2 additions & 1 deletion python/fastdeploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import sys

from .c_lib_wrap import (Frontend, Backend, FDDataType, TensorInfo, Device,
is_built_with_gpu, is_built_with_ort,
FDTensor, is_built_with_gpu, is_built_with_ort,
is_built_with_paddle, is_built_with_trt,
get_default_cuda_directory)

from .runtime import Runtime, RuntimeOption
from .model import FastDeployModel
from . import c_lib_wrap as C
Expand Down
3 changes: 2 additions & 1 deletion python/fastdeploy/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self, runtime_option):
runtime_option._option), "Initialize Runtime Failed!"

def infer(self, data):
assert isinstance(data, dict), "The input data should be type of dict."
assert isinstance(data, dict) or isinstance(
data, list), "The input data should be type of dict or list."
return self._runtime.infer(data)

def num_inputs(self):
Expand Down
25 changes: 25 additions & 0 deletions python/fastdeploy/vision/detection/contrib/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)

@staticmethod
def preprocess(input_image,
size=[640, 640],
padding_value=[114.0, 114.0, 114.0],
is_mini_pad=False,
is_no_pad=False,
is_scale_up=False,
stride=32,
max_wh=7680.0,
multi_label=True):
return C.vision.detection.YOLOv5.preprocess(
input_image, size, padding_value, is_mini_pad, is_no_pad,
is_scale_up, stride, max_wh, multi_label)

@staticmethod
def postprocess(infer_result,
im_info,
conf_threshold=0.25,
nms_iou_threshold=0.5,
multi_label=True,
max_wh=7680.0):
return C.vision.detection.YOLOv5.postprocess(
infer_result, im_info, conf_threshold, nms_iou_threshold,
multi_label, max_wh)

# 一些跟YOLOv5模型有关的属性封装
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
@property
Expand Down