Skip to content

Commit b7a1dd0

Browse files
authored
Merge pull request PaddlePaddle#53 from LokeZhou/APPflow
add appflow;
2 parents c0b8f34 + f73c8fe commit b7a1dd0

18 files changed

+2027
-609
lines changed

paddlevlp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .models import *
1818
from .optimization import *
1919
from .processors import *
20-
from .tests import *
20+
from .tests import *
21+
from .appflow import *

paddlevlp/appflow/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .appflow import Appflow

paddlevlp/appflow/appflow.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# coding:utf-8
2+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import paddle
17+
18+
from paddlevlp.utils.tools import get_env_device
19+
from .configuration import APPLICATIONS
20+
21+
22+
class Appflow(object):
23+
"""
24+
Args:
25+
app (str): The app name for the Appflow, and get the task class from the name.
26+
model (str, optional): The model name in the task, if set None, will use the default model.
27+
mode (str, optional): Select the mode of the task, only used in the tasks of word_segmentation and ner.
28+
If set None, will use the default mode.
29+
device_id (int, optional): The device id for the gpu, xpu and other devices, the defalut value is 0.
30+
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
31+
32+
"""
33+
34+
def __init__(self,
35+
app,
36+
models=None,
37+
mode=None,
38+
device_id=0,
39+
from_hf_hub=False,
40+
**kwargs):
41+
assert app in APPLICATIONS, f"The task name:{app} is not in Taskflow list, please check your task name."
42+
self.app = app
43+
# Set the device for the task
44+
device = get_env_device()
45+
if device == "cpu" or device_id == -1:
46+
paddle.set_device("cpu")
47+
else:
48+
paddle.set_device(device + ":" + str(device_id))
49+
50+
tag = "models"
51+
ind_tag = "model"
52+
self.models = models
53+
if isinstance(self.models, list) and len(self.models) > 0:
54+
for model in self.models:
55+
assert model in set(APPLICATIONS[app][tag].keys(
56+
)), f"The {tag} name: {model} is not in task:[{app}]"
57+
else:
58+
self.models = [APPLICATIONS[app]["default"][ind_tag]]
59+
60+
self.task_instances = []
61+
for model in self.models:
62+
if "task_priority_path" in APPLICATIONS[self.app][tag][model]:
63+
priority_path = APPLICATIONS[self.app][tag][model][
64+
"task_priority_path"]
65+
else:
66+
priority_path = None
67+
68+
# Update the task config to kwargs
69+
config_kwargs = APPLICATIONS[self.app][tag][model]
70+
kwargs["device_id"] = device_id
71+
kwargs.update(config_kwargs)
72+
task_class = APPLICATIONS[self.app][tag][model]["task_class"]
73+
self.task_instances.append(
74+
task_class(
75+
model=model,
76+
task=self.app,
77+
priority_path=priority_path,
78+
from_hf_hub=from_hf_hub,
79+
**kwargs))
80+
81+
app_list = APPLICATIONS.keys()
82+
Appflow.app_list = app_list
83+
84+
def __call__(self, **inputs):
85+
"""
86+
The main work function in the appflow.
87+
"""
88+
results = inputs
89+
for task_instance in self.task_instances:
90+
# Get input results and put into outputs
91+
results = task_instance(results)
92+
return results
93+
94+
def help(self):
95+
"""
96+
Return the task usage message.
97+
"""
98+
return self.task_instance.help()
99+
100+
def task_path(self):
101+
"""
102+
Return the path of current task
103+
"""
104+
return self.task_instance._task_path
105+
106+
@staticmethod
107+
def tasks():
108+
"""
109+
Return the available task list.
110+
"""
111+
task_list = list(TASKS.keys())
112+
return task_list

