Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
90c9d38
11-02/14:35
Zheng-Bicheng Nov 2, 2022
030e4a7
11-03/17:25
Zheng-Bicheng Nov 3, 2022
6ce7bfc
11-03/17:25
Zheng-Bicheng Nov 3, 2022
95a7949
Merge branch 'develop' into picodet
Zheng-Bicheng Nov 3, 2022
5e6248a
11-03/21:48
Zheng-Bicheng Nov 3, 2022
290de55
Merge remote-tracking branch 'origin/picodet' into picodet
Zheng-Bicheng Nov 3, 2022
3960cfe
11-04/01:00
Zheng-Bicheng Nov 3, 2022
7391f50
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 4, 2022
b27995e
11-04/13:13
Zheng-Bicheng Nov 4, 2022
a40a79e
11-04/14:03
Zheng-Bicheng Nov 4, 2022
bdeb075
11-04/14:21
Zheng-Bicheng Nov 4, 2022
6614d27
11-04/14:21
Zheng-Bicheng Nov 4, 2022
89ccf9f
11-04/14:21
Zheng-Bicheng Nov 4, 2022
34fdc3f
* 更新配置文件
Zheng-Bicheng Nov 4, 2022
aeef7a0
* 修正配置文件
Zheng-Bicheng Nov 4, 2022
8ec3199
* 添加缺失的python文件
Zheng-Bicheng Nov 4, 2022
3510d48
* 修正文档
Zheng-Bicheng Nov 4, 2022
1a2e3d2
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 4, 2022
8144106
* 修正代码格式问题0
Zheng-Bicheng Nov 4, 2022
0db9759
Merge remote-tracking branch 'origin/new_picodet' into new_picodet
Zheng-Bicheng Nov 4, 2022
31824ad
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 4, 2022
597f013
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 5, 2022
96af67c
* 按照要求修改
Zheng-Bicheng Nov 5, 2022
a5b5655
Merge remote-tracking branch 'origin/new_picodet' into new_picodet
Zheng-Bicheng Nov 5, 2022
de91482
* 按照要求修改
Zheng-Bicheng Nov 5, 2022
3fb9204
* 按照要求修改
Zheng-Bicheng Nov 5, 2022
1cbe0de
* 按照要求修改
Zheng-Bicheng Nov 5, 2022
65581e5
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 5, 2022
2a32ba3
* 按照要求修改
Zheng-Bicheng Nov 5, 2022
8793741
Merge remote-tracking branch 'origin/new_picodet' into new_picodet
Zheng-Bicheng Nov 5, 2022
bb2a606
test
Zheng-Bicheng Nov 6, 2022
b3043b1
Merge branch 'develop' into new_picodet
Zheng-Bicheng Nov 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/cn/faq/rknpu2/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx
output_folder: ./
target_platform: RK3588
normalize:
mean: [0.5,0.5,0.5]
std: [0.5,0.5,0.5]
mean: [[0.5,0.5,0.5]]
std: [[0.5,0.5,0.5]]
outputs: None
```

Expand All @@ -45,4 +45,4 @@ python tools/export.py --config_path=./config.yaml

## 模型导出要注意的事项

* 请不要导出带softmax和argmax的模型,这两个算子存在bug,请在外部进行运算
* 请不要导出带softmax和argmax的模型,这两个算子存在bug,请在外部进行运算
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# PaddleDetection RKNPU2部署示例

## 支持模型列表

目前FastDeploy支持如下模型的部署
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)

## 准备PaddleSeg部署模型以及转换模型
RKNPU部署模型前需要将Paddle模型转换成RKNN模型,具体步骤如下:
* Paddle动态图模型转换为ONNX模型,请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md)
,注意在转换时请设置**export.nms=True**.
* ONNX模型转换RKNN模型的过程,请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。


## 模型转换example
下面以Picodet-npu为例子,教大家如何转换PPSeg模型到RKNN模型。
```bash
## 下载Paddle静态图模型并解压
wget https://bj.bcebos.com/fastdeploy/models/rknn2/picodet_s_416_coco_npu.zip
unzip -qo picodet_s_416_coco_npu.zip

# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐
paddle2onnx --model_dir picodet_s_416_coco_npu \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--save_file picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--enable_dev_version True

python -m paddle2onnx.optimize --input_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--output_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--input_shape_dict "{'image':[1,3,416,416]}"
# ONNX模型转RKNN模型
# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml
```

- [Python部署](./python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
CMAKE_MINIMUM_REQUIRED(VERSION 3.10)
project(rknpu2_test)

set(CMAKE_CXX_STANDARD 14)

# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3")

include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake)
include_directories(${FastDeploy_INCLUDE_DIRS})

add_executable(infer_picodet infer_picodet.cc)
target_link_libraries(infer_picodet ${FastDeploy_LIBS})



set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install)

install(TARGETS infer_picodet DESTINATION ./)

install(DIRECTORY model DESTINATION ./)
install(DIRECTORY images DESTINATION ./)

file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*)
message("${FASTDEPLOY_LIBS}")
install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib)

file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*)
install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib)

install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./)

file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)

file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# PaddleDetection C++部署示例

本目录下提供`infer_xxxxx.cc`快速完成PPDetection模型在Rockchip板子上上通过二代NPU加速部署的示例。

在部署前,需确认以下两个步骤:

1. 软硬件环境满足要求
2. 根据开发环境,下载预编译部署库或者从头编译FastDeploy仓库

以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现

## 生成基本目录文件

该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_cpu_npu.cc
├── infer_cpu_npu.h
├── main.cc
├── model # 存放模型文件的文件夹
└── thirdpartys # 存放sdk的文件夹
```

