Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Nov 18, 2025

This PR modifies the following tensorview methods to support multidevice fusions:

  1. clearReductionIterdomains: For multidevice fusions, the assumption of logical == loop and allocation being a permutation of logical is no longer true. The changes include updating these checks and explicitly modifying loop domain to avoid losing sharding. This function is called during segmentation when allocation is sharded as well.
  2. multiOutputRfactorHelper: Previous code assumed trivial allocation domain. Multidevice fusions have sharded allocation domains.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

Review updated until commit 3276783

Description

  • Refactor multiOutputRFactorHelper to use TransformReplay::selfReplay for simpler multidevice support

  • Overhaul clearReductionIterDomains to handle allocation domains separately from logical domains

  • Add TensorDomain::noReductions helper function for filtering reduction domains

  • Add test_welford test case for multidevice variance/mean operations

  • Fix typo in comment about multi-output reduction scheduling

Changes walkthrough

Relevant files
Enhancement
tensor_view.cpp
Refactor tensor domain management for multidevice support

csrc/tensor_view.cpp

  • Replace complex replay logic in multiOutputRFactorHelper with
    TransformReplay::selfReplay
  • Completely refactor clearReductionIterDomains to handle allocation
    domains separately
  • Use TensorDomain::noReductions helper and properly manage contiguity
    flags
  • Support cases where allocation domain differs from logical domain in
    multidevice scenarios
  • +23/-58 
    Documentation
    interface_nodes.h
    Fix typo in comment                                                                           

    csrc/ir/interface_nodes.h

  • Fix typo in comment: "wheen" -> "when" in multiOutputRFactorHelper
    documentation
  • +1/-1     
    Tests
    test_multidevice.py
    Add Welford multidevice test                                                         

    tests/python/multidevice/test_multidevice.py

  • Add test_welford function to test variance and mean operations with
    multidevice sharding
  • Verify correctness of sharded variance/mean computations against
    PyTorch results
  • +30/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Significant logic simplification in multiOutputRFactorHelper

    The old implementation had complex domain replay logic with explicit logical domain mapping, replay construction, and tensor domain updates. The new implementation simply uses TransformReplay::selfReplay. This is a substantial change that could affect correctness, especially for complex multi-output reduction scenarios. Need to verify this simplification maintains the same behavior and doesn't break existing use cases.

    TensorView* TensorView::multiOutputRFactorHelper(
        TensorView* tv,
        const std::vector<int64_t>& axes) {
      NVF_ERROR(
          !container()->isA<kir::Kernel>(),
          "Function invalid for kernel container.");
      // Hack:
      // Semantically we should always keep the outputs of multi reduction ops
      // scheduled the same but the user end cannot guarantee that. In order to
      // guarantee that the rFactor is defined meaningfully the scheduling of the
      // output TV that got the rfactor call is force replayed towards the other two
      if (this != tv) {
        TransformReplay::selfReplay(this->domain(), tv->domain());
      }
    Removed validation checks in clearReductionIterDomains

    The old code had explicit error checks ensuring getLoopDomain() == getLogicalDomain() and that allocation domain is a permutation of logical domain. These checks were removed to support multidevice fusions with sharded allocation domains. While this enables the new functionality, it removes important validation that could catch bugs in single-device scenarios. Need to ensure this doesn't introduce regressions for existing single-device code.

    void TensorView::clearReductionIterDomains() {
      NVF_ERROR(
          !domain()->hasRoot(),
          "should not call clearReductionIterDomains on rfactor tv");
      const std::vector<std::optional<bool>>& contiguity = getContiguity();
      const std::vector<IterDomain*>& allocation = getMaybeAllocationDomain();
    
      std::vector<IterDomain*> new_logical =
          TensorDomain::noReductions(getLogicalDomain());
      std::vector<IterDomain*> new_loop =
          TensorDomain::noReductions(getLoopDomain());
      std::vector<IterDomain*> new_allocation =
          TensorDomain::noReductions(allocation);
    
      std::vector<std::optional<bool>> new_contiguity;
      new_contiguity.reserve(contiguity.size());
    
      // Fill new_contig, ignoring reduction ids
      for (auto&& [alloc_id, contig_value] : zip(allocation, contiguity)) {
        if (!alloc_id->isReduction()) {
          new_contiguity.push_back(contig_value);
        }
      }
    
      if (new_allocation == new_logical) {
        // if new allocation domain is identical to new logical domain, we don't
        // need to specify allocation domain
        setDomain(IrBuilder::createInContainer<TensorDomain>(
            container(), new_logical, new_loop, new_contiguity));
      } else {
        setDomain(IrBuilder::createInContainer<TensorDomain>(
            container(),
            std::vector<IterDomain*>(),
            new_logical,
            new_allocation,
            new_loop,
            new_contiguity));
      }
    }
    Limited test coverage for new functionality

    Only one new test was added for welford operations. Given the significant changes to core tensorview methods that affect how domains are handled, there should be more comprehensive testing to ensure the changes don't break existing functionality. Consider adding tests for other multi-output reduction operations and edge cases.

    @pytest.mark.mpi
    def test_welford(multidevice_direct_test):
        d = multidevice_direct_test.size
        mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))
        b, s, e = 1, 2048, 12288
    
        def _definition(fd: FusionDefinition):
            inp = fd.define_tensor(shape=[b, s, e], contiguity=True)
            var, mean = fd.ops.var_mean(inp, dims=[2], correction=0, keepdim=False)
            fd.add_output(var)
            fd.add_output(mean)
    
        def _multidevice_schedule(fd: FusionDefinition):
            (inp,) = fd.fusion.inputs()
            inp.set_device_mesh(mesh)
            inp.split(1, d, inner_split=False)
            inp.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
    
        unsharded = torch.randn(b, s, e)
        sharded = multidevice_direct_test.shard_tensor(unsharded, 1, mesh)
    
        with FusionDefinition() as fd:
            _definition(fd)
            _multidevice_schedule(fd)
        var, mean = fd.execute([sharded])
    
        torch.testing.assert_close(var, sharded.var(2), rtol=1e-3, atol=1e-3)
        torch.testing.assert_close(mean, sharded.mean(2), rtol=1e-3, atol=1e-3)

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 marked this pull request as ready for review November 18, 2025 14:54
    "should not call clearReductionIterDomains on transformed allocation "
    "domain");

    std::vector<IterDomain*> new_logical;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The only use of this function is in fusion segmenter for TranslateApplicableWelford::translateSingleWelford. I'll move it there in a separate PR.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 18, 2025

    Greptile Summary

    • Extended clearReductionIterDomains and multiOutputRFactorHelper to support multidevice fusions where loop domain differs from logical domain due to sharding
    • Removed assumptions that allocation domain is a permutation of logical domain and that loop equals logical, which are invalid for sharded tensors
    • Added test coverage for Welford reduction operations on sequence-parallel sharded tensors

    Confidence Score: 4/5

    • Safe to merge with minor considerations around removed validation checks
    • The changes correctly handle multidevice sharding by explicitly tracking loop domain separately from logical domain. However, the removed assertions in clearReductionIterDomains could potentially allow incorrect usage in non-multidevice contexts. The test coverage validates the primary use case.
    • Pay close attention to csrc/tensor_view.cpp where validation checks were removed

    Important Files Changed

    Filename Overview
    csrc/tensor_view.cpp Updated clearReductionIterDomains to handle sharded loop domains and replaced manual replay logic with selfReplay in multiOutputRFactorHelper

    Sequence Diagram

    sequenceDiagram
        participant User
        participant FusionDefinition
        participant TensorView
        participant TransformReplay
        participant TensorDomain
        
        User->>FusionDefinition: "define_tensor with sharding"
        User->>FusionDefinition: "ops.var_mean (Welford)"
        FusionDefinition->>TensorView: "multiOutputRFactorHelper"
        TensorView->>TransformReplay: "selfReplay(this->domain, tv->domain)"
        TransformReplay->>TensorDomain: "replay loop domain transformations"
        TensorDomain-->>TransformReplay: "updated domain"
        TransformReplay-->>TensorView: "replayed successfully"
        TensorView->>TensorView: "rFactor for reduction"
        TensorView->>TensorView: "clearReductionIterDomains"
        TensorView->>TensorDomain: "noReductions(loop_domain)"
        TensorView->>TensorDomain: "noReductions(allocation_domain)"
        TensorView->>TensorDomain: "create new TensorDomain with preserved loop"
        TensorDomain-->>TensorView: "domain without reduction IDs"
        TensorView-->>User: "execute fusion on sharded tensor"
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    std::unordered_map<IterDomain*, IterDomain*> id_map;
    for (const auto i : arange(logical.size())) {
    id_map[this_logical[i]] = logical[i];
    std::unordered_map<IterDomain*, IterDomain*> ref_to_target_map;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It sounds just like TransformReplay::selfReplay. Am I missing something?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    🤦 Yes. It is.

    NVF_ERROR(
    !domain()->hasRoot(),
    "should not call clearReductionIterDomains on rfactor tv");
    const std::vector<std::optional<bool>>& contiguity = getContiguity();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This sounds like another use case for selfReplay.

    1. Create a TensorDomain of logical domain noReductions(getLogicalDomain()), loop empty, and allocation empty
    2. selfReplay this->domain() to that new TensorDomain
    3. setDomain(the new)

    Am I missing something? Do I need to merge #5316 for selfReplay to work here?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes, selfReplay could be utilized here.
    However, we are only stripping the reduction IDs, so generating the replay here is unnecessary.

    @Priya2698
    Copy link
    Collaborator Author

    !test --diff

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    @Priya2698 Priya2698 merged commit 722fc4f into main Nov 19, 2025
    66 of 67 checks passed
    @Priya2698 Priya2698 deleted the pm/layernorm_sp branch November 19, 2025 14:13
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants