-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathallreduceOp.cpp
More file actions
1233 lines (1078 loc) · 51 KB
/
allreduceOp.cpp
File metadata and controls
1233 lines (1078 loc) · 51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/mcastDevMemUtils.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
#include "tensorrt_llm/runtime/mcastDeviceMemory.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/fp4Quantize.h"
#include "tensorrt_llm/thop/fp8Op.h"
#include "tensorrt_llm/thop/thUtils.h"
#include "tensorrt_llm/thop/userbuffersTensor.h"
#if ENABLE_MULTI_DEVICE
#include <ATen/cuda/EmptyTensor.h>
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <nvml.h>
#include <torch/extension.h>
#include <cstddef>
#include <cstdint>
#include <unordered_set>
// using namespace nvinfer1;
using tensorrt_llm::kernels::AllReduceFusionOp;
using tensorrt_llm::kernels::AllReduceStrategyType;
using tensorrt_llm::mpi::MpiTag;
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
namespace
{
class NvmlManager
{
public:
NvmlManager()
{
NVML_CHECK_THROW(nvmlInit());
}
~NvmlManager()
{
NVML_CHECK(nvmlShutdown());
}
};
std::set<int> getLocalGroup(std::set<int> const& group)
{
auto const myRank = COMM_SESSION.getRank();
auto const myLocalRank = LOCAL_COMM_SESSION.getRank();
auto const localSize = static_cast<uint32_t>(LOCAL_COMM_SESSION.getSize());
std::vector<int32_t> ranks(localSize, 0);
std::vector<int32_t> localRanks(localSize, 0);
if (group.size() >= localSize)
{
LOCAL_COMM_SESSION.allgather(&myRank, ranks.data(), 1, tensorrt_llm::mpi::MpiType::kINT32);
LOCAL_COMM_SESSION.allgather(&myLocalRank, localRanks.data(), 1, tensorrt_llm::mpi::MpiType::kINT32);
}
else
{
if (myRank == *group.begin())
{
ranks.clear();
int rank;
ranks.push_back(myRank);
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.recvValue(rank, *it, MpiTag::kDefault);
ranks.push_back(rank);
}
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.send(
ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, MpiTag::kDefault);
}
localRanks.clear();
localRanks.push_back(myLocalRank);
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.recvValue(rank, *it, MpiTag::kDefault);
localRanks.push_back(rank);
}
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
LOCAL_COMM_SESSION.send(
localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, MpiTag::kDefault);
}
}
else
{
LOCAL_COMM_SESSION.sendValue(myRank, *group.begin(), MpiTag::kDefault);
LOCAL_COMM_SESSION.recv(
ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), MpiTag::kDefault);
LOCAL_COMM_SESSION.sendValue(myLocalRank, *group.begin(), MpiTag::kDefault);
LOCAL_COMM_SESSION.recv(
localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), MpiTag::kDefault);
}
}
std::set<int> localGroup;
for (size_t i = 0; i < ranks.size(); ++i)
{
auto rank = ranks[i];
if (group.find(rank) != group.end())
{
localGroup.insert(localRanks[i]);
}
}
return localGroup;
}
class AllreduceOp
{
public:
AllreduceOp(
std::set<int> group, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps)
: mGroup(std::move(group))
, mType(type)
, mStrategy(strategy)
, mOp(op)
, mEps(eps)
{
}
~AllreduceOp() = default;
std::vector<torch::Tensor> run(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, bool trigger_completion_at_end,
torch::optional<torch::Tensor> workspace) noexcept
{
size_t size = input.numel();
size_t seq_len = input.size(0);
// If strategy is set to UB, UB must be used as UB impl output is special and cannot be used
// by others.
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
// Log runtime strategy
auto const rank = COMM_SESSION.getRank();
logRunTimeStrategy(runtime_strategy, rank);
// Dispatch to different allreduce implementations
switch (runtime_strategy)
{
case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias);
case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias);
case AllReduceStrategyType::MIN_LATENCY:
case AllReduceStrategyType::ONESHOT:
case AllReduceStrategyType::TWOSHOT:
return runFusionAllReduce(
input, residual, norm_weight, scale, bias, trigger_completion_at_end, workspace, runtime_strategy);
case AllReduceStrategyType::LOWPRECISION:
return runLowPrecisionAllReduce(input, residual, norm_weight, scale, bias);
default: TORCH_CHECK(false, "Invalid runtime strategy"); return {};
}
}
int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB)
{
initGroupTopology();
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
return 0;
}
private:
std::vector<torch::Tensor> runUBAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
int hidden_size = input.size(-1);
torch::Tensor residual_out = torch::empty_like(input);
TLLM_CHECK(mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4);
TLLM_CHECK_WITH_INFO(tensorrt_llm::runtime::ub::ub_is_initialized(), "UserBuffer has not been initialized!");
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
TLLM_CHECK(!ub_buffer0.invalid());
auto ub_comm = ub_manager.comm();
int m = size / hidden_size;
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
TORCH_CHECK(norm_weight, "norm_weight is required for residual rms norm allreduce");
TORCH_CHECK(!bias, "bias is not supported for residual rms norm allreduce");
TORCH_CHECK(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16);
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
tensorrt_llm::kernels::ub::allreduce2_userbuff_rmsnorm_launcher(ub_buffer0.handle, 0, ub_buffer1.handle, 0,
size, hidden_size, nullptr, norm_weight.value().data_ptr(), mEps, residual.value().data_ptr(),
residual_out.data_ptr(), mType, ub_comm, stream);
return {norm_out, residual_out};
}
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8)
{
TORCH_CHECK(scale, "scale is required for FP8 allreduce");
TORCH_CHECK(norm_weight, "norm_weight is required for FP8 allreduce");
TORCH_CHECK(!bias, "bias is not supported for FP8 allreduce");
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), torch::kFloat8_e4m3fn);
tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_quant_launcher(ub_buffer0.handle, 0,
ub_buffer1.handle, 0, size, hidden_size, nullptr, norm_weight.value().data_ptr(), mEps,
static_cast<float*>(scale.value().data_ptr()), residual.value().data_ptr(), residual_out.data_ptr(),
mType, ub_comm, stream);
return {norm_out, residual_out};
}
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
{
TORCH_CHECK(scale, "scale is required for FP4 allreduce");
TORCH_CHECK(norm_weight, "norm_weight is required for FP4 allreduce");
TORCH_CHECK(!bias, "bias is not supported for FP4 allreduce");
int const sfVecSize = 16;
int scale_size
= tensorrt_llm::common::roundUp(m, 128) * tensorrt_llm::common::roundUp(hidden_size / sfVecSize, 4);
TORCH_CHECK(hidden_size % sfVecSize == 0, "hidden_size must be divisible by 16 for FP4 allreduce");
auto output_shape = input.sizes().vec();
output_shape.back() /= 2;
auto output_strides = input.strides().vec();
for (size_t i = 0; i < output_shape.size() - 1; i++)
{
output_strides[i] /= 2;
}
auto [quant_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(output_shape, torch::kByte);
auto [scale_out, ub_buffer2] = torch_ext::create_userbuffers_tensor({scale_size}, torch::kByte);
tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_quant_fp4_launcher(ub_buffer0.handle, 0,
ub_buffer1.handle, 0, ub_buffer2.handle, 0, size, hidden_size, nullptr, norm_weight.value().data_ptr(),
mEps, static_cast<float*>(scale.value().data_ptr()), residual.value().data_ptr(),
residual_out.data_ptr(), mType, ub_comm, stream);
return {quant_out, scale_out, residual_out};
}
TORCH_CHECK(
false, "UBAllreduce does not support the fusion operation: " + tensorrt_llm::kernels::toString(mOp));
return {};
}
std::vector<torch::Tensor> runNCCLAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias)
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
torch::Tensor reduce_output = torch::empty_like(input);
NCCLCHECK_THROW(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
ncclSum, *mNcclComm, stream));
if (mOp == AllReduceFusionOp::NONE)
{
return {reduce_output};
}
// Treat any other patterns as fallback cases.
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
}
std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
{
#ifdef ENABLE_FP8
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
int hidden_size = input.size(-1);
auto const tp_size = mGroup.size();
auto const cur_rank = COMM_SESSION.getRank();
int tp_rank = 0;
for (auto const& currentRank : mGroup)
{
if (cur_rank == currentRank)
break;
++tp_rank;
}
int bytes_per_element = input.element_size();
int token_num = size / hidden_size;
auto parts = tensorrt_llm::kernels::splitNumber(size);
torch::Tensor reduce_output = torch::empty_like(input);
size_t global_offset = 0;
for (size_t i = 0; i < parts.size(); ++i)
{
size_t tmp_size = parts[i];
tensorrt_llm::kernels::LowPrecisionAllReduceParams tmp_param;
if (tp_size <= 4)
{
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize(
tp_size, tp_rank, mType, token_num, hidden_size);
}
else
{
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize_hier(
tp_size, tp_rank, mType, token_num, hidden_size);
}
tmp_param.local_input_buffer_ptr = reinterpret_cast<void const*>(
reinterpret_cast<char const*>(input.data_ptr()) + global_offset * bytes_per_element);
tmp_param.local_output_buffer_ptr = reinterpret_cast<void*>(
reinterpret_cast<char*>(reduce_output.mutable_data_ptr()) + global_offset * bytes_per_element);
tmp_param.elts_total = tmp_size;
tensorrt_llm::kernels::customLowPrecisionAllReduce(tmp_param, mType, stream);
global_offset += tmp_size;
}
if (mOp == AllReduceFusionOp::NONE)
{
return {reduce_output};
}
// Treat any other patterns as fallback cases.
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
#else
C10_THROW_ERROR(NotImplementedError, "Can't use LOWPRECISION without compile with ENABLE FP8.");
#endif
}
std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
bool trigger_completion_at_end, torch::optional<torch::Tensor> workspace,
AllReduceStrategyType strategy) noexcept
{
// Should handle only Lamport implementation
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
int hidden_size = input.size(-1);
int seq_len = input.size(0);
auto const tp_size = mGroup.size();
auto const cur_rank = COMM_SESSION.getRank();
int tp_rank = 0;
for (auto const& currentRank : mGroup)
{
if (cur_rank == currentRank)
break;
++tp_rank;
}
// Use cleaner output assigning
torch::Tensor reduce_out;
torch::Tensor residual_out;
torch::Tensor norm_out;
torch::Tensor quant_out;
torch::Tensor scale_out;
tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params;
allreduce_fusion_params.residual_in = nullptr;
allreduce_fusion_params.rms_gamma = nullptr;
allreduce_fusion_params.allreduce_out = nullptr;
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
allreduce_fusion_params.norm_out = nullptr;
allreduce_fusion_params.trigger_completion_at_end = trigger_completion_at_end;
// Determine if using oneshot or twoshot allreduce kernel
if (strategy == AllReduceStrategyType::MIN_LATENCY)
{
allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken;
}
else
{
allreduce_fusion_params.use_oneshot = strategy == AllReduceStrategyType::ONESHOT;
}
// Check for some kernel constraints if using TWOSHOT kernel
if (!allreduce_fusion_params.use_oneshot)
{
TORCH_CHECK(input.size(0) >= static_cast<int64_t>(tp_size),
"Sequence length must be greater than or equal to TP size");
}
// Handle no fusion allreduce here
if (mOp == AllReduceFusionOp::NONE)
{
reduce_out = torch::empty_like(input);
allreduce_fusion_params.allreduce_out = reduce_out.mutable_data_ptr();
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllReduce;
}
// Handle allreduce fusion here
// Prepare required output tensors for each fusion pattern
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
norm_out = torch::empty_like(input);
residual_out = torch::empty_like(residual.value());
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNorm;
}
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8)
{
quant_out = at::detail::empty_cuda(input.sizes(), torch::kFloat8_e4m3fn, input.device(), std::nullopt);
residual_out = torch::empty_like(residual.value());
allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP8Quant;
// norm out is required
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8)
{
norm_out = torch::empty_like(input);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant;
}
}
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
// TODO: Better check for each pattern
int64_t sf_vec_size = 16;
int64_t m = 1;
auto const& input_shape = input.sizes();
auto const& r = input_shape.size();
TORCH_CHECK(r >= 2, "Input should be >=2D tensor.");
for (size_t i = 0; i < r - 1; i++)
{
m *= input_shape[i];
}
auto const k = input_shape[r - 1];
TORCH_CHECK(k % sf_vec_size == 0, "Input should be divisible by sfVecSize.");
std::vector<int64_t> output_shape(input_shape.begin(), input_shape.end());
output_shape[r - 1] = k / 2;
quant_out = at::detail::empty_cuda(output_shape, FLOAT4_E2M1X2, input.device(), std::nullopt);
scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sf_vec_size)},
SF_DTYPE, input.device(), std::nullopt);
residual_out = torch::empty_like(residual.value());
allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr();
allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant;
// norm out is required
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
norm_out = torch::empty_like(input);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant;
}
}
else
{
TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp));
return {};
}
allreduce_fusion_params.nranks = tp_size;
allreduce_fusion_params.rank = tp_rank;
allreduce_fusion_params.dtype = mType;
allreduce_fusion_params.size = size;
allreduce_fusion_params.hidden_dim = hidden_size;
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.allreduce_in = input.data_ptr();
if (mOp != AllReduceFusionOp::NONE)
{
allreduce_fusion_params.residual_in = residual.value().data_ptr();
allreduce_fusion_params.rms_gamma = norm_weight.value().data_ptr();
allreduce_fusion_params.rms_eps = mEps;
}
allreduce_fusion_params.stream = stream;
bool const is_scale_factor_required = mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4;
allreduce_fusion_params.scale_factor
= is_scale_factor_required ? static_cast<float*>(scale.value().data_ptr()) : nullptr;
tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params);
// Pack output tensors
switch (mOp)
{
case AllReduceFusionOp::NONE: return {reduce_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM: return {norm_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: return {quant_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: return {norm_out, quant_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: return {quant_out, scale_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
return {norm_out, quant_out, scale_out, residual_out};
default: TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp));
}
return {};
}
std::vector<torch::Tensor> fallbackRunSubsequentOps(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
torch::Tensor& reduce_output)
{
// If we reach here, it means the extra fallback operations are required.
// All patterns are broken into ALlReduce + residual_rms_norm + following operations (quantization, etc.)
auto const size = input.numel();
auto const hidden_size = input.size(-1);
auto const stream = at::cuda::getCurrentCUDAStream(input.get_device());
torch::Tensor norm_out = torch::empty_like(input);
tensorrt_llm::kernels::AllReduceParams params;
params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr;
params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr;
params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr;
params.local_output_buffer_ptr = norm_out.mutable_data_ptr();
params.elts_total = size;
params.fusion_params.hidden_size = hidden_size;
params.fusion_params.eps = mEps;
params.fusion_params.intermediate_buffer = reduce_output.mutable_data_ptr();
tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, AllReduceFusionOp::RESIDUAL_RMS_NORM);
// If no quantization is needed, return the norm and residual outputs.
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
return {norm_out, reduce_output};
}
const int64_t sf_vecsize = 16;
bool const sf_use_ue8m0 = false;
bool const is_sf_swizzled_layout = true;
TORCH_CHECK(scale, "scale is required for quantization ops");
// Attach the subsequent operations after the residual RMS norm all-reduce and return the final outputs.
switch (mOp)
{
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
{
auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value());
return {quant_out, reduce_output};
}
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
{
auto [quant_out, scale_out]
= torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout);
return {quant_out, scale_out, reduce_output};
}
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
{
auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value());
return {norm_out, quant_out, reduce_output};
}
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
{
auto [quant_out, scale_out]
= torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout);
return {norm_out, quant_out, scale_out, reduce_output};
}
default: break;
}
TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp));
return {};
}
AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size)
{
AllReduceStrategyType runtime_strategy;
if (mStrategy == AllReduceStrategyType::UB)
{
runtime_strategy = AllReduceStrategyType::UB;
}
else if (mStrategy == AllReduceStrategyType::NCCL)
{
runtime_strategy = AllReduceStrategyType::NCCL;
}
else
{
// This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
static char* ifForBenchMark = std::getenv("OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY");
if (ifForBenchMark != nullptr)
{
runtime_strategy = mStrategy;
}
else
{
runtime_strategy = selectImplementation(seq_len, size, mGroup.size(), mType);
}
}
return runtime_strategy;
}
void logRunTimeStrategy(AllReduceStrategyType strategy, int rank)
{
switch (strategy)
{
case AllReduceStrategyType::NCCL:
{
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank);
break;
}
case AllReduceStrategyType::MIN_LATENCY:
{
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank);
break;
}
case AllReduceStrategyType::UB:
{
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank);
break;
}
case AllReduceStrategyType::LOWPRECISION:
{
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
break;
}
default: break;
}
}
void initGroupTopology()
{
static std::map<std::set<int>, std::tuple<bool, bool>> cache;
if (cache.find(mGroup) != cache.end())
{
auto [is_NVLINK_supported, is_P2P_supported] = cache[mGroup];
mIsNVLINKSupported = is_NVLINK_supported;
mIsP2PSupported = is_P2P_supported;
return;
}
setGroupTopology();
cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported};
}
void setGroupTopology()
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_INFO("Detecting local TP group for rank %d", rank);
std::set<int> local_group = getLocalGroup(mGroup);
if (mGroup.size() != local_group.size())
{
mIsP2PSupported = false;
mIsNVLINKSupported = false;
TLLM_LOG_INFO("Found inter-node TP group for rank %d", rank);
return;
}
TLLM_LOG_INFO("TP group is intra-node for rank %d", rank);
NvmlManager nvml_manager;
std::unordered_set<int> visited_device;
mIsP2PSupported = true;
mIsNVLINKSupported = true;
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,
// and use nvml to determine whether there are nvlink links between ranks.
for (int first_device_id : local_group)
{
for (int second_device_id : local_group)
{
if (first_device_id == second_device_id
|| visited_device.find(second_device_id) != visited_device.end())
{
continue;
}
int can_access_peer = 0;
TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id));
if (!can_access_peer)
{
mIsP2PSupported = false;
mIsNVLINKSupported = false;
return;
}
nvmlDevice_t first_device;
NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(first_device_id, &first_device));
bool is_NVLINK = false;
for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++)
{
nvmlPciInfo_t remote_pci_info;
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS)
{
continue;
}
nvmlDevice_t remote_device;
auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device);
if (result == NVML_SUCCESS)
{
// Two GPUs are connected directly through nvlink
unsigned int remote_device_id;
NVML_CHECK_THROW(nvmlDeviceGetIndex(remote_device, &remote_device_id));
if (remote_device_id == static_cast<unsigned int>(second_device_id))
{
is_NVLINK = true;
}
}
else if (result == NVML_ERROR_NOT_FOUND)
{
// Maybe Two GPUs are connected via nvswitch,
// now remotePciInfo represents the pci information of nvswitch,
// determine whether nvlink is supported by whether two GPUs are connected to the same
// nvswitch.
nvmlDevice_t second_device;
NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(second_device_id, &second_device));
for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++)
{
nvmlPciInfo_t second_remote_pci_info;
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(second_device, second_link, &second_remote_pci_info)
!= NVML_SUCCESS)
{
continue;
}
if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0)
{
is_NVLINK = true;
break;
}
}
}
else
{
NVML_CHECK_THROW(result);
}
if (is_NVLINK)
{
break;
}
}
mIsNVLINKSupported &= is_NVLINK;
}
visited_device.insert(first_device_id);
}
}
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto)
{
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
if (message_size_bytes > max_workspace_size)
{
if (!is_auto)
{
TLLM_LOG_WARNING(
"Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL");
}
return true;
}
// If Peer to Peer is not supported, fallback to NCCL.
if (!mIsP2PSupported)
{
if (!is_auto)
{
TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL");
}
return true;
}
// If NVLINK is not supported, fallback to NCCL.
if (!mIsNVLINKSupported)
{
if (!is_auto)
{
TLLM_LOG_WARNING("Since NVLINK not supported, fallback to AllReduceStrategy: NCCL");
}
return true;
}
return false;
}
AllReduceStrategyType selectImplementation(
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type)
{
if (isUsingLowPrecision(message_size))
{
return AllReduceStrategyType::LOWPRECISION;
}
else
{
if (mStrategy == AllReduceStrategyType::LOWPRECISION)
{
mStrategy = AllReduceStrategyType::AUTO;
}
}
// Check that heuristic is only applied when AUTO is set.
// Use Auto select
bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
auto const max_workspace_size
= tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(world_size);
if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size, is_auto))
{
return AllReduceStrategyType::NCCL;
}
// This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies.
// Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types.
// Because NCCL might be faster on some large messageSize cases.
// Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
// TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
// This should be compared with MIN_LATENCY fused kernels to determine the best strategy.
switch (mOp)
{
case AllReduceFusionOp::NONE:
case AllReduceFusionOp::RESIDUAL_RMS_NORM: break;
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY;
// Suppose NCCL has fallback implementations for all fusion types.
default: return AllReduceStrategyType::NCCL;
}
// Check mOp to be supported by the heuristic.
TORCH_CHECK(mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM,
"Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic.");
// Default to NCCL.
AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
// Currently we will not remove ONESHOT and TWOSHOT from the strategy list
// But torch flow user should not use them, but use AUTO or MIN_LATENCY instead.
// NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO,
// user should guarantee the correctness of the fusion pattern dispatching.
if (!is_auto)
{
if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT)
{
strategy = AllReduceStrategyType::MIN_LATENCY;
}
else
{
strategy = mStrategy;
}
}
else if (world_size <= 2)
{
strategy = AllReduceStrategyType::MIN_LATENCY;
}
else
{
static char* threshold_ptr = std::getenv("ALLREDUCE_AUTO_HEURISTIC_MIN_LATENCY_THRESHOLD_TOKEN_NUM");
size_t threshold = 128;
if (threshold_ptr)
{
threshold = static_cast<size_t>(std::atoi(threshold_ptr));
}
// Generally, NCCL is faster than MIN_LATENCY when the token number is greater than 256. I conservatively
// set the threshold here to 128 tokens.
if (seq_len > threshold)
{
strategy = AllReduceStrategyType::NCCL;
}
else
{
strategy = AllReduceStrategyType::MIN_LATENCY;
}
}
return strategy;
}
bool isUsingLowPrecision(size_t message_size) const noexcept
{
bool force_low_precision = mStrategy == AllReduceStrategyType::LOWPRECISION;
#ifdef ENABLE_FP8
// Use LowPrecision if PCIe and p2p support and message size is larger than 2MB
constexpr int LowPrecisionMinMessageSize = 2 * 1024 * 1024;
return force_low_precision && !mIsNVLINKSupported && mIsP2PSupported
&& message_size >= LowPrecisionMinMessageSize;
#else
// Low precision is not available when FP8 is not enabled
return false;
#endif
}
private:
std::set<int> mGroup;
bool mIsNVLINKSupported;
bool mIsP2PSupported;
nvinfer1::DataType mType;
AllReduceStrategyType mStrategy;
AllReduceFusionOp mOp;
float mEps;
std::shared_ptr<ncclComm_t> mNcclComm;
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> workspace,
torch::List<int64_t> const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_,
bool const trigger_completion_at_end_)
{
#if ENABLE_MULTI_DEVICE
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
auto const strategy = static_cast<AllReduceStrategyType>(int8_t(strategy_));
auto const fusion_op = static_cast<AllReduceFusionOp>(int8_t(fusion_op_));
float const eps = eps_;
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
AllreduceOp op(group, dtype, strategy, fusion_op, eps);
op.initialize();
return op.run(input, residual, norm_weight, scale, bias, trigger_completion_at_end_, workspace);
#else
return {input};
#endif // ENABLE_MULTI_DEVICE
}
// residual [m, hidden_dim]
// norm_weight [hidden_dim]
// device_num_experts [1]
// scale_input [global_num_experts, m]
// active_experts_token_input [device_num_experts, m, hidden_dim]
// token_input [m, hidden_dim]
std::vector<torch::Tensor> moe_allreduce(torch::Tensor const& residual, torch::Tensor const& norm_weight,
torch::Tensor const& device_num_experts, torch::Tensor const& scale_input,
torch::Tensor const& active_experts_token_input, torch::Tensor const& token_input, torch::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps)
{
auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams();