Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 133 additions & 57 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,32 @@ MTL::Library* load_library(

} // namespace

StreamOpLock::StreamOpLock(
Device& device,
DeviceStream& stream,
std::mutex& mtx)
: device_(device),
stream_(stream),
lock_(mtx),
sequence_(stream.submission.sequence) {
#ifdef MLX_METAL_GLOBAL_OP_LOCK
device_.global_debug_owner_.set();
#else
stream_.debug_owner.set();
#endif
}

StreamOpLock::~StreamOpLock() {
#ifdef MLX_METAL_GLOBAL_OP_LOCK
device_.global_debug_owner_.clear();
#else
stream_.debug_owner.clear();
#endif
}

CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_ = stream_.submission.buffer->computeCommandEncoder(
MTL::DispatchTypeConcurrent);
enc_->retain();
}

Expand All @@ -259,7 +283,7 @@ void CommandEncoder::set_input_array(
int idx,
int64_t offset /* = 0 */) {
if (all_inputs_.insert(a.buffer().ptr()).second) {
stream_.buffer_sizes += a.data_size();
stream_.submission.buffer_sizes += a.data_size();
}
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
needs_barrier_ =
Expand Down Expand Up @@ -303,15 +327,15 @@ void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
stream_.submission.buffer_ops++;
enc_->dispatchThreadgroups(grid_dims, group_dims);
}

void CommandEncoder::dispatch_threads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
stream_.submission.buffer_ops++;
enc_->dispatchThreads(grid_dims, group_dims);
}

Expand Down Expand Up @@ -371,7 +395,10 @@ Device::~Device() {
k->release();
}
}
stream_map_.clear();
{
std::unique_lock wlock(stream_map_mtx_);
stream_map_.clear();
}
device_->release();
}

Expand All @@ -383,108 +410,149 @@ void Device::new_queue(int index) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
std::unique_lock wlock(stream_map_mtx_);
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}

DeviceStream& Device::get_stream_(int index) {
std::shared_lock rlock(stream_map_mtx_);
if (auto it = stream_map_.find(index); it != stream_map_.end()) {
return it->second;
}
throw std::runtime_error("[metal::Device] Unknown stream index.");
}

void Device::assert_stream_lock_held_(DeviceStream& stream, const char* where)
const {
#ifdef MLX_METAL_GLOBAL_OP_LOCK
global_debug_owner_.assert_held(where);
#else
stream.debug_owner.assert_held(where);
#endif
}

StreamOpLock Device::lock_stream_ops(int index) {
#ifdef MLX_METAL_GLOBAL_OP_LOCK
auto& stream = get_stream_(index);
return StreamOpLock(*this, stream, global_op_mtx_);
#else
auto& stream = get_stream_(index);
return StreamOpLock(*this, stream, stream.op_mtx);
#endif
}

MTL::CommandQueue* Device::get_queue(Stream stream) {
return get_stream_(stream.index).queue;
}

bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
return (stream.buffer_ops > max_ops_per_buffer_) ||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
assert_stream_lock_held_(stream, "command_buffer_needs_commit");
auto& epoch = stream.submission;
return (epoch.buffer_ops > max_ops_per_buffer_) ||
((epoch.buffer_sizes >> 20) > max_mb_per_buffer_);
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
assert_stream_lock_held_(stream, "get_command_buffer");
auto& epoch = stream.submission;
if (epoch.buffer == nullptr) {
epoch.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!epoch.buffer) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
stream.buffer->retain();
// Increment ref count so the buffer is not garbage collected.
epoch.buffer->retain();
if (epoch.state != SubmissionEpoch::State::IDLE) {
throw std::runtime_error(
"[metal::Device] Invalid state transition to OPEN");
}
epoch.state = SubmissionEpoch::State::OPEN;
}
return stream.buffer;
return epoch.buffer;
}

void Device::commit_command_buffer(int index) {
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
stream.buffer_sizes = 0;
assert_stream_lock_held_(stream, "commit_command_buffer");
auto& epoch = stream.submission;
if (epoch.buffer == nullptr) {
return;
}
if (epoch.state != SubmissionEpoch::State::ENDED) {
throw std::runtime_error(
"[metal::Device] Invalid state transition to COMMITTED");
}
epoch.buffer->commit();
epoch.buffer->release();
epoch.buffer = nullptr;
epoch.buffer_ops = 0;
epoch.buffer_sizes = 0;
epoch.state = SubmissionEpoch::State::COMMITTED;
// Epoch boundary: COMMITTED -> IDLE and increment generation.
epoch.sequence++;
epoch.state = SubmissionEpoch::State::IDLE;
}

void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
auto& stream = get_stream_(index);
assert_stream_lock_held_(stream, "add_temporary");
stream.submission.temporaries.push_back(std::move(arr));
}

void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
assert_stream_lock_held_(stream, "add_temporaries");
stream.submission.temporaries.insert(
stream.submission.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}

void Device::end_encoding(int index) {
void Device::end_encoding(int index, StreamOpLock& lk) {
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unnecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
assert_stream_lock_held_(stream, "end_encoding");
auto& epoch = stream.submission;
if (epoch.encoder != nullptr) {
// Lock ordering invariant: op lock is held before entering this method,
// and fence state is only touched through with_fence_state().
auto& enc = *epoch.encoder;
for (auto& t : epoch.temporaries) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}

// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
lk.with_fence_state([&](auto& outputs) {
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (auto it = outputs.find(in); it != outputs.end()) {
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
outputs[out] = epoch.fence;
}
}
enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler(
});
enc.update_fence(epoch.fence->fence);
epoch.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
fence = std::move(epoch.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
std::move(epoch.temporaries)](MTL::CommandBuffer*) mutable {
// Completion handlers must never take op_mtx. Fence map only.
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
Expand All @@ -495,21 +563,28 @@ void Device::end_encoding(int index) {
}
}
});
epoch.state = SubmissionEpoch::State::ENDED;
} else if (
epoch.buffer != nullptr && epoch.state == SubmissionEpoch::State::OPEN) {
// Explicit OPEN -> ENDED transition for empty encoders.
epoch.state = SubmissionEpoch::State::ENDED;
}
stream.encoder = nullptr;
epoch.encoder = nullptr;
}

CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
assert_stream_lock_held_(stream, "get_command_encoder");
auto& epoch = stream.submission;
if (epoch.encoder == nullptr) {
if (epoch.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
epoch.encoder = std::make_unique<CommandEncoder>(stream);
epoch.fence = std::make_shared<Fence>(device_->newFence());
epoch.state = SubmissionEpoch::State::ENCODING;
}
return *stream.encoder;
return *epoch.encoder;
}

MTL::Library* Device::get_library(
Expand Down Expand Up @@ -808,6 +883,7 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
std::shared_lock rlock(stream_map_mtx_);
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
Expand Down
Loading