Skip to content

Commit 9a9f162

Browse files
committed
refactor: unify and simplify create_process_group implementation.
1 parent f56a907 commit 9a9f162

File tree

9 files changed

+56
-69
lines changed

9 files changed

+56
-69
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ DEFINE_bool(enable_constrained_decoding,
465465
"that the output meets specific format or structural requirements "
466466
"through pre-defined rules.");
467467

468-
469468
#if defined(USE_NPU)
470469
DEFINE_string(
471470
npu_kernel_backend,

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#endif
2929
#include "common/global_flags.h"
3030
#include "parallel_args.h"
31+
#include "process_group.h"
3132
#include "util/net.h"
3233

3334
namespace xllm {

xllm/core/framework/parallel_state/cuda_process_group.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ limitations under the License.
2121

2222
namespace xllm {
2323

24-
class ProcessGroupNccl : public ProcessGroup {
24+
class ProcessGroupImpl : public ProcessGroup {
2525
public:
26-
ProcessGroupNccl(int global_rank,
27-
int world_size,
28-
int rank_size,
29-
int port,
26+
ProcessGroupImpl(int32_t global_rank,
27+
int32_t world_size,
28+
int32_t rank_size,
29+
int32_t port,
3030
bool trans,
3131
const std::string& host,
3232
const std::string& group_name,
@@ -38,7 +38,7 @@ class ProcessGroupNccl : public ProcessGroup {
3838
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
3939
pg_options->group_name = group_name;
4040
#endif
41-
int rank = global_rank;
41+
int32_t rank = global_rank;
4242
if (world_size != rank_size) {
4343
auto [local_rank, group_ranks] =
4444
get_group_rank(world_size, global_rank, rank_size, trans);
@@ -52,16 +52,4 @@ class ProcessGroupNccl : public ProcessGroup {
5252
}
5353
};
5454

55-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
56-
int rank,
57-
int world_size,
58-
int rank_size,
59-
int port,
60-
bool trans,
61-
const std::string& host,
62-
const std::string& group_name,
63-
const torch::Device& device) {
64-
return std::make_unique<ProcessGroupNccl>(
65-
rank, world_size, rank_size, port, trans, host, group_name, device);
66-
}
6755
} // namespace xllm

xllm/core/framework/parallel_state/mlu_process_group.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ namespace xllm {
2323

2424
constexpr int32_t local_device_count = 8;
2525

26-
class ProcessGroupCncl : public ProcessGroup {
26+
class ProcessGroupImpl : public ProcessGroup {
2727
public:
28-
ProcessGroupCncl(int32_t global_rank,
28+
ProcessGroupImpl(int32_t global_rank,
2929
int32_t world_size,
3030
int32_t rank_size,
3131
int32_t port,
@@ -57,17 +57,4 @@ class ProcessGroupCncl : public ProcessGroup {
5757
}
5858
};
5959

60-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
61-
int32_t rank,
62-
int32_t world_size,
63-
int32_t rank_size,
64-
int32_t port,
65-
bool trans,
66-
const std::string& host,
67-
const std::string& group_name,
68-
const torch::Device& device) {
69-
return std::make_unique<ProcessGroupCncl>(
70-
rank, world_size, rank_size, port, trans, host, group_name, device);
71-
}
72-
7360
} // namespace xllm

xllm/core/framework/parallel_state/npu_process_group.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ namespace {
3232

3333
namespace xllm {
3434

35-
ProcessGroupHCCL::ProcessGroupHCCL(int global_rank,
36-
int world_size,
37-
int rank_size,
38-
int port,
35+
ProcessGroupImpl::ProcessGroupImpl(int32_t global_rank,
36+
int32_t world_size,
37+
int32_t rank_size,
38+
int32_t port,
3939
bool trans,
4040
const std::string& host,
4141
const std::string& group_name,
@@ -47,7 +47,7 @@ ProcessGroupHCCL::ProcessGroupHCCL(int global_rank,
4747
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
4848
hccl_pg_options->group_name = group_name;
4949
#endif
50-
int rank = global_rank;
50+
int32_t rank = global_rank;
5151
if (world_size != rank_size) {
5252
auto [local_rank, group_ranks] =
5353
get_group_rank(world_size, global_rank, rank_size, trans);
@@ -65,31 +65,18 @@ ProcessGroupHCCL::ProcessGroupHCCL(int global_rank,
6565
}
6666

6767
// Destructor.
68-
ProcessGroupHCCL::~ProcessGroupHCCL() {
68+
ProcessGroupImpl::~ProcessGroupImpl() {
6969
if (pg_) {
7070
pg_->shutdown();
7171
} else {
7272
HCCLCHECK(HcclCommDestroy(comm_));
7373
}
7474
}
7575

76-
ProcessGroupHCCL::ProcessGroupHCCL(int rank,
76+
ProcessGroupImpl::ProcessGroupImpl(int rank,
7777
int world_size,
7878
const torch::Device& device,
7979
HcclComm comm)
8080
: ProcessGroup(device), comm_(comm) {}
8181

82-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
83-
int rank,
84-
int world_size,
85-
int rank_size,
86-
int port,
87-
bool trans,
88-
const std::string& host,
89-
const std::string& group_name,
90-
const torch::Device& device) {
91-
return std::make_unique<ProcessGroupHCCL>(
92-
rank, world_size, rank_size, port, trans, host, group_name, device);
93-
}
94-
9582
} // namespace xllm

xllm/core/framework/parallel_state/npu_process_group.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ limitations under the License.
2020

2121
namespace xllm {
2222

23-
class ProcessGroupHCCL : public ProcessGroup {
23+
class ProcessGroupImpl : public ProcessGroup {
2424
public:
2525
// Constructor.
26-
ProcessGroupHCCL(int rank,
26+
ProcessGroupImpl(int rank,
2727
int world_size,
2828
const torch::Device& device,
2929
HcclComm comm);
3030

31-
ProcessGroupHCCL(int rank,
31+
ProcessGroupImpl(int rank,
3232
int world_size,
3333
int rank_size,
3434
int port,
@@ -38,20 +38,10 @@ class ProcessGroupHCCL : public ProcessGroup {
3838
const torch::Device& device);
3939

4040
// Destructor.
41-
~ProcessGroupHCCL() override;
41+
~ProcessGroupImpl() override;
4242

4343
private:
4444
HcclComm comm_ = nullptr;
4545
};
4646

47-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
48-
int rank,
49-
int world_size,
50-
int rank_size,
51-
int port,
52-
bool trans,
53-
const std::string& host,
54-
const std::string& group_name,
55-
const torch::Device& device);
56-
5747
} // namespace xllm

xllm/core/framework/parallel_state/parallel_state.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ std::vector<std::unique_ptr<ProcessGroup>> create_npu_process_groups(
215215
std::vector<std::unique_ptr<ProcessGroup>> process_groups;
216216
process_groups.reserve(devices.size());
217217
for (int i = 0; i < world_size; ++i) {
218-
process_groups.emplace_back(std::make_unique<ProcessGroupHCCL>(
218+
process_groups.emplace_back(std::make_unique<ProcessGroupImpl>(
219219
/*rank=*/i, world_size, devices[i], comms[i]));
220220
}
221221

xllm/core/framework/parallel_state/process_group.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ limitations under the License.
1515

1616
#include "process_group.h"
1717

18+
#if defined(USE_NPU)
19+
#include "npu_process_group.h"
20+
#elif defined(USE_MLU)
21+
#include "mlu_process_group.h"
22+
#elif defined(USE_CUDA)
23+
#include "cuda_process_group.h"
24+
#endif
25+
1826
namespace {
1927
std::pair<int, std::vector<uint64_t>> get_trans_group_rank(int world_size,
2028
int global_rank,
@@ -75,4 +83,18 @@ void ProcessGroup::allgather(const torch::Tensor& input,
7583
std::vector<std::vector<torch::Tensor>> output_tensors = {outputs};
7684
pg_->allgather(output_tensors, input_tensors)->wait();
7785
}
86+
87+
std::unique_ptr<ProcessGroup> create_process_group(
88+
int32_t rank,
89+
int32_t world_size,
90+
int32_t rank_size,
91+
int32_t port,
92+
bool trans,
93+
const std::string& host,
94+
const std::string& group_name,
95+
const torch::Device& device) {
96+
return std::unique_ptr<ProcessGroup>(new ProcessGroupImpl(
97+
rank, world_size, rank_size, port, trans, host, group_name, device));
98+
}
99+
78100
} // namespace xllm

xllm/core/framework/parallel_state/process_group.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ limitations under the License.
2525
#endif
2626

2727
namespace xllm {
28+
29+
class ProcessGroupImpl;
30+
2831
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
2932
int global_rank,
3033
int split_size,
@@ -77,4 +80,14 @@ class ProcessGroup {
7780
#endif
7881
};
7982

83+
std::unique_ptr<xllm::ProcessGroup> create_process_group(
84+
int32_t rank,
85+
int32_t world_size,
86+
int32_t rank_size,
87+
int32_t port,
88+
bool trans,
89+
const std::string& host,
90+
const std::string& group_name,
91+
const torch::Device& device);
92+
8093
} // namespace xllm

0 commit comments

Comments
 (0)