Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Nov 18, 2025

Following #5553

TMA Pointwise Scheduler: Broadcast Domain Handling

(1) TMA Load vs. General Load (ldg/ld.global)

The current TMA pointwise scheduler does not use TMA load for inputs with concretized broadcast domains.

Example:
Given three inputs: tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], where B1 and B2 are broadcast domains concretized to I1 and I2:

  • tv0 and tv1 will NOT be loaded with TMA

Note: This is a performance optimization. These inputs can be loaded with TMA, but only using a one-dimensional tile, as demonstrated in the newly added test.

(2) Break Point Selection

When broadcasts are present, the loop domain of the reference tv is merged to [lhs, rhs] instead of flattening to a single dimension.

Example:
Given tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], the break point is at pos-1, which separates [I1, I2] into [lhs, rhs].

Break point selection differs between TMA and non-TMA versions:

  • TMA version: Break point is selected whenever broadcast domains are present
  • Non-TMA version: Break point is selected only when at least 10% of transferred data can be saved

Rationale: In the TMA version, we cannot safely merge broadcast and non-broadcast domains when creating 2D TMA domains and schedules, so we always break when broadcasts are present. See restrictions at #5556

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

Review updated until commit 0b5bd62

Description

  • Refactor pointwise scheduler to extract break point calculation into utility functions

  • Add support for TMA (Tensor Memory Accelerator) scheduling with broadcast tensors

  • Introduce BreakPointInfo and BlockGridConfig structs for better code organization

  • Enhance TMA scheduling to handle 2D break points and broadcast dimensions

  • Add comprehensive tests for TMA broadcast tensor handling

Changes walkthrough

Relevant files
Enhancement
pointwise_non_tma.cpp
Refactor non-TMA pointwise scheduler with utility functions

