Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/backend/cuda/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_;
}

Expand Down
22 changes: 10 additions & 12 deletions mlx/backend/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ namespace mlx::core::cu {

namespace {

// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;

#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))

void check_cudnn_error(const char* name, cudnnStatus_t err) {
Expand Down Expand Up @@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {

CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}

Expand Down Expand Up @@ -221,19 +218,14 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id);
}

void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}

void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
Expand All @@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
Expand Down Expand Up @@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {

void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
Expand All @@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
insert_graph_dependencies(GraphNode{node, 'G'});
}

int CommandEncoder::get_num_ops() {
return node_count_;
}

void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_,
Expand Down Expand Up @@ -355,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));

// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
Expand All @@ -367,6 +364,7 @@ void CommandEncoder::commit() {

// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
}

void CommandEncoder::synchronize() {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class CommandEncoder {
}

void add_completed_handler(std::function<void()> task);
void maybe_commit();
int get_num_ops();
void commit();

Device& device() {
Expand Down
16 changes: 14 additions & 2 deletions mlx/backend/cuda/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"

#include <nvtx3/nvtx3.hpp>

namespace mlx::core::gpu {

// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;

bool is_available() {
return true;
}
Expand All @@ -36,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs);
}

auto& encoder = cu::get_command_encoder(arr.primitive().stream());
auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
Expand All @@ -47,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();

if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
}

void finalize(Stream s) {
Expand Down