Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ bool DoesTileFitsInRegisters(int64_t tile_size,
device_info.registers_per_block_limit();
}

// Computes and caches the largest tile size in the tiled computation.
int64_t ComputeLargestTileSize(TiledHloComputation& tiled_hlo_computation) {
if (tiled_hlo_computation.GetLargestTileSize().has_value()) {
return *tiled_hlo_computation.GetLargestTileSize();
}

int64_t largest_tile_size = 1;
for (const TiledHloInstruction* tiled_hlo :
tiled_hlo_computation.instructions()) {
largest_tile_size =
std::max(largest_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes()));
}
tiled_hlo_computation.SetLargestTileSize(largest_tile_size);
return largest_tile_size;
}

// Checks if all tiles in the computation fit in registers.
//
// There is no way to know for sure if emitted computation will not spill
Expand All @@ -141,12 +157,18 @@ bool DoesComputationFitInRegisters(
const se::DeviceDescription& device_info) {
// Check that output tiles fit in registers.
for (const TiledHloInstruction* root : tiled_hlo_computation.GetRoots()) {
if (!DoesTileFitsInRegisters(GetPaddedTileSize(root->tile_sizes()),
device_info)) {
int64_t padded_tile_size = GetPaddedTileSize(root->tile_sizes());
if (!DoesTileFitsInRegisters(padded_tile_size, device_info)) {
return false;
}
}

// Check that the largest tile fits in registers.
int64_t largest_tile_size = *tiled_hlo_computation.GetLargestTileSize();
if (!DoesTileFitsInRegisters(largest_tile_size, device_info)) {
return false;
}

for (const TiledHloInstruction* tiled_hlo :
tiled_hlo_computation.instructions()) {
bool is_operand = !fusion_adaptor.ContainsInstruction(tiled_hlo->hlo());
Expand Down Expand Up @@ -623,11 +645,7 @@ GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion(

// Decide on the number of warps to use based on the largest live tile size
// at any given point within the computation.
int64_t largest_live_tile_size = 1;
for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) {
largest_live_tile_size = std::max(
largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes()));
}
int64_t largest_live_tile_size = *tiled_hlo_computation.GetLargestTileSize();
int64_t num_warps = GetNumWarps(largest_live_tile_size);

return {static_cast<uint64_t>(num_blocks),
Expand Down Expand Up @@ -673,6 +691,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion(
}

auto tiled_hlo_computation = std::move(maybe_tiled_hlo_computation.value());
ComputeLargestTileSize(tiled_hlo_computation);
LaunchDimensions launch_dimensions =
GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, *device_info_);

Expand All @@ -681,6 +700,11 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion(
EstimateRunTimeForTiledHloComputation(
fusion_adaptor, tiled_hlo_computation, launch_dimensions));

// Skip tilings with infinite runtime (e.g., due to register spilling).
if (estimate_run_time_data.exec_time == absl::InfiniteDuration()) {
continue;
}

if (!best_tiled_run_time_data.has_value() ||
estimate_run_time_data.exec_time <
best_tiled_run_time_data->runtime_data.exec_time) {
Expand Down
15 changes: 15 additions & 0 deletions xla/service/gpu/model/tiled_hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -116,6 +117,16 @@ class TiledHloComputation {
return Product(num_output_tiles_per_dim());
}

// Returns the largest tile size (after padding) across all instructions.
// Returns std::nullopt if not yet computed.
std::optional<int64_t> GetLargestTileSize() const {
return largest_tile_size_;
}

// Sets the largest tile size. This should be called after all instructions
// are added.
void SetLargestTileSize(int64_t size) { largest_tile_size_ = size; }

// Returns the root instructions of the computation. When a computation has
// several outputs (i.e. it has a tuple root), the roots are the operands of
// the root tuple. The roots are order by increasing output index, and point
Expand Down Expand Up @@ -146,6 +157,10 @@ class TiledHloComputation {

// Stores the number of output tiles for each dimension.
llvm::SmallVector<int64_t> num_output_tiles_per_dim_;

// Stores the largest tile size (after padding) across all instructions.
// Cached to avoid recomputation.
std::optional<int64_t> largest_tile_size_ = std::nullopt;
};

} // namespace gpu
Expand Down