Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ struct webgpu_pool_bufs {
wgpu::Buffer dev_buf;
};

// The futures to wait on for a single queue submission
struct webgpu_submission_futures {
std::vector<wgpu::FutureWaitInfo> futures;
};

// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
Expand Down Expand Up @@ -463,26 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
/** End WebGPU object initializations */

/** WebGPU Actions */
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}

// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission_futures> & futures,
bool block = true) {
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
if (futures.empty()) {
return;
}
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
futures.erase(futures.begin());
Comment on lines -474 to -475
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would previously wait for any future within the first vector of futures to finish, and then delete the entire first vector (instead of just the completed future) since futures is a vector of vectors of futures. I think this bug surfaced with the param buf diff because by expanding the parameter buffer, we can have multiple futures in flight instead of just 1, so we may delete an inflight future alongside a completed future. If something was waiting for a deleted future, it would then wait forever, causing test-thread-safety to time out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition isn't doing quite what we want anymore, since there is now no separation between param_bufs/set_row_error_bufs/gpu_profile bufs from different batch submissions. But, I have a PR coming up soon which should simplify this further and I think I can split it out into making sure we free enough param bufs for future batches. So this is fine to merge for now.

auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
if (waitStatus == wgpu::WaitStatus::Error) {
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
}
if (futures[0].completed) {
futures.erase(futures.begin());
}
}

if (futures.empty()) {
return;
}
size_t i = 0;
while (i < futures.size()) {
auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);

if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
}
}
} else {
// Poll once and return
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the intended behavior when block = false btw? Since I think calling WaitAny with a timeout of 0 just checks once and then returns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, the idea is to just check when block=false, in case the implementation isn't good at scheduling callbacks on its own.

auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
futures.erase(futures.begin() + i);
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::TimedOut:
i++;
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
Expand Down Expand Up @@ -525,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
}
#endif

static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
std::vector<wgpu::CommandBuffer> command_buffers;
std::vector<webgpu_pool_bufs> params_bufs;
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
Expand Down Expand Up @@ -600,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex
futures.push_back({ f });
}
#endif
return { futures };
return futures;
}

static webgpu_command ggml_backend_webgpu_build_multi(
Expand Down Expand Up @@ -727,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,

webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
ctx->memset_buf_pool) };
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
ggml_backend_webgpu_wait(ctx, futures);
}

Expand Down Expand Up @@ -836,7 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);

return flags;
}
Expand Down Expand Up @@ -1153,8 +1182,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
};

// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
uint32_t wg_x = 1;
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;

if (use_fast && is_vec) {
Expand Down Expand Up @@ -1410,7 +1439,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
uint32_t offset_merged_src0 = 0;
uint32_t offset_merged_src1 = 0;
if (flags.src_overlap) {
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
}
Expand All @@ -1419,7 +1448,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
Expand Down Expand Up @@ -2121,29 +2150,30 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str

WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);

std::vector<webgpu_command> commands;
std::vector<webgpu_submission_futures> futures;
uint32_t num_batched_kernels = 0;
std::vector<webgpu_command> commands;
std::vector<wgpu::FutureWaitInfo> futures;
uint32_t num_batched_kernels = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}

if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
&ctx->set_rows_error_buf_pool));
num_batched_kernels = 0;
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
commands.clear();
}
}
if (!commands.empty()) {
webgpu_submission_futures new_futures =
auto new_futures =
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.push_back(new_futures);
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
}

ggml_backend_webgpu_wait(ctx->global_ctx, futures);
Expand Down
Loading