paddlevlp/appflow/apptask.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# coding:utf-8
2+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import math
18+
from multiprocessing import cpu_count
19+
20+
import paddle
21+
22+
from paddlevlp.utils.env import PPMIX_HOME
23+
from paddlevlp.utils.log import logger
24+
from paddlenlp.taskflow.utils import dygraph_mode_guard
25+
26+
27+
class AppTask(object):
28+
"""
29+
The meta classs of task in Taskflow. The meta class has the five abstract function,
30+
the subclass need to inherit from the meta class.
31+
Args:
32+
task(string): The name of task.
33+
model(string): The model name in the task.
34+
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
35+
"""
36+
37+
def __init__(self, model, task, priority_path=None, **kwargs):
38+
39+
self.model = model
40+
self.task = task
41+
self.kwargs = kwargs
42+
self._priority_path = priority_path
43+
self.is_static_model = kwargs.get("is_static_model", False)
44+
45+
self._home_path = self.kwargs[
46+
"home_path"] if "home_path" in self.kwargs else PPMIX_HOME
47+
48+
if "task_path" in self.kwargs:
49+
self._task_path = self.kwargs["task_path"]
50+
self._model_dir = self._task_path
51+
elif self._priority_path:
52+
self._task_path = os.path.join(self._home_path, "models",
53+
self._priority_path)
54+
self._model_dir = os.path.join(self._home_path, "models")
55+
else:
56+
self._task_path = os.path.join(self._home_path, "models",
57+
self.model)
58+
self._model_dir = os.path.join(self._home_path, "models")
59+
60+
self._infer_precision = self.kwargs[
61+
"precision"] if "precision" in self.kwargs else "fp32"
62+
# Default to use Paddle Inference
63+
self._predictor_type = "paddle-inference"
64+
self._num_threads = self.kwargs[
65+
"num_threads"] if "num_threads" in self.kwargs else math.ceil(
66+
cpu_count() / 2)
67+
68+
def _construct_tokenizer(self, model):
69+
"""
70+
Construct the tokenizer for the predictor.
71+
"""
72+
73+
def _construct_input_spec(self):
74+
"""
75+
Construct the input spec for the predictor.
76+
"""
77+
78+
def _get_static_model_name(self):
79+
names = []
80+
for file_name in os.listdir(self._task_path):
81+
if ".pdparams" in file_name:
82+
names.append(file_name[:-9])
83+
if len(names) == 0:
84+
raise IOError(f"{self._task_path} should include '.pdparams' file.")
85+
if len(names) > 1:
86+
logger.warning(
87+
f"{self._task_path} includes more than one '.pdparams' file.")
88+
return names[0]
89+
90+
def _convert_dygraph_to_static(self):
91+
"""
92+
Convert the dygraph model to static model.
93+
"""
94+
assert (
95+
self._model is not None
96+
), "The dygraph model must be created before converting the dygraph model to static model."
97+
assert (
98+
self._input_spec is not None
99+
), "The input spec must be created before converting the dygraph model to static model."
100+
logger.info("Converting to the inference model cost a little time.")
101+
static_model = paddle.jit.to_static(
102+
self._model, input_spec=self._input_spec)
103+
104+
paddle.jit.save(static_model, self.inference_model_path)
105+
logger.info("The inference model save in the path:{}".format(
106+
self.inference_model_path))
107+
108+
def _prepare_static_mode(self):
109+
"""
110+
Construct the input data and predictor in the PaddlePaddele static mode.
111+
"""
112+
if paddle.get_device() == "cpu":
113+
self._config.disable_gpu()
114+
self._config.enable_mkldnn()
115+
if self._infer_precision == "int8":
116+
# EnableMKLDNN() only works when IR optimization is enabled.
117+
self._config.switch_ir_optim(True)
118+
self._config.enable_mkldnn_int8()
119+
logger.info((">>> [InferBackend] INT8 inference on CPU ..."))
120+
elif paddle.get_device().split(":", 1)[0] == "npu":
121+
self._config.disable_gpu()
122+
self._config.enable_custom_device("npu", self.kwargs["device_id"])
123+
else:
124+
precision_map = {
125+
'trt_int8': paddle.inference.PrecisionType.Int8,
126+
'trt_fp32': paddle.inference.PrecisionType.Float32,
127+
'trt_fp16': paddle.inference.PrecisionType.Half
128+
}
129+
self._config.enable_use_gpu(5000, self.kwargs["device_id"])
130+
self._config.set_cpu_math_library_num_threads(self._num_threads)
131+
self._config.switch_use_feed_fetch_ops(False)
132+
self._config.disable_glog_info()
133+
self._config.switch_ir_optim(True)
134+
self._config.enable_memory_optim(True)
135+
if self._infer_precision in precision_map.keys():
136+
self._config.enable_tensorrt_engine(
137+
workspace_size=(1 << 40),
138+
max_batch_size=0,
139+
min_subgraph_size=30,
140+
precision_mode=precision_map[self._infer_precision],
141+
use_static=True,
142+
use_calib_mode=False)
143+
144+
if not os.path.exists(self._tuned_trt_shape_file):
145+
self._config.collect_shape_range_info(
146+
self._tuned_trt_shape_file)
147+
else:
148+
logger.info(f'Use dynamic shape file: '
149+
f'{self._tuned_trt_shape_file} for TRT...')
150+
self._config.enable_tuned_tensorrt_dynamic_shape(
151+
self._tuned_trt_shape_file, True)
152+
153+
if self.task == 'openset_det_sam':
154+
self._config.delete_pass("add_support_int8_pass")
155+
156+
if self.model == 'GroundingDino/groundingdino-swint-ogc':
157+
self._config.exp_disable_tensorrt_ops([
158+
"pad3d", "set_value", "reduce_all", "cumsum_8.tmp_0",
159+
"linear_296.tmp_1"
160+
])
161+
162+
if self.model == 'Sam/SamVitH-1024' or self.model == 'Sam/SamVitH-512':
163+
self._config.delete_pass("shuffle_channel_detect_pass")
164+
self._config.delete_pass("trt_skip_layernorm_fuse_pass")
165+
self._config.delete_pass("preln_residual_bias_fuse_pass")
166+
self._config.exp_disable_tensorrt_ops([
167+
"concat_1.tmp_0", "set_value", "empty_0.tmp_0",
168+
"concat_55.tmp_0"
169+
])
170+
171+
self.predictor = paddle.inference.create_predictor(self._config)
172+
self.input_names = [name for name in self.predictor.get_input_names()]
173+
self.input_handles = [
174+
self.predictor.get_input_handle(name)
175+
for name in self.predictor.get_input_names()
176+
]
177+
self.output_handle = [
178+
self.predictor.get_output_handle(name)
179+
for name in self.predictor.get_output_names()
180+
]
181+
182+
def _get_inference_model(self):
183+
"""
184+
Return the inference program, inputs and outputs in static mode.
185+
"""
186+
187+
# When the user-provided model path is already a static model, skip to_static conversion
188+
if self.is_static_model:
189+
self.inference_model_path = os.path.join(self._task_path,
190+
self._static_model_name)
191+
if not os.path.exists(self.inference_model_path +
192+
".pdmodel") or not os.path.exists(
193+
self.inference_model_path + ".pdiparams"):
194+
raise IOError(
195+
f"{self._task_path} should include {self._static_model_name + '.pdmodel'} and {self._static_model_name + '.pdiparams'} while is_static_model is True"
196+
)
197+
if self.paddle_quantize_model(self.inference_model_path):
198+
self._infer_precision = "int8"
199+
self._predictor_type = "paddle-inference"
200+
201+
else:
202+
# Since 'self._task_path' is used to load the HF Hub path when 'from_hf_hub=True', we construct the static model path in a different way
203+
self.inference_model_path = os.path.join(self._task_path,
204+
self._static_model_name)
205+
self._tuned_trt_shape_file = self.inference_model_path + "_shape.txt"
206+
if not os.path.exists(self.inference_model_path + ".pdiparams"):
207+
with dygraph_mode_guard():
208+
self._construct_model(self.model)
209+
self._construct_input_spec()
210+
self._convert_dygraph_to_static()
211+
212+
self._static_model_file = self.inference_model_path + ".pdmodel"
213+
self._static_params_file = self.inference_model_path + ".pdiparams"
214+
215+
if paddle.get_device().split(
216+
":", 1)[0] == "npu" and self._infer_precision == "fp16":
217+
# transform fp32 model tp fp16 model
218+
self._static_fp16_model_file = self.inference_model_path + "-fp16.pdmodel"
219+
self._static_fp16_params_file = self.inference_model_path + "-fp16.pdiparams"
220+
if not os.path.exists(
221+
self._static_fp16_model_file) and not os.path.exists(
222+
self._static_fp16_params_file):
223+
logger.info(
224+
"Converting to the inference model from fp32 to fp16.")
225+
paddle.inference.convert_to_mixed_precision(
226+
os.path.join(self._static_model_file),
227+
os.path.join(self._static_params_file),
228+
os.path.join(self._static_fp16_model_file),
229+
os.path.join(self._static_fp16_params_file),
230+
backend=paddle.inference.PlaceType.CUSTOM,
231+
mixed_precision=paddle.inference.PrecisionType.Half,
232+
# Here, npu sigmoid will lead to OOM and cpu sigmoid don't support fp16.
233+
# So, we add sigmoid to black list temporarily.
234+
black_list={"sigmoid"}, )
235+
logger.info(
236+
"The inference model in fp16 precison save in the path:{}".
237+
format(self._static_fp16_model_file))
238+
self._static_model_file = self._static_fp16_model_file
239+
self._static_params_file = self._static_fp16_params_file
240+
241+
if self._predictor_type == "paddle-inference":
242+
self._config = paddle.inference.Config(self._static_model_file,
243+
self._static_params_file)
244+
self._prepare_static_mode()
245+
else:
246+
self._prepare_onnx_mode()
247+
248+
def __call__(self, *args, **kwargs):
249+
inputs = self._preprocess(*args)
250+
outputs = self._run_model(inputs, **kwargs)
251+
results = self._postprocess(outputs)
252+
return results

0 commit comments

Comments
 (0)