Skip to content

Commit e2383b9

Browse files
Add tutorial for sensitivity. (PaddlePaddle#15)
1 parent bdac950 commit e2383b9

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
该示例介绍如何分析卷积网络中各卷积层的敏感度,以及如何根据计算出的敏感度选择一组合适的剪裁率。
2+
该示例默认会自动下载并使用MNIST数据。支持以下模型:
3+
4+
- MobileNetV1
5+
- MobileNetV2
6+
- ResNet50
7+
8+
## 1. 接口介绍
9+
10+
该示例涉及以下接口:
11+
12+
- [paddleslim.prune.sensitivity](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#sensitivity)
13+
- [paddleslim.prune.merge_sensitive](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#merge_sensitive)
14+
- [paddleslim.prune.get_ratios_by_loss](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#get_ratios_by_losssensitivities-loss)
15+
16+
## 2. 运行示例
17+
18+
19+
在路径`PaddleSlim/demo/sensitive`下执行以下代码运行示例:
20+
21+
```
22+
export CUDA_VISIBLE_DEVICES=0
23+
python train.py --model "MobileNetV1"
24+
```
25+
26+
通过`python train.py --help`查看更多选项。
27+
28+
## 3. 重要步骤说明
29+
30+
### 3.1 计算敏感度
31+
32+
计算敏感度之前,用户需要搭建好用于测试的网络,以及实现评估模型精度的回调函数。
33+
34+
调用`paddleslim.prune.sensitivity`接口计算敏感度。敏感度信息会追加到`sensitivities_file`选项所指定的文件中,如果需要重新计算敏感度,需要先删除`sensitivities_file`文件。
35+
36+
如果模型评估速度较慢,可以通过多进程的方式加速敏感度计算过程。比如在进程1中设置`pruned_ratios=[0.1, 0.2, 0.3, 0.4]`,并将敏感度信息存放在文件`sensitivities_0.data`中,然后在进程2中设置`pruned_ratios=[0.5, 0.6, 0.7]`,并将敏感度信息存储在文件`sensitivities_1.data`中。这样每个进程只会计算指定剪切率下的敏感度信息。多进程可以运行在单机多卡,或多机多卡。
37+
38+
代码如下:
39+
40+
```
41+
# 进程1
42+
sensitivity(
43+
val_program,
44+
place,
45+
params,
46+
test,
47+
sensitivities_file="sensitivities_0.data",
48+
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
49+
```
50+
51+
```
52+
# 进程2
53+
sensitivity(
54+
val_program,
55+
place,
56+
params,
57+
test,
58+
sensitivities_file="sensitivities_1.data",
59+
pruned_ratios=[0.5, 0.6, 0.7])
60+
```
61+
62+
63+
### 3.2 合并敏感度
64+
65+
如果用户通过上一节多进程的方式生成了多个存储敏感度信息的文件,可以通过`paddleslim.prune.merge_sensitive`将其合并,合并后的敏感度信息存储在一个`dict`中。代码如下:
66+
67+
```
68+
sens = merge_sensitive(["./sensitivities_0.data", "./sensitivities_1.data"])
69+
```
70+
71+
### 3.3 计算剪裁率
72+
73+
调用`paddleslim.prune.get_ratios_by_loss`接口计算一组剪裁率。
74+
75+
```
76+
ratios = get_ratios_by_loss(sens, 0.01)
77+
```
78+
79+
其中,`0.01`为一个阈值,对于任意卷积层,其剪裁率为使精度损失低于阈值`0.01`的最大剪裁率。
80+
81+
用户在计算出一组剪裁率之后可以通过接口`paddleslim.prune.Pruner`剪裁网络,并用接口`paddleslim.analysis.flops`计算`FLOPs`。如果`FLOPs`不满足要求,调整阈值重新计算出一组剪裁率。

0 commit comments

Comments
 (0)