Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 200 additions & 2 deletions amdsharktuner/amdsharktuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tempfile

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen, iree_gpu, transform # type: ignore
from iree.compiler.dialects import iree_codegen, iree_gpu, linalg, transform # type: ignore
import iree.compiler as ireec # type: ignore
from iree.compiler._mlir_libs._mlir import ir # type: ignore

Expand Down Expand Up @@ -190,6 +190,23 @@ class ContractionDimensions:
batch: list[int] = field(default_factory=list)


@dataclass
class ConvToIgemmInfo:
"""
Stores information about convolution to IGEMM transformation.
This corresponds to the C++ ConvToIgemmInfo struct in IREE.

Note: In C++, conv_to_igemm_dim_map is DenseMap<int64_t, AffineExpr>,
but in Python bindings it's dict[int, int] mapping conv dim to IGEMM position.
"""

is_batch_dim_last: bool = False
is_spatial_dim_last: bool = False
conv_dims: Optional[linalg.ConvolutionDimensions] = None
conv_to_igemm_dim_map: dict[int, int] = field(default_factory=dict)
input_channel_dim_to_size: dict[int, int] = field(default_factory=dict)


@dataclass
class MatmulShapeType:
m: int
Expand Down Expand Up @@ -233,6 +250,24 @@ class AttentionKnobs(KnobAssignment):
pass


def is_affine_expr_function_of_dim(expr: ir.AffineExpr, position: int) -> bool:
"""
Return True if the expression depends on the dimension at the given position.
"""
if ir.AffineDimExpr.isinstance(expr):
dim_expr = ir.AffineDimExpr(expr)
return dim_expr.position == position

# Check if it's a binary operation and recursively check both sides.
if ir.AffineBinaryExpr.isinstance(expr):
binary_expr = ir.AffineBinaryExpr(expr)
return is_affine_expr_function_of_dim(
binary_expr.lhs, position
) or is_affine_expr_function_of_dim(binary_expr.rhs, position)

return False


