Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e4f2752
fix reducer
ForFishes Mar 11, 2021
5bf0086
add utest
ForFishes Mar 11, 2021
449e090
fix untest
ForFishes Mar 12, 2021
26d1b48
support control flow
ForFishes Mar 15, 2021
a798b97
support control flow
ForFishes Mar 15, 2021
45fed37
fix utest
ForFishes Mar 15, 2021
32b2ab5
fix bug
ForFishes Mar 17, 2021
c6c55f4
add utest
ForFishes Mar 17, 2021
930a58f
fix utest
ForFishes Mar 18, 2021
7fbb5a2
fix untest
ForFishes Mar 18, 2021
2583f1c
fix bckl
ForFishes Mar 18, 2021
70716fc
fix utest
ForFishes Mar 18, 2021
2143ff5
support fleet
ForFishes Mar 18, 2021
7706e9d
fix comment
ForFishes Mar 18, 2021
a4212df
fix fleet api
ForFishes Mar 18, 2021
dae49d7
fix utest
ForFishes Mar 18, 2021
2d4f4d5
fix bug
ForFishes Mar 18, 2021
cc7fb0c
fix comment
ForFishes Mar 18, 2021
ac3e2c0
fix bug
ForFishes Mar 18, 2021
3f541fe
fix files
ForFishes Mar 19, 2021
5f3a3fd
fix utest
ForFishes Mar 19, 2021
97d08eb
fix utest
ForFishes Mar 19, 2021
a356f61
add sync buffer and param
ForFishes Mar 18, 2021
42e0c11
add sync param
ForFishes Mar 18, 2021
39dab1a
fix bug
ForFishes Mar 20, 2021
e4f5a4d
fix utest
ForFishes Mar 20, 2021
a807956
fix cmake
ForFishes Mar 20, 2021
8f67e1f
fix coverage
ForFishes Mar 20, 2021
be4072b
add gradient check
ForFishes Mar 20, 2021
93bac9a
fix coverage
ForFishes Mar 20, 2021
1ce2b2b
fix utest
ForFishes Mar 20, 2021
e619554
add test for nonevar && find_unused_parameters
ForFishes Mar 21, 2021
037c225
supoort xpu in sync_parameters_buffers
ForFishes Mar 21, 2021
8f25efa
fix xpu
ForFishes Mar 21, 2021
9997354
fix ctest
ForFishes Mar 21, 2021
0dee9fd
add test for dataparallel
ForFishes Mar 21, 2021
8ad02cc
fix utest
ForFishes Mar 21, 2021
85a758b
fix comment
ForFishes Mar 21, 2021
7a9c0a9
fix small bug for redcuer
ForFishes Mar 22, 2021
c3af47e
fix parallel_dygraph_dataparallel
ForFishes Mar 23, 2021
d934171
solve compute stream & comm stream conflict
ForFishes Mar 23, 2021
39caee0
fix the bug of sparse embedding
ForFishes Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ message DistributedStrategy {
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional bool find_unused_parameters = 28 [ default = true ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
platform::errors::OutOfRange("Ring id expected < nrings,"
"but got ring id = %d, nrings = %d",
ring_id, strategy_.nrings_));
// TODO(wangxi16): [Performance optimize] Maybe need to put Wait and
// bkcl_allreduce to comm thread, for bkcl_allreduce is blocking now.
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
Expand All @@ -167,6 +165,12 @@ void BKCLParallelContext::WaitComm(int ring_id) {
comm_dev_ctx->Wait();
}

void BKCLParallelContext::SynchronizeCompute() {
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}

} // namespace imperative
} // namespace paddle
#endif
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/bkcl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class BKCLParallelContext : public ParallelContext {
void WaitCompute(int ring_id) override;

void WaitComm(int ring_id) override;

void SynchronizeCompute() override;
};

} // namespace imperative
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/imperative/nccl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ void NCCLParallelContext::WaitComm(int ring_id) {
#endif
}

void NCCLParallelContext::SynchronizeCompute() {
auto *compute_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}

#endif

} // namespace imperative
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/nccl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class NCCLParallelContext : public ParallelContext {

void WaitComm(int ring_id) override;

void SynchronizeCompute() override;

private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::CudaEventObject>> compute_events_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/imperative/parallel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class ParallelContext {
// if CPU, should do nothing.
virtual void WaitComm(int ring_id) = 0;

// synchorize compute stream
virtual void SynchronizeCompute() = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

亮哥,顺路把SynchronizeCompute在bkcl_context里面也加下哈~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加。


inline int GetNRings() const { return strategy_.nrings_; }

inline int64_t GetNRanks() const { return strategy_.nranks_; }
Expand Down
Loading