diff --git a/hopper/tile_size.h b/hopper/tile_size.h index d63999c638..24c76b84c2 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -52,7 +52,11 @@ constexpr std::tuple tile_size_fwd_sm90( } } else { if (headdim <= 64) { - return {192, 160, true, true}; + if (use_one_mma_wg) { + return {64, 128, true, true}; + } else { + return {192, 160, true, true}; + } } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) {