def get_map_result_dim_positions(map: ir.AffineMap) -> Optional[list[int]]:
if not map.is_projected_permutation:
return None
Expand Down Expand Up @@ -281,7 +316,7 @@ def get_lowering_config(
# A local variable to hold the transformed value.
promoted_value = value
match key:
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding":
case "workgroup" | "reduction" | "subgroup" | "promote_operands" | "padding" | "padding_conv":
if isinstance(value, Sequence):
promoted_value = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
Expand Down Expand Up @@ -523,3 +558,166 @@ def get_target_info(input_module: ir.Module) -> iree_gpu.TargetInfo:
target = executable_variant_op.target

return iree_gpu.TargetInfo.get_gpu_target_info(target)


# The following two functions are from IREE side for padding utility:
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L631-L671
def maybe_padded_bounds(original_bound: int, alignment: int) -> tuple[int, bool]:
remainder = original_bound % alignment
if remainder == 0:
return original_bound, False
return original_bound + alignment - remainder, True


def get_dim_bounds(
dims: list[int],
padding_can_be_expensive: bool,
) -> tuple[list[int], bool]:
result = []
any_padding_applied = False

for dim in dims:
if padding_can_be_expensive:
result.append(dim)
continue

# TODO: Make over-padding a tunable parameter. This logic allows over-padding to get larger
# tile sizes, which may result in better performance despite doing more padded computation.
if dim > 128:
padded, was_padded = maybe_padded_bounds(dim, 128)
result.append(padded)
any_padding_applied = any_padding_applied or was_padded
elif dim > 32:
padded, was_padded = maybe_padded_bounds(dim, 32)
result.append(padded)
any_padding_applied = any_padding_applied or was_padded
else:
result.append(dim)

return result, any_padding_applied


# Implemented padding logic from IREE side:
# https://github.com/iree-org/iree/blob/8ae91ebb0e555e660b8a6898f6071476f7a1f20b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L382-L467
def get_padding_conv_sizes(
bounds: list[int],
padding_sizes: list[int],
workgroup_tile_sizes: list[int],
reduction_tile_sizes: list[int],
conv_to_igemm_info: ConvToIgemmInfo,
) -> Optional[list[int]]:
"""
Calculate padding sizes for convolution dimensions when using IGEMM.
This corresponds to C++ getPaddingConvSizes function in IREE.

Args:
bounds: Loop bounds for each dimension
padding_sizes: Padding sizes for IGEMM dimensions
workgroup_tile_sizes: Workgroup tile sizes
reduction_tile_sizes: Reduction tile sizes
conv_to_igemm_info: Convolution to IGEMM transformation info (must not be None)

Returns:
List of padding sizes for convolution dimensions, or None if padding should be skipped.
Caller should convert to ArrayAttr if needed.
"""
# Skip padding convolution for NCHW layout (spatial dimension last).
if conv_to_igemm_info.is_spatial_dim_last:
return None

conv_to_igemm_map = conv_to_igemm_info.conv_to_igemm_dim_map
padded_igemm_dims = set()
conv_dims = conv_to_igemm_info.conv_dims

assert conv_dims is not None, "Expected conv_dims to be set in ConvToIgemmInfo"

input_channel_dims = set(conv_dims.input_channel)

padding_conv_sizes = [0] * len(conv_to_igemm_map)

# For batch-last layout (e.g., CHWN), only pad the batch dimension to avoid
# introducing pad op as the producer of collapse_shape op which may cause fusion problem.
if conv_to_igemm_info.is_batch_dim_last:
last_batch_dim = conv_dims.batch[-1]
# The map stores integer positions, use them directly.
igemm_batch_pos = conv_to_igemm_map[last_batch_dim]

if (
padding_sizes[igemm_batch_pos]
and bounds[igemm_batch_pos] % padding_sizes[igemm_batch_pos] == 0
):
return None

padding_conv_sizes[last_batch_dim] = padding_sizes[igemm_batch_pos]
return padding_conv_sizes

# Process each convolution dimension mapping.
for conv_dim, igemm_pos in conv_to_igemm_map.items():
# The map stores integer positions directly.

if reduction_tile_sizes[igemm_pos] != 0:
# For reduction dimensions, avoid setting padding on the convolution
# if the product of the corresponding conv sizes are already divisible by the padding size.
if (
padding_sizes[igemm_pos]
and bounds[igemm_pos] % padding_sizes[igemm_pos] == 0
):
padded_igemm_dims.add(igemm_pos)
continue

# Only pad input channel dims. If we need to pad filter dims, then we
# would rather just do padding on the IGEMM instead.
if conv_dim in input_channel_dims:
# Multiple input channel dims for a single IGEMMPos is not supported.
if igemm_pos in padded_igemm_dims:
return None

input_channel_size = conv_to_igemm_info.input_channel_dim_to_size.get(
conv_dim, 0
)
is_input_channel_size_small = (
padding_sizes[igemm_pos] // input_channel_size > 2
)

# If the input channel dimension is much smaller than the padding size,
# skip padding along that dimension while still padding the others.
if is_input_channel_size_small:
padding_conv_sizes[conv_dim] = 0
else:
padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]

padded_igemm_dims.add(igemm_pos)
continue

# Multiple padded parallel dims mapping to the same IGEMM dim is not supported.
if workgroup_tile_sizes[igemm_pos] != 0 and igemm_pos in padded_igemm_dims:
return None

padding_conv_sizes[conv_dim] = padding_sizes[igemm_pos]
padded_igemm_dims.add(igemm_pos)

# Ensure that all dimensions have been padded.
if len(padded_igemm_dims) != len(padding_sizes):
return None

return padding_conv_sizes


def calculate_padded_dimensions(
M: list[int],
N: list[int],
contraction_dims: ContractionDimensions,
contraction_maps: list[ir.AffineMap],
) -> tuple[list[int], list[int], bool]:
# Detect LHS transposition. Padding is disabled only when LHS is transposed.
k_dim_inner = contraction_dims.k[-1]
lhs_map = contraction_maps[0]
lhs_last_expr = lhs_map.results[-1]
lhs_dim_expr = ir.AffineDimExpr(lhs_last_expr)

transposed_lhs = k_dim_inner != lhs_dim_expr.position

M_padded, m_padding_applied = get_dim_bounds(M, transposed_lhs)
N_padded, n_padding_applied = get_dim_bounds(N, transposed_lhs)

return M_padded, N_padded, m_padding_applied or n_padding_applied
63 changes: 49 additions & 14 deletions amdsharktuner/amdsharktuner/constraint_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABC, abstractmethod
from typing import Iterator, Optional

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen, iree_gpu, linalg # type: ignore

from . import common, dispatch_constraints, dispatch_parser
Expand Down Expand Up @@ -57,7 +58,6 @@ def adjust_problem_size_for_pipeline(
matmul_size.N = [bounds[i] for i in contraction_dims.n]
matmul_size.K = [bounds[i] for i in contraction_dims.k]
matmul_size.B = [bounds[i] for i in contraction_dims.batch]

return

# Fallback: Manual flattening for legacy path when IGEMM details are unavailable.
Expand All @@ -76,11 +76,13 @@ def generate_generic_contraction_solutions(
rhs_type: common.ShapedType,
res_type: common.ShapedType,
dispatch_kind: common.DispatchKind,
indexing_maps: list[ir.AffineMap],
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
num_subgroups: int = 4,
allowed_waves_per_eu: list[int] = [2],
pipeline_options_search_space: dispatch_constraints.PipelineOptionsSearchSpace = dispatch_constraints.PipelineOptionsSearchSpace(),
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None,
conv_to_igemm_info: Optional[common.ConvToIgemmInfo] = None,
) -> Iterator[list[common.TuningConfiguration]]:
adjust_problem_size_for_pipeline(
contraction_dims,
Expand All @@ -94,6 +96,23 @@ def generate_generic_contraction_solutions(
M, N, K = matmul_size.M, matmul_size.N, matmul_size.K
tuner_ctx.logger.debug(f"{M},{N},{K}")

# Apply padding for TileAndFuse pipeline to get better tile sizes.
padding_applied = False
if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
# Use IGEMM maps if available (dimensions were restructured), otherwise use original indexing maps.
padding_maps = indexing_maps
if igemm_details:
padding_maps = [
map_attr.value for map_attr in igemm_details.igemm_contraction_maps
]

M_padded, N_padded, padding_applied = common.calculate_padded_dimensions(
M, N, contraction_dims, padding_maps
)
M, N = M_padded, N_padded
matmul_size.M = M
matmul_size.N = N

m_vars = [z3.Int(f"m{i}") for i in range(len(M))]
n_vars = [z3.Int(f"n{i}") for i in range(len(N))]
k_vars = [z3.Int(f"k{i}") for i in range(len(K))]
Expand Down Expand Up @@ -238,23 +257,33 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
[lookup(v) for v in k_vars],
)

required_padding = any(
p[-1] % i != 0 for p, i in zip((M, N, K), intrinsic_mnk_shape, strict=True)
)
promote_operands = [0, 1]
padding = None
if required_padding:
padding_conv = None
if padding_applied:
# TODO: Remove promotion of operand 2 once codegen supports handling padded outputs without promotion.
promote_operands = [0, 1, 2]
_, _, mma_intrinsic_k = mma_attr.mnk_shape
padding = [
*(workgroup_tile_sizes[d] for d in contraction_dims.m),
*(workgroup_tile_sizes[d] for d in contraction_dims.n),
*(
reduction_tile_sizes[d] * mma_intrinsic_k
for d in contraction_dims.k
),
]
padding_tile_sizes = list(workgroup_tile_sizes)
for k_dim in contraction_dims.k:
padding_tile_sizes[k_dim] = reduction_tile_sizes[k_dim]

mma_intrinsic_k = mma_attr.mnk_shape[2]
inner_k_dim = contraction_dims.k[-1]
padding_tile_sizes[inner_k_dim] *= mma_intrinsic_k

padding = padding_tile_sizes

# Calculate padding_conv sizes for convolutions when using IGEMM.
if conv_to_igemm_info and igemm_details:
# Use IGEMM loop bounds directly from igemm_details.
bounds = list(igemm_details.igemm_loop_bounds)
padding_conv = common.get_padding_conv_sizes(
bounds,
padding_tile_sizes,
workgroup_tile_sizes,
reduction_tile_sizes,
conv_to_igemm_info,
)
# Setting subgroup basis.
# TODO(Bangtian): Sync changes from IREE PR: https://github.com/iree-org/iree/pull/22000.
subgroup_basis_counts = [1] * num_loops
Expand All @@ -279,6 +308,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
pipeline_options_search_space,
allowed_waves_per_eu,
padding=padding,
padding_conv=padding_conv,
)

solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
Expand Down Expand Up @@ -562,6 +592,7 @@ def generate_solutions(
rhs_type=self.op_info.rhs_type,
res_type=self.op_info.res_type,
dispatch_kind=common.DispatchKind.contraction,
indexing_maps=self.op_info.indexing_maps,
codegen_pipeline=codegen_pipeline,
**pipeline_constraint_options,
)
Expand All @@ -578,6 +609,8 @@ def generate_solutions(
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline,
**pipeline_constraint_options,
) -> Iterator[list[common.TuningConfiguration]]:
# TODO(Bangtian): Simplify the function signature to accept op_info directly instead of
# unpacking all individual fields.
return generate_generic_contraction_solutions(
tuner_ctx=tuner_context,
gpu_target_info=gpu_target_info,
Expand All @@ -587,8 +620,10 @@ def generate_solutions(
rhs_type=self.op_info.rhs_type,
res_type=self.op_info.res_type,
dispatch_kind=common.DispatchKind.conv,
indexing_maps=self.op_info.indexing_maps,
codegen_pipeline=codegen_pipeline,
igemm_details=self.op_info.igemm_details,
conv_to_igemm_info=self.op_info.conv_to_igemm_info,
**pipeline_constraint_options,
)

Expand Down
7 changes: 4 additions & 3 deletions amdsharktuner/amdsharktuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,6 @@ def generate_tile_and_fuse_constraints(
M, N, K = list(matmul_size.M), list(matmul_size.N), list(matmul_size.K)
m_tiles, n_tiles, k_tiles, subgroup_m_tiles, subgroup_n_tiles = tile_sizes
intrinsic_mn, intrinsic_k = intrinsic_size
M[-1] = ((M[-1] + intrinsic_mn - 1) / intrinsic_mn) * intrinsic_mn
N[-1] = ((N[-1] + intrinsic_mn - 1) / intrinsic_mn) * intrinsic_mn
K[-1] = ((K[-1] + intrinsic_k - 1) / intrinsic_k) * intrinsic_k
wg_x, wg_y, wg_z = workgroup_size
wg_threads = wg_x
constraints = [wg_y == 1, wg_z == 1]
Expand Down Expand Up @@ -675,6 +672,7 @@ def generate_compilation_infos(
pipeline_options_search_space: PipelineOptionsSearchSpace,
allowed_waves_per_eu: list[int],
padding: Optional[list[int]] = None,
padding_conv: Optional[list[int]] = None,
) -> list[iree_codegen.CompilationInfoAttr]:
subgroup_basis = [subgroup_basis_counts, subgroup_basis_mapping]
# Create the LoweringConfigAttr.
Expand All @@ -691,6 +689,9 @@ def generate_compilation_infos(
if padding is not None:
lowering_config_args["padding"] = padding

if padding_conv is not None:
lowering_config_args["padding_conv"] = padding_conv

if codegen_pipeline == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
lowering_config_args["subgroup"] = subgroup_tile_sizes

Expand Down
Loading
Loading