[slim] Refine framework of slim and add filter pruning strategy#16226
[slim] Refine framework of slim and add filter pruning strategy#16226wanghaoshuang merged 26 commits intoPaddlePaddle:developfrom
Conversation
wanghaoshuang
commented
Mar 15, 2019
- Add the framework of paddle slim
- Add filter pruning strategy
1. Add framework of paddle slim 2. Add filter pruning strategy test=develop
… release_slim_pruning test=develop
dab6dea to
cbab069
Compare
test=develop
6918572 to
fa2c64b
Compare
test=develop
test=develop
test=develop
test=develop
test=develop
… release_slim_pruning test=develop
d04fabb to
3679af9
Compare
b707d6f to
cac2867
Compare
test=develop
test=develop
test=develop
test=develop
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
There was a problem hiding this comment.
不是剪切和蒸馏不使用graph了吗?还需要graph_wrapper吗?
There was a problem hiding this comment.
为了以后方便切换到IrGraph
| @@ -13,9 +13,10 @@ | |||
| # limitations under the License. | |||
|
|
|||
There was a problem hiding this comment.
Quantization Strategy is not included?
There was a problem hiding this comment.
准备另起PR提交Quantization Strategy相关的内容
| from collections import OrderedDict | ||
| from ..prune import * | ||
| from .compress_pass import * | ||
| from ..quantization import * |
There was a problem hiding this comment.
I don't see where is the quantization?
There was a problem hiding this comment.
quantization是主干已经有的一个module: https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/fluid/contrib/slim/quantization
| cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. | ||
| """ | ||
| np.random.seed(cached_id) | ||
| cache_path = cache_path + "/" + str(cached_id) |
There was a problem hiding this comment.
建议使用os.path.join(dir0, dir1,..., file),以自动补充"/"
| Load the context from file. | ||
| """ | ||
| with open(file_name) as context_file: | ||
| data = pickle.load(context_file) |
There was a problem hiding this comment.
pickle在python3下面测试过吗,上次我们提交video的代码,在python3下面使用pickle.load报错,修改成下面的形式了
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
Fix log and comments. test=develop
6d075e2 to
6113fff
Compare
| batch += 1 | ||
| yield data | ||
|
|
||
| return s_reader |
There was a problem hiding this comment.
Why needs these readers?
There was a problem hiding this comment.
为了提升data feeder的速度,但并不是所有情况都能提速。这个逻辑放在压缩工具内部确实不太合适。
| teacher_graphs: The teacher graphs used in distillation strategies. | ||
| train_optimizer: The optimizer used to append backward ops and | ||
| optimization ops into train_graph. | ||
| distiller_optimizer: The optimizer used by distillation strategies. |
There was a problem hiding this comment.
In distill, how to set train_optimizer and distiller_optimizer?
There was a problem hiding this comment.
在distiller框架中,没有统一的用户接口,不同的任务都有不同的入口,比如:分类任务使用接口, language model 使用接口
其中,分类任务接口提供了蒸馏功能,language model 没有提供蒸馏功能。
在分类任务接口compress_classifier.py中,optimizer是固定的SGD,而且蒸馏策略只能用该optimizer. 如果用户需要切换optimizer,只能修改compress_classifier.py文件,该文件夹杂了各种对压缩策略的调度逻辑,对用户非常不友好。
对于分类任务,蒸馏没有用独立的optimizer, 它做蒸馏的步骤如下:
- 调用knowledge distillation policy的forward进行前向计算
- 使用所有压缩策略共用的optimizer进行反向计算
综上,distiller的用户接口设计、optimizer的使用方式和蒸馏策略的调用方式都不具参考价值。
There was a problem hiding this comment.
distiller给出的大部分demo都是基于compress_classifier.py的, 用户能调整的也就是SGD optimizer的一些参数。
There was a problem hiding this comment.
抱歉,我理解错问题了。这里我会补充下注释,详细说明下两个optimizer各自的用途。多谢。
There was a problem hiding this comment.
抱歉,我上面的意思是蒸馏压缩算法时,这两个optimizer如何设置需要解释清楚~
| logger.info('Latest evaluations: {}'.format(results)) | ||
| return abs(results[1] - results[0]) / results[0] < delta | ||
|
|
||
| def run_eval_graph(self, sampled_rate=None, cached_id=0): |
There was a problem hiding this comment.
Why the run_eval_graph in the Context ?
There was a problem hiding this comment.
为了方便各种策略随时评估当前graph的性能,该方法可放在:
- 每个策略类中:重复实现,策略实现者比较麻烦。
- CompressPass类中:每个策略访问不到
- GraphWrapper中?
- 单独实现一个ExecutorHelper?
| if 'init_model' in factory.compress_pass: | ||
| self.init_model = factory.compress_pass['init_model'] | ||
|
|
||
| def _init_model(self, context): |
There was a problem hiding this comment.
As discuss before, Pass is only used to transform graph, no include RUNING(train/eval)
There was a problem hiding this comment.
CompressPass其实不是pass,我把名字改成Compressor?
There was a problem hiding this comment.
I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju
There was a problem hiding this comment.
Renamed CompressPass to Compressor.
| Pruner used to pruning parameters by groups. | ||
| """ | ||
|
|
||
| def __init__(self, pruning_axis, criterions): |
There was a problem hiding this comment.
you might want to add some comments for this and many others. not easy to understand the codes.
| class: 'SensitivePruneStrategy' | ||
| pruner: 'pruner_1' | ||
| start_epoch: 1 | ||
| delta_rate: 0.2 |
There was a problem hiding this comment.
you might want to document this metrics
| num_steps: 1 | ||
| eval_rate: 0.5 | ||
| pruned_params: 'conv6_sep_weights' | ||
| sensitivities_file: 'mobilenet_acc_top1_sensitive.data' |
| num_steps: 1 | ||
| eval_rate: 0.5 | ||
| pruned_params: '.*_sep_weights' | ||
| sensitivities_file: 'mobilenet_acc_top1_sensitive.data' |
| if 'init_model' in factory.compress_pass: | ||
| self.init_model = factory.compress_pass['init_model'] | ||
|
|
||
| def _init_model(self, context): |
There was a problem hiding this comment.
I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju
| strategies = self.strategies | ||
| if self.checkpoint_path: | ||
| if not os.path.exists(self.checkpoint_path): | ||
| os.makedirs(self.checkpoint_path) |
There was a problem hiding this comment.
这里是load_checkpoint,路径还可能不存在,要makerdirs?
| executor = SlimGraphExecutor(self.place) | ||
|
|
||
| for epoch in range(self.epoch): | ||
| reader = feed_reader( |
There was a problem hiding this comment.
建议去掉feed_reader,用户外面传进来的reader,可能已经是多线程/进程了
| """ | ||
| Runing evaluation. | ||
| """ | ||
| results, names = context.run_eval_graph() |
There was a problem hiding this comment.
这个class是各种压缩算法公用的吧? 如果是Context里的run_eval_graph可以直接放这里吗?
There was a problem hiding this comment.
Compressor(CompressPass)对象包含了strategy, 从层级关系上看,strategy不能直接访问Compressor对象的。
如果strategy.on_compression_epoch(compressor):
- 两个对象相互依赖
- strategy能访问到的内容太多,比如可以直接调用compressor.run()形成死循环
context相当于把compressor中允许且需要strategy访问的信息和能力封装起来,供strategy使用。同时,compressor自己也可以使用,比如这一行。
| graph.program.global_block().var(name) for name in fetches | ||
| ] | ||
| results = self.exe.run(graph.program, | ||
| def run(self, graph, scope, data=None, feed=None, fetches=None): |
There was a problem hiding this comment.
add comments, data是什么类型?
| graph.program.global_block().var(name) for name in fetches | ||
| ] | ||
| results = self.exe.run(graph.program, | ||
| def run(self, graph, scope, data=None, feed=None, fetches=None): |
There was a problem hiding this comment.
add comments, data, feed, fetches是什么类型?
| Args: | ||
| context(slim.core.Context): The context storing all information used to evaluate the current model. | ||
| sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. | ||
| cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. |
There was a problem hiding this comment.
cache_id是随机抽取的数据子集的唯一标识。
2. Fix cache reader 3. Rename CompressPass to Compressor 4. Add comments for distiller optimizer 5. Remove unused pruner currently 6. Add some comments. 7. Change API.spec test=develop
test=develop
test=develop
23bd353 to
b7f8b4f
Compare
test=develop
b7f8b4f to
071accc
Compare