Add basic hook classes for dygraph & implement reduce hook#28584
Add basic hook classes for dygraph & implement reduce hook#28584chenwhql merged 5 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
| } | ||
|
|
||
| private: | ||
| std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_; |
There was a problem hiding this comment.
May be it can call 'leaf_var_hooks_' , and can be better distinguished from 'backward_hooks_' . After all, both of them are hooks for backward. Isn't 'backward_hooks_' here for Allreduce/Reduce only?
There was a problem hiding this comment.
- my opinion: the class name
LeafVarHookPackagealreaady hold theleaf varinfo, thehooksinLeafVarHookPackageareleaf_var_hooks_, using long member name cause information redundancy and also make the interface name longer, such asLeafVarHookPackage.add_leaf_var_hook() backward_hooks_mean the hooks ofwhole backward process, because it relay on leaf var, so we can only put it here now, may be we should addAccumulateGrad dummy OpNodeand movebackward_hooks_outside, I wiil perfect the comments here
There was a problem hiding this comment.
And backward_hooks_ may not only used for Allreduce/Reduce, we should keep scalability here
| << ref_cnt_; | ||
| // After all tmp gradient being accumulated to grad var, run hooks | ||
| if (AccumulateCompleted() && HasPostHooks()) { | ||
| CallBackwardPostHooks(); |
There was a problem hiding this comment.
Here call backward_hooks_, how about when AccumulateCompleted, first call_hooks_ , then gradient_accumulation between batch, last call backward_hooks_ .
So We must have two function: CallPostHooks, and CallBackwardPostHooks. And this can changed after this PR merged.
| } | ||
|
|
||
| private: | ||
| std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_; |
|
LGTM for |
PR types
New features
PR changes
Others
Describe
Add basic hook classes for dygraph & implement reduce hook
执行逻辑设计
由前向VarBase拿到前向VariableWrapper, 通过VariableWrapper的接口注册LeafGradHook
反向执行Engine准备执行环境时将hook关联到GradientAccumulator
当反向执行梯度累加完成时,执行关联的hook
简单hook示例