-
Notifications
You must be signed in to change notification settings - Fork 70
Inline Linear+Allreduce into the same loop #5547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: wjy/out
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,18 +20,22 @@ namespace nvfuser { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| struct LoopInfo { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hir::ForLoop* loop; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hir::ForLoop* loop = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // The Scope that owns `loop`. It's one level outer than `loop`'s body scope. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scope* parent_scope; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scope* parent_scope = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // The iterator that points to `loop`. This way, we can insert instructions, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // e.g. Allocate, right before the loop. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scope::Iterator parent_insertion_point; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::ostream& operator<<(std::ostream& os, const LoopInfo& loop_info) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os << loop_info.loop->toInlineString(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (loop_info.loop == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os << "<null>"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os << loop_info.loop->toInlineString(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return os; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -57,6 +61,8 @@ class LoopNest { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return loop_infos_.back(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Returns the scope of the innermost for-loop or the top-level scope if the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // loop nest is empty. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scope& innermostScope() const { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return empty() ? top_level_ : innermost().loop->body(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -131,7 +137,7 @@ Expr* cloneWithNewOperands( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t out_replaced = std::ranges::count_if(new_outs, maybe_replace); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (in_replaced == 0 && out_replaced == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return e; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (out_replaced > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -151,6 +157,12 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hir::HostIrContainer& hic, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| LoopNest& loop_nest, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IrCloner& ir_cloner) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scope& innermost_scope = loop_nest.innermostScope(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| LoopInfo innermost; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!loop_nest.empty()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost = loop_nest.innermost(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| switch (group.schedulerType()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case SchedulerType::Communication: { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto device_id = Communicator::getInstance().deviceId(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -162,24 +174,50 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // without cloning the value again. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Expr* e = ir_cloner.clone(group.exprs().front()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto* c : convertSingleOpToCommunication(e, device_id)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // FIXME: should this be associated with the scope? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<Val*, Val*> replacement_map; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (Expr* c : convertSingleOpToCommunication(e, device_id)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c->isA<Communication>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Exprs in a Communication group should be Communication: ", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Allocate the recv buffers of communications | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* communication = c->as<Communication>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView* tv = communication->out(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (tv->getDeviceMesh().has(device_id)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* allocate = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IrBuilder::create<kir::Allocate>(tv, MemoryType::Global); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // TODO: allocation may have to go to the top level. See how | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // SchedulerType::ExprEval handles allocations. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView* in = communication->in(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView* out = communication->out(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| getShardedIterDomain( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto [i, inserted] = replacement_map.try_emplace( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in, hir::shardByStream(in, innermost.loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (inserted) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+188
to
+196
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This analysis is used at multiple locations. It can be moved to a util function.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Not for this PR though. I'll try to fix a bug around 187 and then think about refactoring.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notes for myself: one consideration was that a communication segment always writes to a pre-allocated output and an expr-eval segment may or may not. But I'll think about how to DRY.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What is the bug around 187?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's one more on top of that which I'll try to fix in the same PR. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+188
to
197
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: null pointer dereference if
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(communication); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto wait = IrBuilder::create<hir::Wait>(communication); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(wait); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Allocate the recv buffers of communications | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* allocate = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IrBuilder::create<kir::Allocate>(out, MemoryType::Global); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (getShardedIterDomain( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| getShardedIterDomain( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, ParallelType::Stream, DomainType::kAllocation) == | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost.parent_scope->insert( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost.parent_insertion_point, allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto [i, inserted] = replacement_map.try_emplace( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, hir::shardByStream(out, innermost.loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR(inserted); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+202
to
+212
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: null pointer dereference if
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+202
to
+215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: null pointer dereference if
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Expr* new_c = cloneWithNewOperands(c, replacement_map); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(new_c); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* wait = IrBuilder::create<hir::Wait>(new_c); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(wait); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -211,14 +249,11 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // TensorViews. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (loop_nest.empty()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (Expr* e : exprs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(e); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(e); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto [for_loop, parent_scope, parent_insertion_point] = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermost(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<Val*, Val*> replacement_map; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (Expr* e : exprs) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -228,9 +263,9 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in, ParallelType::Stream, DomainType::kAllocation) == | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto [i, inserted] = replacement_map.try_emplace( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in, hir::shardByStream(in, for_loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in, hir::shardByStream(in, innermost.loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (inserted) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->body().push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -241,21 +276,22 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* allocate = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IrBuilder::create<kir::Allocate>(out, MemoryType::Global); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| parent_scope->insert(parent_insertion_point, allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost.parent_scope->insert( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost.parent_insertion_point, allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Loop is stream parallelized but allocation is not. Therefore, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // `out` should be allocated outside the loop. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // I use try_emplace here so shardByStream is called only when `out` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // is missing. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto [i, inserted] = replacement_map.try_emplace( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, hir::shardByStream(out, for_loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, hir::shardByStream(out, innermost.loop->index())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVF_ERROR(inserted); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->body().push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(i->second->definition()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Expr* new_e = cloneWithNewOperands(e, replacement_map); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->body().push_back(new_e); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(new_e); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -280,7 +316,7 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* tv = out->as<TensorView>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* allocate = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IrBuilder::create<kir::Allocate>(tv, MemoryType::Global); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(allocate); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Add the LaunchKernel instruction. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -296,7 +332,7 @@ void lowerSegment( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ins, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_id); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_nest.innermostScope().push_back(launch_kernel); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| innermost_scope.push_back(launch_kernel); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // switch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // lowerSegment | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,83 @@ | |
| from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView | ||
|
|
||
|
|
||
| @pytest.mark.mpi | ||
| def test_row_parallel_linear_forward(multidevice_direct_test): | ||
| # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. | ||
Priya2698 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| h, s, t = 2, 3, 6 | ||
| d = multidevice_direct_test.size | ||
| if (h * 4) % d != 0: | ||
| pytest.skip( | ||
| f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." | ||
| ) | ||
| assert t % s == 0 | ||
|
|
||
| mesh = nvfuser.multidevice.DeviceMesh(range(d)) | ||
|
|
||
| with FusionDefinition() as fd: | ||
| inp = fd.define_tensor( | ||
| shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16 | ||
| ) | ||
| weight = fd.define_tensor( | ||
| shape=[h, h * 4], contiguity=True, dtype=DataType.BFloat16 | ||
| ) | ||
| out = fd.ops.linear(inp, weight) | ||
| fd.add_output(out) | ||
|
|
||
| for tv in (inp, weight): | ||
| tv.set_device_mesh(mesh) | ||
|
|
||
| inp.split(0, s, inner_split=False) | ||
| inp.axis(0).parallelize(nvfuser.ParallelType.stream) | ||
| inp.split(2, d, inner_split=False) | ||
| inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x) | ||
| weight.split(1, d, inner_split=False) | ||
| weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x) | ||
|
|
||
| # Expected pre-segmentation IR: | ||
| # | ||
| # [t, 4h] [h, 4h] | ||
| # /\ /\ /\. | ||
| # s* d d | ||
| # | | ||
| # | linear | ||
| # | | ||
| # r{4h} | ||
| # / \. | ||
| # [t, h, d, r{4h/d}] | ||
| # /\. | ||
| # s | ||
| # | | ||
| # | sum | ||
| # | | ||
| # [t, h, r{d}] | ||
| # /\. | ||
| # s* | ||
|
|
||
| # Expected host IR: | ||
| # | ||
| # %HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) : | ||
| # T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false) | ||
| # FOR i535 from 0 to 3: | ||
| # T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535) | ||
| # T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) | ||
| # = linear(T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}), | ||
| # T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) ) | ||
| # T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535) | ||
| # Communication 250 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}), backend=NCCL) | ||
| # Wait Communication 250 | ||
| # } // %HostIrContainer | ||
|
|
||
| inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16) | ||
| weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16) | ||
| out_ref = torch.nn.functional.linear(inp_ref, weight_ref) | ||
|
|
||
| inp = (multidevice_direct_test.shard_tensor(inp_ref, -1, mesh),) | ||
| weight = (multidevice_direct_test.shard_tensor(weight_ref, -1, mesh),) | ||
| (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) | ||
| torch.testing.assert_close(out.cpu(), out_ref) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to verify inlining actually happened?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about torch.profiler to count how many kernels are launched?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try using PythonProfiler (python interface of FusionProfiler). It should give us the name of schedulers. It also records the stream id if needed. See https://github.com/NVIDIA/Fuser/pull/5563/files for an example.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
|
|
||
| @pytest.mark.mpi | ||
| @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) | ||
| @pytest.mark.parametrize("s", [1, 8]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: null pointer dereference if
loop_nestis empty.innermost.loopis null whenloop_nest.empty()is true (line 161-164), but this code callsinnermost.loop->index()on line 192 without checking ifinnermost.loopis null