首先需要先生成目录结构
```bash
mkdir build
mkdir images
mkdir model
mkdir thirdpartys
```

## 编译

### 编译并拷贝SDK到thirdpartys文件夹

请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成
fastdeploy-0.0.3目录,请移动它至thirdpartys目录下.

### 拷贝模型文件,以及配置文件至model文件夹
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。
转换为RKNN后的模型文件也需要拷贝至model。

### 准备测试图片至image文件夹
```bash
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
cp 000000014439.jpg ./images
```

### 编译example

```bash
cd build
cmake ..
make -j8
make install
```

## 运行例程

```bash
cd ./build/install
./rknpu_test
```


- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../../docs/api/vision_results/)
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"

void InferPicodet(const std::string& device = "cpu");

int main() {
InferPicodet("npu");
return 0;
}

fastdeploy::RuntimeOption GetOption(const std::string& device) {
auto option = fastdeploy::RuntimeOption();
if (device == "npu") {
option.UseRKNPU2();
} else {
option.UseCpu();
}
return option;
}

fastdeploy::ModelFormat GetFormat(const std::string& device) {
auto format = fastdeploy::ModelFormat::ONNX;
if (device == "npu") {
format = fastdeploy::ModelFormat::RKNN;
} else {
format = fastdeploy::ModelFormat::ONNX;
}
return format;
}

std::string GetModelPath(std::string& model_path, const std::string& device) {
if (device == "npu") {
model_path += "rknn";
} else {
model_path += "onnx";
}
return model_path;
}

void InferPicodet(const std::string &device) {
std::string model_file = "./model/picodet_s_416_coco_npu/picodet_s_416_coco_npu_rk3588.";
std::string params_file;
std::string config_file = "./model/picodet_s_416_coco_npu/infer_cfg.yml";

fastdeploy::RuntimeOption option = GetOption(device);
fastdeploy::ModelFormat format = GetFormat(device);
model_file = GetModelPath(model_file, device);
auto model = fastdeploy::vision::detection::RKPicoDet(
model_file, params_file, config_file,option,format);

if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto image_file = "./images/000000014439.jpg";
auto im = cv::imread(image_file);

fastdeploy::vision::DetectionResult res;
clock_t start = clock();
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
clock_t end = clock();
auto dur = static_cast<double>(end - start);
printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC));

std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
cv::imwrite("picodet_npu_result.jpg", vis_im);
std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# PaddleDetection Python部署示例

在部署前,需确认以下两个步骤

- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)

本目录下提供`infer.py`快速完成Picodet在RKNPU上部署的示例。执行如下脚本即可完成

```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/segmentation/paddleseg/python

# 下载图片
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg

# copy model
cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python

# 推理
python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \
--config_file ./picodet_s_416_coco_npu/infer_cfg.yml \
--image 000000014439.jpg
```


## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
需要先调用DisableNormalizePermute(C++)或`disable_normalize_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。
## 其它文档

- [PaddleSeg 模型介绍](..)
- [PaddleSeg C++部署](../cpp)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [转换PPSeg RKNN模型文档](../README.md)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os


def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of PaddleSeg model.")
parser.add_argument(
"--config_file", required=True, help="Path of PaddleSeg config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()


def build_option(args):
option = fd.RuntimeOption()
option.use_rknpu2()
return option


args = parse_arguments()

# 配置runtime,加载模型
runtime_option = build_option(args)
model_file = args.model_file
params_file = ""
config_file = args.config_file
model = fd.vision.detection.RKPicoDet(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN)

# 预测图片分割结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)

# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")
1 change: 1 addition & 0 deletions fastdeploy/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/detection/rknpu2/model.h"
#include "fastdeploy/vision/facedet/contrib/retinaface.h"
#include "fastdeploy/vision/facedet/contrib/scrfd.h"
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/vision/detection/detection_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void BindNanoDetPlus(pybind11::module& m);
void BindPPDet(pybind11::module& m);
void BindYOLOv7End2EndTRT(pybind11::module& m);
void BindYOLOv7End2EndORT(pybind11::module& m);
void BindRKDet(pybind11::module& m);


void BindDetection(pybind11::module& m) {
auto detection_module =
Expand All @@ -42,5 +44,6 @@ void BindDetection(pybind11::module& m) {
BindNanoDetPlus(detection_module);
BindYOLOv7End2EndTRT(detection_module);
BindYOLOv7End2EndORT(detection_module);
BindRKDet(detection_module);
}
} // namespace fastdeploy
16 changes: 16 additions & 0 deletions fastdeploy/vision/detection/rknpu2/model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "fastdeploy/vision/detection/rknpu2/rkpicodet.h"
Loading