Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1173,17 +1173,17 @@ if(BUILD_TEST)
list(APPEND MULTIDEVICE_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
)
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
list(APPEND TEST_BINARIES test_multidevice)
Expand Down
2 changes: 1 addition & 1 deletion csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ void HostIrEvaluator::handle(ShardByStream* shard) {
IterDomain* stream_id = *i;

auto in_tensor = getKnownConcreteValue(shard->in()).as<at::Tensor>();
int64_t stream_index =
auto stream_index =
expr_evaluator_.evaluate(shard->stream_index()).as<int64_t>();
at::Tensor out_tensor =
in_tensor
Expand Down
2 changes: 1 addition & 1 deletion csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
NVF_ERROR(
(expr->isOneOf<Communication, P2PCommunication, EndCoalescing>()),
expr,
"must be a Communication, a P2PCommunication, or a EndCoalescing");
" must be a Communication, a P2PCommunication, or a EndCoalescing");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)
Expand Down
92 changes: 64 additions & 28 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ namespace nvfuser {
namespace {

struct LoopInfo {
hir::ForLoop* loop;
hir::ForLoop* loop = nullptr;

// The Scope that owns `loop`. It's one level outer than `loop`'s body scope.
Scope* parent_scope;
Scope* parent_scope = nullptr;

// The iterator that points to `loop`. This way, we can insert instructions,
// e.g. Allocate, right before the loop.
Scope::Iterator parent_insertion_point;
};

std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info) {
os << loop_info.loop->toInlineString();
if (loop_info.loop == nullptr) {
os << "<null>";
} else {
os << loop_info.loop->toInlineString();
}
return os;
}

Expand All @@ -57,6 +61,8 @@ class LoopNest {
return loop_infos_.back();
}

// Returns the scope of the innermost for-loop or the top-level scope if the
// loop nest is empty.
Scope& innermostScope() const {
return empty() ? top_level_ : innermost().loop->body();
}
Expand Down Expand Up @@ -131,7 +137,7 @@ Expr* cloneWithNewOperands(
int64_t out_replaced = std::ranges::count_if(new_outs, maybe_replace);

if (in_replaced == 0 && out_replaced == 0) {
return 0;
return e;
}

if (out_replaced > 0) {
Expand All @@ -151,6 +157,12 @@ void lowerSegment(
hir::HostIrContainer& hic,
LoopNest& loop_nest,
IrCloner& ir_cloner) {
Scope& innermost_scope = loop_nest.innermostScope();
LoopInfo innermost;
if (!loop_nest.empty()) {
innermost = loop_nest.innermost();
}

switch (group.schedulerType()) {
case SchedulerType::Communication: {
auto device_id = Communicator::getInstance().deviceId();
Expand All @@ -162,24 +174,50 @@ void lowerSegment(
// without cloning the value again.
Expr* e = ir_cloner.clone(group.exprs().front());

for (auto* c : convertSingleOpToCommunication(e, device_id)) {
// FIXME: should this be associated with the scope?
std::unordered_map<Val*, Val*> replacement_map;
for (Expr* c : convertSingleOpToCommunication(e, device_id)) {
NVF_ERROR(
c->isA<Communication>(),
"Exprs in a Communication group should be Communication: ",
c);
// Allocate the recv buffers of communications
auto* communication = c->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(device_id)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
// TODO: allocation may have to go to the top level. See how
// SchedulerType::ExprEval handles allocations.
loop_nest.innermostScope().push_back(allocate);
TensorView* in = communication->in();
TensorView* out = communication->out();
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
nullptr &&
getShardedIterDomain(
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, innermost.loop->index()));
if (inserted) {
innermost_scope.push_back(i->second->definition());
}
Comment on lines +188 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: null pointer dereference if loop_nest is empty. innermost.loop is null when loop_nest.empty() is true (line 161-164), but this code calls innermost.loop->index() on line 192 without checking if innermost.loop is null

Suggested change
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
nullptr &&
getShardedIterDomain(
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, innermost.loop->index()));
if (inserted) {
innermost_scope.push_back(i->second->definition());
}
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
nullptr &&
getShardedIterDomain(
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
NVF_ERROR(
!loop_nest.empty(),
"Stream-parallelized input requires a non-empty loop nest");
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, innermost.loop->index()));

Comment on lines +188 to +196
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This analysis is used at multiple locations. It can be moved to a util function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Not for this PR though. I'll try to fix a bug around 187 and then think about refactoring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes for myself: one consideration was that a communication segment always writes to a pre-allocated output and an expr-eval segment may or may not. But I'll think about how to DRY.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to fix a bug around 187

What is the bug around 187?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#5562

There's one more on top of that which I'll try to fix in the same PR.

}
Comment on lines +188 to 197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: null pointer dereference if loop_nest is empty on line 193. When getShardedIterDomain returns non-null but loop_nest.empty() is true, innermost.loop is null (set at lines 161-164), causing innermost.loop->index() to crash. The ExprEval case has an early return for empty loop_nest (line 251), but Communication case is missing this safeguard.

Suggested change
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
nullptr &&
getShardedIterDomain(
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, innermost.loop->index()));
if (inserted) {
innermost_scope.push_back(i->second->definition());
}
}
if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) !=
nullptr &&
getShardedIterDomain(
in, ParallelType::Stream, DomainType::kAllocation) == nullptr) {
NVF_ERROR(!loop_nest.empty(), "Stream-parallelized input requires a loop nest");
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, innermost.loop->index()));
if (inserted) {
innermost_scope.push_back(i->second->definition());
}
}

