Conversation
f40edfb to
779a2f5
Compare
test=develop
779a2f5 to
67badd2
Compare
test=develop
| viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path)); | ||
| } | ||
|
|
||
| if (strategy.fuse_elewise_add_act_ops_) { |
There was a problem hiding this comment.
It is not deleted but only moved to here.
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" |
There was a problem hiding this comment.
This should be after the system header files.
| } | ||
|
|
||
| bool MultiDevSSAGraphBuilderBase::DealWithSpecialOp(ir::Graph *result, | ||
| ir::Node *node) const { |
There was a problem hiding this comment.
Why add an empty function?
There was a problem hiding this comment.
Just move AllReduceSSAGraphBuilder::DealWithSpecialOp to here.
|
|
||
| outputs.insert(outputs.end(), op_handle.Outputs().begin(), | ||
| op_handle.Outputs().end()); | ||
| // Remove Input |
There was a problem hiding this comment.
Do we need to remove the pointer in kGraphVars attr?
There was a problem hiding this comment.
Only the origin all_reduce nodes should be removed from graph,and it is done in line 132.
05d55ec to
0ede8fc
Compare
test=develop
00646cc to
cd49a17
Compare
test=develop
… fuse_all_reduce test=develop
|
Do we need a unittest to make sure the fused allreduce on all trainers hold the same order? |
|
@gongweibao |
test=develop
… fuse_all_reduce test=develop
|
Another question: Where is the address alignment implemented? |
| auto iter = vars.find(p_g.second); | ||
| PADDLE_ENFORCE(iter != vars.end()); | ||
| PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); | ||
| PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(), |
There was a problem hiding this comment.
目前阶段要求类型必须是lodtensor,下一个阶段可能会做这方面的改进。
| // Run Only Once Programs | ||
| for (size_t i = 0; i < local_scopes.size(); ++i) { | ||
| for (auto &op_desc : program_desc.Block(0).AllOps()) { | ||
| auto op = OpRegistry::CreateOp(*op_desc); |
There was a problem hiding this comment.
Maybe we can set a new tensor member to avoid fused-tensor be resized.
There was a problem hiding this comment.
fused_all_reduce_op_handle has address check, so it is unnecessary.
… fuse_all_reduce test=develop
test=develop
test=develop
f6b6639 to
601dd3c
Compare
… fuse_all_reduce test=develop
test=develop
… fuse_all_reduce test=develop
gongweibao
left a comment
There was a problem hiding this comment.
LGTM+++
Gradient fusion is very useful for multi-machines training communication.Thanks!
Code separated from #15497
Fix part of #16061
For ResNet:
For Transformer: