-
Notifications
You must be signed in to change notification settings - Fork 70
Inline Linear+Allreduce into the same loop #5547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: wjy/out
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit 6e7e138 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement | 4 files
| ||||||||
| Tests | 1 files
| ||||||||
| Bug fix | 1 files
| ||||||||
| Formatting | 1 files
| ||||||||
| Cleanup | 1 files
| ||||||||
| Configuration changes | 1 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Loop Scope Management
innermost_scope, innermost.loop, and innermost.parent_scope. Need to verify that the scope hierarchy is correctly maintained and that allocations are properly placed (inside vs outside loops) based on sharding configuration. |
Test failures
-
(Medium, 2)
Profiler kernel count mismatch in tests.python.multidevice.test_overlapTest Name A100 A100 (dist.) Source tests.python.multidevice.test_overlap.test_row_parallel_linear_forward ❌ ❌
|
!test |
Greptile OverviewGreptile SummaryThis PR implements the ability to inline linear operations and allreduce communications into the same loop for stream-parallelized matmul operations. The main change refactors Key Changes:
Critical Issues:
Minor Issues:
Confidence Score: 1/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant LF as lowerSegmentedFusionToHostIr
participant LN as LoopNest
participant LS as lowerSegment
participant IC as IrCloner
LF->>LN: Create LoopNest(topLevel)
loop For each SegmentedGroup
LF->>LF: computeInlinePosition()
LF->>LN: closeLoop() until inline_position
LF->>LN: openLoop() for stream IDs
LF->>LS: lowerSegment(group, loop_nest)
alt SchedulerType::Communication
LS->>IC: clone(group.exprs().front())
LS->>LS: convertSingleOpToCommunication()
loop For each Communication
LS->>LS: Check if input is stream-parallelized
alt Input has Stream loop domain
LS->>LS: shardByStream(in, innermost.loop->index())
Note right of LS: ⚠️ Potential null pointer if loop_nest empty
LS->>LN: push_back(ShardByStream)
end
LS->>LS: Allocate output buffer
alt Output has Stream loop domain
LS->>LN: insert allocate before loop
LS->>LS: shardByStream(out, innermost.loop->index())
Note right of LS: ⚠️ Potential null pointer if loop_nest empty
LS->>LN: push_back(ShardByStream)
else
LS->>LN: push_back(allocate)
end
LS->>IC: cloneWithNewOperands(communication)
LS->>LN: push_back(communication)
LS->>LN: push_back(Wait)
end
else SchedulerType::ExprEval
alt loop_nest.empty()
LS->>LN: push_back(exprs) to top level
else
loop For each expression
LS->>LS: Check inputs for stream parallelization
LS->>LS: shardByStream if needed
LS->>LS: Check outputs for stream parallelization
LS->>LN: insert allocate before loop if needed
LS->>IC: cloneWithNewOperands(expr)
LS->>LN: push_back(expr)
end
end
else Default (LaunchKernel)
LS->>LN: Allocate outputs
LS->>LN: push_back(LaunchKernel)
end
end
LF->>LF: InsertDeallocations()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | ||
| nullptr && | ||
| getShardedIterDomain( | ||
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | ||
| auto [i, inserted] = replacement_map.try_emplace( | ||
| in, hir::shardByStream(in, innermost.loop->index())); | ||
| if (inserted) { | ||
| innermost_scope.push_back(i->second->definition()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: null pointer dereference if loop_nest is empty. innermost.loop is null when loop_nest.empty() is true (line 161-164), but this code calls innermost.loop->index() on line 192 without checking if innermost.loop is null
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | |
| nullptr && | |
| getShardedIterDomain( | |
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| in, hir::shardByStream(in, innermost.loop->index())); | |
| if (inserted) { | |
| innermost_scope.push_back(i->second->definition()); | |
| } | |
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | |
| nullptr && | |
| getShardedIterDomain( | |
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | |
| NVF_ERROR( | |
| !loop_nest.empty(), | |
| "Stream-parallelized input requires a non-empty loop nest"); | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| in, hir::shardByStream(in, innermost.loop->index())); |
| if (getShardedIterDomain( | ||
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | ||
| getShardedIterDomain( | ||
| out, ParallelType::Stream, DomainType::kAllocation) == | ||
| nullptr) { | ||
| innermost.parent_scope->insert( | ||
| innermost.parent_insertion_point, allocate); | ||
| auto [i, inserted] = replacement_map.try_emplace( | ||
| out, hir::shardByStream(out, innermost.loop->index())); | ||
| NVF_ERROR(inserted); | ||
| innermost_scope.push_back(i->second->definition()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: null pointer dereference if loop_nest is empty. Multiple accesses to innermost.loop and innermost.parent_scope without null checks
| if (getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | |
| getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kAllocation) == | |
| nullptr) { | |
| innermost.parent_scope->insert( | |
| innermost.parent_insertion_point, allocate); | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| out, hir::shardByStream(out, innermost.loop->index())); | |
| NVF_ERROR(inserted); | |
| innermost_scope.push_back(i->second->definition()); | |
| if (getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | |
| getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kAllocation) == | |
| nullptr) { | |
| NVF_ERROR( | |
| !loop_nest.empty(), | |
| "Stream-parallelized output requires a non-empty loop nest"); | |
| innermost.parent_scope->insert( | |
| innermost.parent_insertion_point, allocate); |
| inp = (multidevice_direct_test.shard_tensor(inp_ref, -1, mesh),) | ||
| weight = (multidevice_direct_test.shard_tensor(weight_ref, -1, mesh),) | ||
| (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) | ||
| torch.testing.assert_close(out.cpu(), out_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to verify inlining actually happened?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about torch.profiler to count how many kernels are launched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try using PythonProfiler (python interface of FusionProfiler). It should give us the name of schedulers. It also records the stream id if needed. See https://github.com/NVIDIA/Fuser/pull/5563/files for an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | ||
| nullptr && | ||
| getShardedIterDomain( | ||
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | ||
| auto [i, inserted] = replacement_map.try_emplace( | ||
| in, hir::shardByStream(in, innermost.loop->index())); | ||
| if (inserted) { | ||
| innermost_scope.push_back(i->second->definition()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This analysis is used at multiple locations. It can be moved to a util function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Not for this PR though. I'll try to fix a bug around 187 and then think about refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notes for myself: one consideration was that a communication segment always writes to a pre-allocated output and an expr-eval segment may or may not. But I'll think about how to DRY.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to fix a bug around 187
What is the bug around 187?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's one more on top of that which I'll try to fix in the same PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 3 comments
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | ||
| nullptr && | ||
| getShardedIterDomain( | ||
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | ||
| auto [i, inserted] = replacement_map.try_emplace( | ||
| in, hir::shardByStream(in, innermost.loop->index())); | ||
| if (inserted) { | ||
| innermost_scope.push_back(i->second->definition()); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: null pointer dereference if loop_nest is empty on line 193. When getShardedIterDomain returns non-null but loop_nest.empty() is true, innermost.loop is null (set at lines 161-164), causing innermost.loop->index() to crash. The ExprEval case has an early return for empty loop_nest (line 251), but Communication case is missing this safeguard.
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | |
| nullptr && | |
| getShardedIterDomain( | |
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| in, hir::shardByStream(in, innermost.loop->index())); | |
| if (inserted) { | |
| innermost_scope.push_back(i->second->definition()); | |
| } | |
| } | |
| if (getShardedIterDomain(in, ParallelType::Stream, DomainType::kLoop) != | |
| nullptr && | |
| getShardedIterDomain( | |
| in, ParallelType::Stream, DomainType::kAllocation) == nullptr) { | |
| NVF_ERROR(!loop_nest.empty(), "Stream-parallelized input requires a loop nest"); | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| in, hir::shardByStream(in, innermost.loop->index())); | |
| if (inserted) { | |
| innermost_scope.push_back(i->second->definition()); | |
| } | |
| } |
| if (getShardedIterDomain( | ||
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | ||
| getShardedIterDomain( | ||
| out, ParallelType::Stream, DomainType::kAllocation) == | ||
| nullptr) { | ||
| innermost.parent_scope->insert( | ||
| innermost.parent_insertion_point, allocate); | ||
| auto [i, inserted] = replacement_map.try_emplace( | ||
| out, hir::shardByStream(out, innermost.loop->index())); | ||
| NVF_ERROR(inserted); | ||
| innermost_scope.push_back(i->second->definition()); | ||
| } else { | ||
| innermost_scope.push_back(allocate); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: null pointer dereference if loop_nest is empty. Multiple accesses to innermost.loop->index() (line 210), innermost.parent_scope (line 207), and innermost.parent_insertion_point (line 208) without null checks. Same issue as lines 188-197.
| if (getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | |
| getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kAllocation) == | |
| nullptr) { | |
| innermost.parent_scope->insert( | |
| innermost.parent_insertion_point, allocate); | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| out, hir::shardByStream(out, innermost.loop->index())); | |
| NVF_ERROR(inserted); | |
| innermost_scope.push_back(i->second->definition()); | |
| } else { | |
| innermost_scope.push_back(allocate); | |
| } | |
| if (getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kLoop) != nullptr && | |
| getShardedIterDomain( | |
| out, ParallelType::Stream, DomainType::kAllocation) == | |
| nullptr) { | |
| NVF_ERROR(!loop_nest.empty(), "Stream-parallelized output requires a loop nest"); | |
| innermost.parent_scope->insert( | |
| innermost.parent_insertion_point, allocate); | |
| auto [i, inserted] = replacement_map.try_emplace( | |
| out, hir::shardByStream(out, innermost.loop->index())); | |
| NVF_ERROR(inserted); | |
| innermost_scope.push_back(i->second->definition()); | |
| } else { | |
| innermost_scope.push_back(allocate); | |
| } |
| Expr* e = ir_cloner.clone(group.exprs().front()); | ||
|
|
||
| for (auto* c : convertSingleOpToCommunication(e, device_id)) { | ||
| // TODO: `replacement_map` should be associated with the scope os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: 'os' should be 'so'
| // TODO: `replacement_map` should be associated with the scope os | |
| // TODO: `replacement_map` should be associated with the scope so |
Priya2698
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please also incorporate greptile changes on spelling and adding error checks.
|
|
||
| # Collect CUDA kernels after a warmup run to exclude autotuning. | ||
| # nvfuser_direct.PythonProfiler failed with host IR lowering. The main | ||
| # reason is that HostIrContainer doesn't keep segments while SegmentProfiler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.profiler use looks good to me.
Is it as simple as adding
| FusionProfiler::createSegments(kernel_runtime->executors().size()); |
I'll try to take a look at this to extend FusionProfiler for Host IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try a couple of superficial things and figured it's more complicated than that. I stashed my changes unfortunately.
Furthermore, when we switch to HostIrJit generating an LLVM IR program that has no nvFuser dependency, what extra values does PythonProfiler provide in addition to torch.profiler?
Fixes #5307