diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index e97451d2bf320..56423fecc2708 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -129,6 +129,22 @@ bool DoesTileFitsInRegisters(int64_t tile_size, device_info.registers_per_block_limit(); } +// Computes and caches the largest tile size in the tiled computation. +int64_t ComputeLargestTileSize(TiledHloComputation& tiled_hlo_computation) { + if (tiled_hlo_computation.GetLargestTileSize().has_value()) { + return *tiled_hlo_computation.GetLargestTileSize(); + } + + int64_t largest_tile_size = 1; + for (const TiledHloInstruction* tiled_hlo : + tiled_hlo_computation.instructions()) { + largest_tile_size = + std::max(largest_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); + } + tiled_hlo_computation.SetLargestTileSize(largest_tile_size); + return largest_tile_size; +} + // Checks if all tiles in the computation fit in registers. // // There is no way to know for sure if emitted computation will not spill @@ -141,12 +157,18 @@ bool DoesComputationFitInRegisters( const se::DeviceDescription& device_info) { // Check that output tiles fit in registers. for (const TiledHloInstruction* root : tiled_hlo_computation.GetRoots()) { - if (!DoesTileFitsInRegisters(GetPaddedTileSize(root->tile_sizes()), - device_info)) { + int64_t padded_tile_size = GetPaddedTileSize(root->tile_sizes()); + if (!DoesTileFitsInRegisters(padded_tile_size, device_info)) { return false; } } + // Check that the largest tile fits in registers. + int64_t largest_tile_size = *tiled_hlo_computation.GetLargestTileSize(); + if (!DoesTileFitsInRegisters(largest_tile_size, device_info)) { + return false; + } + for (const TiledHloInstruction* tiled_hlo : tiled_hlo_computation.instructions()) { bool is_operand = !fusion_adaptor.ContainsInstruction(tiled_hlo->hlo()); @@ -623,11 +645,7 @@ GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( // Decide on the number of warps to use based on the largest live tile size // at any given point within the computation. - int64_t largest_live_tile_size = 1; - for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { - largest_live_tile_size = std::max( - largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); - } + int64_t largest_live_tile_size = *tiled_hlo_computation.GetLargestTileSize(); int64_t num_warps = GetNumWarps(largest_live_tile_size); return {static_cast(num_blocks), @@ -673,6 +691,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( } auto tiled_hlo_computation = std::move(maybe_tiled_hlo_computation.value()); + ComputeLargestTileSize(tiled_hlo_computation); LaunchDimensions launch_dimensions = GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, *device_info_); @@ -681,6 +700,11 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( EstimateRunTimeForTiledHloComputation( fusion_adaptor, tiled_hlo_computation, launch_dimensions)); + // Skip tilings with infinite runtime (e.g., due to register spilling). + if (estimate_run_time_data.exec_time == absl::InfiniteDuration()) { + continue; + } + if (!best_tiled_run_time_data.has_value() || estimate_run_time_data.exec_time < best_tiled_run_time_data->runtime_data.exec_time) { diff --git a/xla/service/gpu/model/tiled_hlo_computation.h b/xla/service/gpu/model/tiled_hlo_computation.h index 7b0dc5fa169a6..e6a8a80b771fd 100644 --- a/xla/service/gpu/model/tiled_hlo_computation.h +++ b/xla/service/gpu/model/tiled_hlo_computation.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -116,6 +117,16 @@ class TiledHloComputation { return Product(num_output_tiles_per_dim()); } + // Returns the largest tile size (after padding) across all instructions. + // Returns std::nullopt if not yet computed. + std::optional GetLargestTileSize() const { + return largest_tile_size_; + } + + // Sets the largest tile size. This should be called after all instructions + // are added. + void SetLargestTileSize(int64_t size) { largest_tile_size_ = size; } + // Returns the root instructions of the computation. When a computation has // several outputs (i.e. it has a tuple root), the roots are the operands of // the root tuple. The roots are order by increasing output index, and point @@ -146,6 +157,10 @@ class TiledHloComputation { // Stores the number of output tiles for each dimension. llvm::SmallVector num_output_tiles_per_dim_; + + // Stores the largest tile size (after padding) across all instructions. + // Cached to avoid recomputation. + std::optional largest_tile_size_ = std::nullopt; }; } // namespace gpu