Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
02d29aa
update .gitignore
DefTruth Jul 12, 2022
f9caef2
Merge branch 'develop' of https://github.com/DefTruth/FastDeploy into…
DefTruth Jul 12, 2022
afa8114
Added checking for cmake include dir
DefTruth Jul 12, 2022
659c14c
fixed missing trt_backend option bug when init from trt
DefTruth Jul 12, 2022
17a43ce
remove un-need data layout and add pre-check for dtype
DefTruth Jul 12, 2022
75948f8
changed RGB2BRG to BGR2RGB in ppcls model
DefTruth Jul 12, 2022
b57244d
add yolov6 c++ and yolov6 pybind
DefTruth Jul 13, 2022
71cd83c
Merge branch 'develop' of https://github.com/DefTruth/FastDeploy into…
DefTruth Jul 13, 2022
f847490
add model_zoo yolov6 c++/python demo
DefTruth Jul 14, 2022
c56fdc3
fixed CMakeLists.txt typos
DefTruth Jul 14, 2022
2300f57
update yolov6 cpp/README.md
DefTruth Jul 14, 2022
5670280
Merge branch 'PaddlePaddle:develop' into develop
DefTruth Jul 18, 2022
cb91b3c
add yolox c++/pybind and model_zoo demo
DefTruth Jul 18, 2022
9d7e9d9
move some helpers to private
DefTruth Jul 18, 2022
d2e51a2
fixed CMakeLists.txt typos
DefTruth Jul 18, 2022
eb63a0e
Merge branch 'develop' of https://github.com/DefTruth/FastDeploy into…
DefTruth Jul 18, 2022
df8b6a6
add normalize with alpha and beta
DefTruth Jul 18, 2022
8786384
add version notes for yolov5/yolov6/yolox
DefTruth Jul 18, 2022
fed2953
add copyright to yolov5.cc
DefTruth Jul 18, 2022
6ec3bd5
revert normalize
DefTruth Jul 18, 2022
367dad0
fixed some bugs in yolox
DefTruth Jul 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ option(ENABLE_OPENCV_CUDA "if to enable opencv with cuda, this will allow proces
option(ENABLE_DEBUG "if to enable print debug information, this may reduce performance." OFF)

# Whether to build fastdeply with vision/text/... examples, only for testings.
option(WTIH_VISION_EXAMPLES "Whether to build fastdeply with vision examples" ON)
option(WITH_VISION_EXAMPLES "Whether to build fastdeply with vision examples" ON)

if(ENABLE_DEBUG)
add_definitions(-DFASTDEPLOY_DEBUG)
Expand All @@ -53,7 +53,7 @@ option(BUILD_FASTDEPLOY_PYTHON "if build python lib for fastdeploy." OFF)
include_directories(${PROJECT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR})

if (WTIH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
if (WITH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
# ENABLE_VISION and ENABLE_VISION_VISUALIZE must be ON if enable vision examples.
message(STATUS "Found WTIH_VISION_EXAMPLES ON, so, force ENABLE_VISION and ENABLE_VISION_VISUALIZE ON")
set(ENABLE_VISION ON CACHE BOOL "force to enable vision models usage" FORCE)
Expand Down Expand Up @@ -181,8 +181,8 @@ set_target_properties(fastdeploy PROPERTIES VERSION ${FASTDEPLOY_VERSION})
target_link_libraries(fastdeploy ${DEPEND_LIBS})

# add examples after prepare include paths for third-parties
if (WTIH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
add_definitions(-DWTIH_VISION_EXAMPLES)
if (WITH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples)
add_definitions(-DWITH_VISION_EXAMPLES)
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/examples/bin)
add_subdirectory(examples)
endif()
Expand Down
3 changes: 2 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ function(add_fastdeploy_executable field url model)
endfunction()

# vision examples
if (WTIH_VISION_EXAMPLES)
if (WITH_VISION_EXAMPLES)
add_fastdeploy_executable(vision ultralytics yolov5)
add_fastdeploy_executable(vision meituan yolov6)
add_fastdeploy_executable(vision megvii yolox)
endif()

# other examples ...
52 changes: 52 additions & 0 deletions examples/vision/megvii_yolox.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision.h"

int main() {
namespace vis = fastdeploy::vision;

std::string model_file = "../resources/models/yolox_s.onnx";
std::string img_path = "../resources/images/bus.jpg";
std::string vis_path = "../resources/outputs/megvii_yolox_vis_result.jpg";

auto model = vis::megvii::YOLOX(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();

cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();

vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}

// 输出预测框结果
std::cout << res.Str() << std::endl;

// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}
5 changes: 2 additions & 3 deletions examples/vision/meituan_yolov6.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ int main() {

auto model = vis::meituan::YOLOv6(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Dynamic Mode: "
<< model.IsDynamicShape() << std::endl;
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();

Expand Down
1 change: 1 addition & 0 deletions fastdeploy/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "fastdeploy/vision/ppcls/model.h"
#include "fastdeploy/vision/ultralytics/yolov5.h"
#include "fastdeploy/vision/meituan/yolov6.h"
#include "fastdeploy/vision/megvii/yolox.h"
#endif

#include "fastdeploy/vision/visualize/visualize.h"
1 change: 1 addition & 0 deletions fastdeploy/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from . import ppcls
from . import ultralytics
from . import meituan
from . import megvii
from . import visualize
1 change: 0 additions & 1 deletion fastdeploy/vision/common/processors/normalize.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class Normalize : public Processor {
const std::vector<float>& min = std::vector<float>(),
const std::vector<float>& max = std::vector<float>(),
ProcLib lib = ProcLib::OPENCV_CPU);

private:
std::vector<float> alpha_;
std::vector<float> beta_;
Expand Down
96 changes: 96 additions & 0 deletions fastdeploy/vision/megvii/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import fastdeploy_main as C


class YOLOX(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(YOLOX, self).__init__(runtime_option)

self._model = C.vision.megvii.YOLOX(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "YOLOX initialize failed."

def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)

# 一些跟YOLOX模型有关的属性封装
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
@property
def size(self):
return self._model.size

@property
def padding_value(self):
return self._model.padding_value

@property
def is_decode_exported(self):
return self._model.is_decode_exported

@property
def downsample_strides(self):
return self._model.downsample_strides

@property
def max_wh(self):
return self._model.max_wh

@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh

@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value

@is_decode_exported.setter
def is_decode_exported(self, value):
assert isinstance(
value,
bool), "The value to set `is_decode_exported` must be type of bool."
self._model.max_wh = value

@downsample_strides.setter
def downsample_strides(self, value):
assert isinstance(
value,
list), "The value to set `downsample_strides` must be type of list."
self._model.downsample_strides = value

@max_wh.setter
def max_wh(self, value):
assert isinstance(
value, float), "The value to set `max_wh` must be type of float."
self._model.max_wh = value
41 changes: 41 additions & 0 deletions fastdeploy/vision/megvii/megvii_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/pybind/main.h"

namespace fastdeploy {
void BindMegvii(pybind11::module& m) {
auto megvii_module =
m.def_submodule("megvii", "https://github.com/megvii/YOLOX");
pybind11::class_<vision::megvii::YOLOX, FastDeployModel>(
megvii_module, "YOLOX")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::megvii::YOLOX& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::megvii::YOLOX::size)
.def_readwrite("padding_value",
&vision::megvii::YOLOX::padding_value)
.def_readwrite("is_decode_exported",
&vision::megvii::YOLOX::is_decode_exported)
.def_readwrite("downsample_strides",
&vision::megvii::YOLOX::downsample_strides)
.def_readwrite("max_wh", &vision::megvii::YOLOX::max_wh);
}
} // namespace fastdeploy
Loading