loop_nest.innermostScope().push_back(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
loop_nest.innermostScope().push_back(wait);

// Allocate the recv buffers of communications
auto* allocate =
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
auto [i, inserted] = replacement_map.try_emplace(
out, hir::shardByStream(out, innermost.loop->index()));
NVF_ERROR(inserted);
innermost_scope.push_back(i->second->definition());
Comment on lines +202 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: null pointer dereference if loop_nest is empty. Multiple accesses to innermost.loop and innermost.parent_scope without null checks

Suggested change
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
auto [i, inserted] = replacement_map.try_emplace(
out, hir::shardByStream(out, innermost.loop->index()));
NVF_ERROR(inserted);
innermost_scope.push_back(i->second->definition());
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
NVF_ERROR(
!loop_nest.empty(),
"Stream-parallelized output requires a non-empty loop nest");
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);

} else {
innermost_scope.push_back(allocate);
}
Comment on lines +202 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: null pointer dereference if loop_nest is empty. Multiple accesses to innermost.loop->index() (line 210), innermost.parent_scope (line 207), and innermost.parent_insertion_point (line 208) without null checks. Same issue as lines 188-197.

Suggested change
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
auto [i, inserted] = replacement_map.try_emplace(
out, hir::shardByStream(out, innermost.loop->index()));
NVF_ERROR(inserted);
innermost_scope.push_back(i->second->definition());
} else {
innermost_scope.push_back(allocate);
}
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
NVF_ERROR(!loop_nest.empty(), "Stream-parallelized output requires a loop nest");
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
auto [i, inserted] = replacement_map.try_emplace(
out, hir::shardByStream(out, innermost.loop->index()));
NVF_ERROR(inserted);
innermost_scope.push_back(i->second->definition());
} else {
innermost_scope.push_back(allocate);
}


Expr* new_c = cloneWithNewOperands(c, replacement_map);
innermost_scope.push_back(new_c);

auto* wait = IrBuilder::create<hir::Wait>(new_c);
innermost_scope.push_back(wait);
}
break;
}
Expand Down Expand Up @@ -211,14 +249,11 @@ void lowerSegment(
// TensorViews.
if (loop_nest.empty()) {
for (Expr* e : exprs) {
loop_nest.innermostScope().push_back(e);
innermost_scope.push_back(e);
}
break;
}

auto [for_loop, parent_scope, parent_insertion_point] =
loop_nest.innermost();

std::unordered_map<Val*, Val*> replacement_map;
for (Expr* e : exprs) {
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
Expand All @@ -228,9 +263,9 @@ void lowerSegment(
in, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
auto [i, inserted] = replacement_map.try_emplace(
in, hir::shardByStream(in, for_loop->index()));
in, hir::shardByStream(in, innermost.loop->index()));
if (inserted) {
for_loop->body().push_back(i->second->definition());
innermost_scope.push_back(i->second->definition());
}
}
}
Expand All @@ -241,21 +276,22 @@ void lowerSegment(
nullptr) {
auto* allocate =
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
parent_scope->insert(parent_insertion_point, allocate);
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
// Loop is stream parallelized but allocation is not. Therefore,
// `out` should be allocated outside the loop.
//
// I use try_emplace here so shardByStream is called only when `out`
// is missing.
auto [i, inserted] = replacement_map.try_emplace(
out, hir::shardByStream(out, for_loop->index()));
out, hir::shardByStream(out, innermost.loop->index()));
NVF_ERROR(inserted);
for_loop->body().push_back(i->second->definition());
innermost_scope.push_back(i->second->definition());
}
}