csrc/scheduler/pointwise_non_tma.cpp

  • Extract break point calculation logic into separate getBreakPoint
    function
  • Extract block/grid configuration into getBlockGridConfig function
  • Simplify main heuristics function by using utility functions
  • Move broadcast info calculation for debug output only
  • +32/-131
    pointwise_tma.cpp
    Enhance TMA scheduler with break point support                     

    csrc/scheduler/pointwise_tma.cpp

  • Use new getBreakPoint function for TMA scheduling
  • Add break point to debug output
  • Modify scheduling to handle 2D break points properly
  • Add validation for TMA scheduling with break points
  • +15/-3   
    pointwise_utils.cpp
    Implement utility functions for break point and grid configuration

    csrc/scheduler/pointwise_utils.cpp

  • Add getBreakPoint function for optimal 2D scheduling break point
    calculation
  • Add getBlockGridConfig function for block/grid dimension calculation
  • Handle both TMA and non-TMA scheduling cases
  • Support broadcast tensor handling in break point logic
  • +181/-0 
    pointwise_utils.h
    Add data structures and function declarations for scheduler utilities

    csrc/scheduler/pointwise_utils.h

  • Add BreakPointInfo struct for break point information
  • Add BlockGridConfig struct for complete block/grid configuration
  • Declare getBreakPoint and getBlockGridConfig functions
  • Add documentation for new utility functions
  • +38/-0   
    Tests
    test_pointwise.cpp
    Add comprehensive TMA broadcast tensor tests                         

    tests/cpp/test_pointwise.cpp

  • Rename Tma2dTileTest to TmaPointwiseTest for consistency
  • Add TmaTestBase base class for common TMA test setup
  • Add TmaPointwiseBcastTest for broadcast tensor TMA testing
  • Test TMA scheduling with inner/outer broadcast dimensions
  • Test both auto-scheduler and manual TMA scheduling approaches
  • +215/-17

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    TMA Break Point Handling

    The new break point handling logic in TMA scheduling should be validated to ensure it correctly handles edge cases, particularly when break_point=0 and the conditional logic for n_valid_dims >= 2. Verify that TMA domain splitting works correctly for all supported tensor dimensions.

    if (pparams->break_point == 0) {
      reference_tv->split(0, pparams->tma_domain_inner);
    } else {
      NVF_ERROR(
          n_valid_dims >= 2,
          "Required at least 2 valid dimensions for Tma scheduling, but got ",
          n_valid_dims);
    }
    Break Point Calculation Logic

    The extracted getBreakPoint function contains complex logic for calculating optimal break points. Validate that the TMA-specific logic (lines 293-300) correctly prioritizes break points for broadcast dimensions and that the transfer size calculations are accurate for TMA use cases.

    BreakPointInfo getBreakPoint(
        Fusion* fusion,
        const FusionRuntimeProperties& prop,
        HeuristicDataCache* data_cache,
        bool is_tma,
        int64_t max_vect_factor,
        int64_t kThreadX) {
      BreakPointInfo result;
    
      // Calculate dtype_sum_bit from fusion inputs/outputs
      int64_t dtype_sum_bit = 0;
      const auto index_type = prop.index_type;
      for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
        dtype_sum_bit += dataTypeSizeBit(inp->getDataType().value(), index_type);
      }
      for (auto out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
        dtype_sum_bit += dataTypeSizeBit(out->getDataType().value(), index_type);
      }
    
      // Get broadcast information
      TensorView* largest_out = prop.largest_out;
      auto broadcast_info_entry = HeuristicDataCacheEntry<
          HeuristicCompileTime::BroadcastMultiples>(
          data_cache, [&largest_out, &index_type]() {
            return std::make_unique<scheduler_utils::BroadcastMultipleInformation>(
                scheduler_utils::getBroadcastMultiples(largest_out, index_type));
          });
      const auto& broadcast_info = broadcast_info_entry.get();
    
      const auto& ref_loop = prop.ref_loop;
      const auto& elem_counts = prop.elem_counts;
      const int64_t n_elems = prop.n_elems;
      const auto& view_disjoint_sets = broadcast_info.view_disjoint_set_ids;
      const auto& broadcast_bit_multiples = broadcast_info.broadcast_multiples;
    
      // Default values for 1D scheduling
      result.break_point = 0;
      result.flip_grid_binding = false;
      result.right_elem_count = 0;
      result.is_outer_broadcast_dominated = false;
    
      // Figure out break point position
      // How much would this transfer cost if it was done as a 1-D schedule
      int64_t transfer_size_1d_bit = 1;
    
      for (const auto i : arange(ref_loop.size())) {
        transfer_size_1d_bit =
            transfer_size_1d_bit * elem_counts[i] * dtype_sum_bit;
      }
    
      // Calculate optimal break point for 2D scheduling
      int64_t min_total_transfer_bit = std::numeric_limits<int64_t>::max();
      // Don't check the inner most dimension, scheduler assumes there's always
      // an rhs
      for (const auto break_point_i : arange((int64_t)ref_loop.size())) {
        // If break point is incoherent with view, don't consider breaking here.
        if (!scheduler_utils::breakIsDisjoint(view_disjoint_sets, break_point_i)) {
          continue;
        }
    
        // Number of elements in the right side of reference tv with
        // break_point_i
        int64_t cur_right_elem_count = 1;
        for (const auto right_i : arange(break_point_i, ref_loop.size())) {
          cur_right_elem_count = cur_right_elem_count * elem_counts[right_i];
        }
    
        // For tma scheduling, allow no element in the left side of break point,
        // e.g. break at pos-0 for non-broadcasted case.
        auto cur_left_elem_count = n_elems / cur_right_elem_count;
        if (!is_tma && cur_left_elem_count <= 1) {
          continue;
        }
    
        auto lhs_bit_multiple = broadcast_bit_multiples[break_point_i].lhs_multiple;
        auto rhs_bit_multiple = broadcast_bit_multiples[break_point_i].rhs_multiple;
    
        // Estimate transfer cost with this break point
        int64_t cur_transfer_size_bit = 1;
        int64_t right_transfer_size_bit = 1;
    
        for (const auto left_i : arange(break_point_i)) {
          cur_transfer_size_bit =
              cur_transfer_size_bit * elem_counts[left_i] * lhs_bit_multiple;
        }
    
        for (const auto right_i : arange(break_point_i, ref_loop.size())) {
          right_transfer_size_bit =
              right_transfer_size_bit * elem_counts[right_i] * rhs_bit_multiple;
        }
        cur_transfer_size_bit *= right_transfer_size_bit;
    
        if (!is_tma) {
          //  Continue if this break point doesn't save at least 10% of 1D
          //  scheduling or isn't better than previous break_points found.
          if (cur_transfer_size_bit >= min_total_transfer_bit ||
              cur_transfer_size_bit * 10 >= transfer_size_1d_bit * 9) {
            continue;
          }
    
          // Need to be able to parallelize, don't use break if there's not
          // at least an unrolled warp.
          if (ceilDiv(cur_right_elem_count, max_vect_factor) <=
              at::cuda::getCurrentDeviceProperties()->warpSize) {
            continue;
          }
          // If outer broadcast, or balanced broadcast:
          if (lhs_bit_multiple <= rhs_bit_multiple &&
              // If right transfer size is bigger than half of L2
              at::cuda::getCurrentDeviceProperties()->l2CacheSize * 8 <
                  right_transfer_size_bit * 2) {
            // flip BIDx and BIDy bindings
            result.flip_grid_binding = true;
          } else {
            result.flip_grid_binding = false;
          }
        } else {
          // If TMA is used, priorize break if it saves transfered size
          // This ensures we break at broadcast dimensions, then we can optionally
          // load tvs with broadcasted dimensions.
          if (cur_transfer_size_bit >= min_total_transfer_bit) {
            continue;
          }
        }
    
        // Use this break point
        result.break_point = static_cast<int>(break_point_i);
        min_total_transfer_bit = cur_transfer_size_bit;
        result.right_elem_count = cur_right_elem_count;
    
        // when lhs byte multiple is smaller than rhs byte multiple,
        // there is broadcast in the lhs, which is outer broadcast.
        result.is_outer_broadcast_dominated = lhs_bit_multiple < rhs_bit_multiple;
      }
    
      return result;
    }

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl marked this pull request as ready for review November 20, 2025 15:09
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 20, 2025

    Greptile Overview

    Greptile Summary

    This PR extends the TMA pointwise scheduler to handle broadcast domains. The implementation refactors break point calculation into shared utilities (getBreakPoint() and getBlockGridConfig()), enabling both TMA and non-TMA schedulers to use common logic with different policies.

    Key changes:

    • TMA scheduler now calculates break points and uses 2D scheduling when broadcasts are present (avoiding merging broadcast/non-broadcast domains)
    • Non-TMA scheduler refactored to use the new shared utilities, removing ~130 lines of duplicated code
    • TMA version always breaks when broadcasts reduce transfer size, while non-TMA version requires 10% savings threshold
    • Inputs with concretized broadcast domains are excluded from TMA load (use standard load instead)
    • Comprehensive test coverage added for broadcast scenarios (both auto-scheduler and manual scheduling paths)

    Confidence Score: 4/5

    • This PR is safe to merge after fixing the comment formatting issue
    • The refactoring successfully consolidates break point logic and adds broadcast support. There is one syntax issue (malformed comment) that should be fixed. The logic changes are well-motivated and tested, though the TMA domain splitting logic at line 214-221 could be clearer about why no split is needed when break_point > 0
    • csrc/scheduler/pointwise_tma.cpp requires attention for the comment formatting issue at lines 82-83

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/pointwise_utils.cpp 4/5 Implemented getBreakPoint() with TMA-specific logic that prioritizes any transfer size savings when broadcasts are present, and getBlockGridConfig() to compute block/grid dimensions
    csrc/scheduler/pointwise_tma.cpp 4/5 Added break point calculation and conditional TMA domain splitting - skips split when break_point > 0 (assumes 2D domains already exist)

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Pointwise Scheduler
        participant Utils as pointwise_utils
        participant TMA as TMA Scheduler
        participant NonTMA as Non-TMA Scheduler
        
        alt TMA Scheduling
            Scheduler->>TMA: getPointwiseHeuristics()
            TMA->>Utils: getBreakPoint(is_tma=true)
            Utils->>Utils: Calculate broadcast multiples
            Utils->>Utils: Iterate break points
            alt Broadcast detected
                Utils->>Utils: Prioritize any transfer size savings
                Note over Utils: No 10% threshold for TMA
            end
            Utils-->>TMA: BreakPointInfo (break_point, flip_grid_binding)
            TMA->>TMA: Calculate TMA tile sizes
            TMA->>TMA: schedulePointwise()
            alt break_point == 0
                TMA->>TMA: split(0, tma_domain_inner)
                Note over TMA: 1D: [I0*I1] → [Do, Di]
            else break_point > 0
                Note over TMA: 2D: Keep [I0, I1] separate
                TMA->>TMA: NVF_ERROR(n_valid_dims >= 2)
            end
            TMA->>TMA: Filter inputs by isTvSuitableForTma()
            Note over TMA: Reject inputs with broadcast domains
        else Non-TMA Scheduling
            Scheduler->>NonTMA: getPointwiseHeuristics()
            alt Sufficient parallelism
                NonTMA->>Utils: getBreakPoint(is_tma=false)
                Utils->>Utils: Calculate broadcast multiples
                Utils->>Utils: Iterate break points
                alt No broadcast or minimal savings
                    Utils->>Utils: Require 10% transfer size savings
                    Utils->>Utils: Check parallelization threshold
                end
                Utils-->>NonTMA: BreakPointInfo
                NonTMA->>Utils: getBlockGridConfig()
                Utils->>Utils: Calculate bdimx, bdimy, gdim_left, gdim_right
                Utils-->>NonTMA: BlockGridConfig
            else Insufficient parallelism
                NonTMA->>NonTMA: Use 1D scheduling (break_point=0)
            end
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +82 to 83
    // dimension. outer tile size: don't exceed the outer TMA dimension size Both
    // Both are subject to hardware constraints of 256 elements per dimension.
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: comment has formatting issue - appears to be an incomplete edit

    Suggested change
    // dimension. outer tile size: don't exceed the outer TMA dimension size Both
    // Both are subject to hardware constraints of 256 elements per dimension.
    // - Inner tile size: ensure at least 2 tiles in the inner TMA dimension
    // - Outer tile size: don't exceed the outer TMA dimension size

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants