Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 131 additions & 41 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>

#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
Expand Down Expand Up @@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in, 0);
Expand Down Expand Up @@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
std::string kname;
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in, 0);
Expand Down Expand Up @@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log};

// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');

// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
kname,
out,
bm,
bn,
Expand Down Expand Up @@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(wt, 0);
Expand All @@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in_padded, 0);
Expand Down Expand Up @@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(out_wg, 0);
Expand All @@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
std::string base_name;
base_name.reserve(32);
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));

const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
Expand All @@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
};

// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on

std::string hash_name = kname.str();
std::string hash_name;
hash_name.reserve(64);
concatenate(
hash_name,
base_name,
"_ker_h_", ker_h,
"_ker_w_", ker_w,
"_str_h_", str_h,
"_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on

auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
Expand Down Expand Up @@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
}
}

void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));

if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);

auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);

compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);

} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}

compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);

compute_encoder.dispatch_threads(grid_dims, group_dims);
}

void conv_1D_gpu(
const Stream& s,
metal::Device& d,
Expand All @@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Fast path for fully separable 1D convolution
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}

const int C_per_group = C / groups;
const int O_per_group = O / groups;

// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
Expand Down
34 changes: 34 additions & 0 deletions mlx/backend/metal/kernels/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);

template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;

float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}

#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)

instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);

///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////
Expand Down