Expr* new_e = cloneWithNewOperands(e, replacement_map);
for_loop->body().push_back(new_e);
innermost_scope.push_back(new_e);
}
break;
}
Expand All @@ -280,7 +316,7 @@ void lowerSegment(
auto* tv = out->as<TensorView>();
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
loop_nest.innermostScope().push_back(allocate);
innermost_scope.push_back(allocate);
}

// Add the LaunchKernel instruction.
Expand All @@ -296,7 +332,7 @@ void lowerSegment(
ins,
outs,
cache_id);
loop_nest.innermostScope().push_back(launch_kernel);
innermost_scope.push_back(launch_kernel);
}
} // switch
} // lowerSegment
Expand Down
6 changes: 3 additions & 3 deletions csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class NVF_API Communicator {
}

// returns the number of processes in the communicator
auto size() const {
int64_t size() const {
return size_;
}

// returns the local number of processes in the communicator (within the node)
auto local_size() const {
int64_t local_size() const {
return local_size_;
}

Expand All @@ -89,7 +89,7 @@ class NVF_API Communicator {
const std::string& prefix = "");

// returns the device associated with the current process
auto device() const {
at::Device device() const {
return at::Device("cuda:" + std::to_string(local_rank_));
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// clang-format on
#include <runtime/fusion_kernel_runtime.h>

#include <c10/cuda/CUDAGuard.h>

#include <fusion.h>
#include <fusion_profiler.h>
#include <fusion_segmenter.h>
Expand All @@ -25,8 +27,6 @@
#include <serde/fusion_cache_generated.h>
#include <type.h>

#include <c10/cuda/CUDAGuard.h>

namespace nvfuser {

namespace {
Expand Down
3 changes: 0 additions & 3 deletions tests/cpp/test_multidevice_stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <iterator>

#include <cuda_profiler_api.h>

#include <fusion.h>
Expand All @@ -24,7 +22,6 @@
namespace nvfuser {

using testing::ElementsAre;
using testing::SizeIs;

using MultiDeviceStreamParallelTypeTest = MultiDeviceTest;

Expand Down
77 changes: 77 additions & 0 deletions tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,83 @@
from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView


@pytest.mark.mpi
def test_row_parallel_linear_forward(multidevice_direct_test):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, s, t = 2, 3, 6
d = multidevice_direct_test.size
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

mesh = nvfuser.multidevice.DeviceMesh(range(d))

with FusionDefinition() as fd:
inp = fd.define_tensor(
shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16
)
weight = fd.define_tensor(
shape=[h, h * 4], contiguity=True, dtype=DataType.BFloat16
)
out = fd.ops.linear(inp, weight)
fd.add_output(out)

for tv in (inp, weight):
tv.set_device_mesh(mesh)

inp.split(0, s, inner_split=False)
inp.axis(0).parallelize(nvfuser.ParallelType.stream)
inp.split(2, d, inner_split=False)
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
weight.split(1, d, inner_split=False)
weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x)

# Expected pre-segmentation IR:
#
# [t, 4h] [h, 4h]
# /\ /\ /\.
# s* d d
# |
# | linear
# |
# r{4h}
# / \.
# [t, h, d, r{4h/d}]
# /\.
# s
# |
# | sum
# |
# [t, h, r{d}]
# /\.
# s*

# Expected host IR:
#
# %HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) :
# T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false)
# FOR i535 from 0 to 3:
# T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
# T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1})
# = linear(T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}),
# T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) )
# T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
# Communication 250 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}), backend=NCCL)
# Wait Communication 250
# } // %HostIrContainer

inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16)
weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16)
out_ref = torch.nn.functional.linear(inp_ref, weight_ref)

inp = (multidevice_direct_test.shard_tensor(inp_ref, -1, mesh),)
weight = (multidevice_direct_test.shard_tensor(weight_ref, -1, mesh),)
(out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
torch.testing.assert_close(out.cpu(), out_ref)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to verify inlining actually happened?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about torch.profiler to count how many kernels are launched?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try using PythonProfiler (python interface of FusionProfiler). It should give us the name of schedulers. It also records the stream id if needed. See https://github.com/NVIDIA/Fuser/pull/5563/files for an example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



@pytest.mark.mpi
@pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl])
@pytest.mark.parametrize("s", [1, 8])
Expand Down