Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/multidevice/propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ std::unordered_map<IterDomain*, IterDomain*> getRef2TargetMap(
const TensorView* target,
PropagateDirection direction) {
if (direction == PropagateDirection::kForward) {
return PairwiseLogicalDomainMap(ref, target).mapProducerToConsumer();
return PairwiseLogicalDomainMap(ref, target)
.mapBroadcast(false)
.mapProducerToConsumer();
}
return PairwiseLogicalDomainMap(target, ref).mapConsumerToProducer();
return PairwiseLogicalDomainMap(target, ref)
.mapBroadcast(false)
.mapConsumerToProducer();
}

// Propagates the given device/stream ids from ref to target.
Expand Down
24 changes: 22 additions & 2 deletions csrc/preseg_passes/decompose_reshardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ void decomposeRowParallelLinearWithBias(Fusion* fusion) {
}

auto* without_bias = linear(linear_op->inA(), linear_op->inB());
TransformReplay::selfReplay(out->domain(), without_bias->domain());

TensorView* broadcasted_bias = [&]() {
const int64_t rank_after_broadcast = std::ssize(
Expand All @@ -329,8 +328,29 @@ void decomposeRowParallelLinearWithBias(Fusion* fusion) {

TensorView* new_out =
maybeCastOp(out->dtype(), add(without_bias, broadcasted_bias));
TransformReplay::selfReplay(out->domain(), new_out->domain());

ir_utils::replaceValInAllExprInputsAndFusionOutputs(out, new_out);

// Shard without_bias to match new_out so that reduction ID is properly
// sharded.
TransformReplay::selfReplay(out->domain(), without_bias->domain());
TransformReplay::selfReplay(out->domain(), new_out->domain());
// Backpropagate shardings to consistently shard all intermediate
// expressions. Forward propagating may miss sharding tensorviews
// on the path between `bias` and `new_out`.
for (Expr* expr : StmtSort::getExprsBetween(
{without_bias, broadcasted_bias}, {new_out}) |
std::views::reverse) {
for (auto* output : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto* input : ir_utils::filterByType<TensorView>(expr->inputs())) {
shardLoopLike(
/*ref=*/output,
/*target=*/input,
deviceAndStreamParallelTypes(),
PropagateDirection::kBackward);
}
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/communication_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ KernelArgumentHolder CommunicationExecutor::run(
group_id_);
SegmentProfiler& sprof = FusionProfiler::segment(group_id_);
sprof.inputBytesAccessed(computeBytes(args));
sprof.scheduler(toString(SchedulerType::ExprEval));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caused the wrong scheduler name in profiler output.

sprof.scheduler(toString(SchedulerType::Communication));
sprof.startKernel();
}
NVF_ERROR(host_ir_container_, "Need to compile before you can run.");
Expand Down
53 changes: 32 additions & 21 deletions tests/python/multidevice/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

import nvfuser_direct as nvfuser
from nvfuser_direct import DataType, FusionDefinition
from nvfuser_direct import DataType, FusionDefinition, PythonProfiler


# Avoid doing this when possible. This test started to exist before nvFuser
Expand Down Expand Up @@ -197,50 +197,61 @@ def _multidevice_schedule(fd: FusionDefinition):
def test_linear_reduce_scatter(multidevice_direct_test):
d = multidevice_direct_test.size
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
e = 768
b, s, e = 3, 5, 7

def _definition(fd: FusionDefinition):
inp = fd.define_tensor([-1, -1, d * e])
weight = fd.define_tensor([e, d * e])
out = fd.ops.linear(inp, weight, None)
inp = fd.define_tensor([-1, d * s, d * e], dtype=DataType.BFloat16)
weight = fd.define_tensor([-1, d * e], dtype=DataType.BFloat16)
bias = fd.define_tensor([e], dtype=DataType.BFloat16)
out = fd.ops.linear(inp, weight, bias)
fd.add_output(out)

def _multidevice_schedule(fd: FusionDefinition):
inp, weight = fd.fusion.inputs()
inp, weight, bias = fd.fusion.inputs()
(out,) = fd.fusion.outputs()
for t in [inp, weight, out]:
t.set_device_mesh(mesh)
t.split(-1, d, inner_split=False)
t.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
bias.set_device_mesh(mesh)
for tv in [inp, weight, out]:
tv.set_device_mesh(mesh)
tv.split(-1, d, inner_split=False)
tv.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)

# Scatter
out.split(1, d, inner_split=False)
out.axis(1).parallelize(nvfuser.ParallelType.mesh_x)

torch.cuda.set_device(multidevice_direct_test.local_rank)

# set b=1 as a temporary fix for the test to pass.
# TODO: set b>1 once reduce scatter is fixed.
b, s = 2, 1024
unsharded_inp = torch.randn(b, s, d * e)
unsharded_weight = torch.randn(e, d * e)

unsharded_inp = torch.randint(-2, 3, (b, d * s, d * e)).to(torch.bfloat16)
unsharded_weight = torch.randint(-2, 3, (e, d * e)).to(torch.bfloat16)
bias = torch.randint(-2, 3, (e,)).to(torch.bfloat16)
inp = multidevice_direct_test.shard_tensor(unsharded_inp, -1, mesh)
weight = multidevice_direct_test.shard_tensor(unsharded_weight, -1, mesh)

with FusionDefinition() as fd:
_definition(fd)
_multidevice_schedule(fd)

(out,) = fd.execute([inp, weight])
with PythonProfiler() as prof:
(out,) = fd.execute([inp, weight, bias.cuda()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this synchronize? Could we miss kernels?

Copy link
Collaborator Author

@Priya2698 Priya2698 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fd.execute should not return until kernels have completed.
There is a cudaStreamSynchronize at the end of nsys trace too.

Is this what you are referring to?

Copy link
Collaborator

@wujingyue wujingyue Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a difference between cudaStreamSynchronize and cudaDeviceSynchronize though. The former blocks the stream and the latter blocks the host.

Copy link
Collaborator Author

@Priya2698 Priya2698 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I assumed cudaStreamSynchronize would be enough here but pointwise kernel and nccl call on different streams.

FusionProfiler/PythonProfiler synchronize at start but not on stop. So I will add an explicit call here.

Note for myself: See if FusionProfiler should synchronize before reading data.


unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None)
# rtol is the same as the default for fp32. atol is slightly increased.
# Only one reduce scatter kernel should be scheduled.
assert (
len(
[
kp
for kp in prof.profile.kernel_profiles
if kp.scheduler == "communication"
]
)
== 1
if d > 1
else 0
)

unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, bias)
torch.testing.assert_close(
out,
multidevice_direct_test.shard_tensor(unsharded_out, 1, mesh),
rtol=1.3e-6,
atol=1e-3,
)


Expand Down