diff --git a/mlx/distributed/jaccl/mesh.cpp b/mlx/distributed/jaccl/mesh.cpp index c8df4e6745..52f43b949c 100644 --- a/mlx/distributed/jaccl/mesh.cpp +++ b/mlx/distributed/jaccl/mesh.cpp @@ -5,8 +5,6 @@ #include "mlx/distributed/reduction_ops.h" #include "mlx/dtype_utils.h" -constexpr int MAX_PEERS = 8; - namespace mlx::core::distributed::jaccl { MeshGroup::MeshGroup( @@ -17,9 +15,9 @@ MeshGroup::MeshGroup( size_(device_names.size()), side_channel_(rank_, size_, coordinator_addr), connections_(create_connections(device_names)) { - if (size_ > MAX_PEERS) { + if (size_ > MESH_MAX_PEERS) { std::ostringstream msg; - msg << "[jaccl] The JACCL mesh supports up to " << MAX_PEERS + msg << "[jaccl] The JACCL mesh supports up to " << MESH_MAX_PEERS << " peers but " << size_ << " were provided."; throw std::runtime_error(msg.str()); } @@ -29,6 +27,17 @@ MeshGroup::MeshGroup( // Make sure every node has reached here before continuing side_channel_.all_gather(0); + + // Create the mesh implementation object + mesh_ = MeshImpl(rank_, size_, connections_, buffers_); + ring_ = RingImpl( + rank_, + size_, + &connections_[(rank_ + size_ - 1) % size_], + &connections_[(rank_ + 1) % size_], + 1, + ring_send_buffers_, + ring_recv_buffers_); } void MeshGroup::initialize() { @@ -75,18 +84,27 @@ void MeshGroup::initialize() { void MeshGroup::allocate_buffers() { // Deregister any buffers and free the memory buffers_.clear(); + ring_send_buffers_.clear(); + ring_recv_buffers_.clear(); // Allocate the memory for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { + // Mesh buffers for (int j = 0; j < size_; j++) { buffers_.emplace_back(FRAME_SIZE * (1 << k)); } + // Ring buffers (1 for each direction) + for (int j = 0; j < 2; j++) { + ring_send_buffers_.emplace_back(FRAME_SIZE * (1 << k)); + ring_recv_buffers_.emplace_back(FRAME_SIZE * (1 << k)); + } } } for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { + // Mesh buffers for (int j = 0; j < size_; j++) { // This is our send buffer so register it with all pds so we can send // it to all connected devices. @@ -106,6 +124,19 @@ void MeshGroup::allocate_buffers() { .register_to_protection_domain(connections_[j].protection_domain); } } + + // Ring buffers (see ring group for the logic below) + // We register send buffers to both the right and the left. + int left = (rank_ + size_ - 1) % size_; + int right = (rank_ + 1) % size_; + ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0] + .register_to_protection_domain(connections_[right].protection_domain); + ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 0] + .register_to_protection_domain(connections_[left].protection_domain); + ring_send_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1] + .register_to_protection_domain(connections_[left].protection_domain); + ring_recv_buffers_[k * NUM_BUFFERS * 2 + i * 2 + 1] + .register_to_protection_domain(connections_[right].protection_domain); } } } @@ -139,83 +170,7 @@ void MeshGroup::all_gather(const array& input, array& output, Stream stream) { encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { - // Copy our data to the appropriate place - std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); - - // Fully connected all gather - char* data = out_ptr; - char* our_data = out_ptr + rank_ * n_bytes; - auto [sz, N] = buffer_size_from_message(n_bytes); - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2; - int64_t total = static_cast(n_bytes); - int num_peers = size_ - 1; - - // Counters to maintain the state of transfers - int in_flight = 0; - int read_offset = 0; - int completed_send_count[PIPELINE] = {0}; - int write_offset[MAX_PEERS] = {0}; - - // Prefill the pipeline - int buff = 0; - while (read_offset < total && buff < PIPELINE) { - post_recv_all(sz, buff); - std::copy( - our_data + read_offset, - our_data + std::min(read_offset + N, total), - send_buffer(sz, buff).begin()); - post_send_all(sz, buff); - - buff++; - in_flight += 2 * num_peers; - read_offset += N; - } - - // Main loop - // - // Keep going until we have no longer data in flight. - while (in_flight > 0) { - ibv_wc wc[WC_NUM]; - int n = poll(connections_, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - in_flight--; - - // Send completed. If all sends completed then send the next chunk. - if (work_type == SEND_WR && read_offset < total) { - completed_send_count[buff]++; - if (completed_send_count[buff] == num_peers) { - std::copy( - our_data + read_offset, - our_data + std::min(read_offset + N, total), - send_buffer(sz, buff).begin()); - post_send_all(sz, buff); - - completed_send_count[buff] = 0; - in_flight += num_peers; - read_offset += N; - } - } - - // Recv completed. If we have more chunks then post another recv. - else if (work_type == RECV_WR) { - std::copy( - recv_buffer(sz, buff, rank).begin(), - recv_buffer(sz, buff, rank).begin() + - std::min(N, total - write_offset[rank]), - data + rank * n_bytes + write_offset[rank]); - write_offset[rank] += N; - if (write_offset[rank] + N * (PIPELINE - 1) < total) { - recv_from(sz, rank, buff); - in_flight++; - } - } - } - } + mesh_.all_gather(in_ptr, out_ptr, n_bytes); }); } @@ -224,55 +179,8 @@ void MeshGroup::send(const array& input, int dst, Stream stream) { int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); - encoder.dispatch([data, n_bytes, dst, this]() { - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE; - auto [sz, N] = buffer_size_from_message(n_bytes); - - int in_flight = 0; - int64_t read_offset = 0; - - // Prefill the pipeline - int buff = 0; - while (read_offset < n_bytes && buff < PIPELINE) { - std::copy( - data + read_offset, - data + std::min(read_offset + N, n_bytes), - send_buffer(sz, buff).begin()); - send_to(sz, dst, buff); - - buff++; - read_offset += N; - in_flight++; - } - - // Main loop - while (in_flight > 0) { - // Poll the hardware for completions. - // - // If a send was completed and we have more data to send then go ahead - // and send them. - ibv_wc wc[WC_NUM]; - int n = connections_[dst].poll(WC_NUM, wc); - for (int i = 0; i < n; i++) { - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - in_flight--; - - if (read_offset < n_bytes) { - std::copy( - data + read_offset, - data + std::min(read_offset + N, n_bytes), - send_buffer(sz, buff).begin()); - send_to(sz, dst, buff); - - read_offset += N; - in_flight++; - } - } - } - }); + encoder.dispatch( + [data, n_bytes, dst, this]() { mesh_.send(data, n_bytes, dst); }); } void MeshGroup::recv(array& out, int src, Stream stream) { @@ -280,52 +188,8 @@ void MeshGroup::recv(array& out, int src, Stream stream) { int64_t n_bytes = out.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); - encoder.dispatch([data, n_bytes, src, this]() { - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE; - auto [sz, N] = buffer_size_from_message(n_bytes); - - int in_flight = 0; - int64_t write_offset = 0; - - // Prefill the pipeline - int buff = 0; - while (N * buff < n_bytes && buff < PIPELINE) { - recv_from(sz, src, buff); - - in_flight++; - buff++; - } - - // Main loop - while (in_flight > 0) { - // Poll the hardware for completions. - // - // If a recv was completed copy it to the output and if we have more - // data to fetch post another recv. - ibv_wc wc[WC_NUM]; - int n = connections_[src].poll(WC_NUM, wc); - for (int i = 0; i < n; i++) { - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - in_flight--; - - std::copy( - recv_buffer(sz, buff, src).begin(), - recv_buffer(sz, buff, src).begin() + - std::min(n_bytes - write_offset, static_cast(N)), - data + write_offset); - write_offset += N; - - if (write_offset + (PIPELINE - 1) * N < n_bytes) { - recv_from(sz, src, buff); - - in_flight++; - } - } - } - }); + encoder.dispatch( + [data, n_bytes, src, this]() { mesh_.recv(data, n_bytes, src); }); } template @@ -336,114 +200,17 @@ void MeshGroup::all_reduce( ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); + int64_t size = input.size(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { - // If not inplace all reduce then copy the input to the output first - if (in_ptr != out_ptr) { - std::memcpy(out_ptr, in_ptr, size * sizeof(T)); - } - - // Fully connected all reduce - T* data = out_ptr; - auto [sz, buffer_size] = buffer_size_from_message(size * sizeof(T)); - int64_t N = buffer_size / sizeof(T); - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2; - int64_t total = static_cast(size); - int num_peers = size_ - 1; - - // Counters to maintain the state of transfers - int in_flight = 0; - int64_t read_offset = 0; - int completed_send_count[PIPELINE] = {0}; - int completed_recv_begin[MAX_PEERS] = {0}; - int completed_recv_end[MAX_PEERS] = {0}; - - // Prefill the pipeline - int buff = 0; - while (read_offset < total && buff < PIPELINE) { - post_recv_all(sz, buff); - std::copy( - data + read_offset, - data + std::min(read_offset + N, total), - send_buffer(sz, buff).begin()); - post_send_all(sz, buff); - - buff++; - in_flight += 2 * num_peers; - read_offset += N; - } - - // Main loop - // - // Keep going until we have no longer data in flight. - while (in_flight > 0) { - // Poll the hardware for completions. - // - // If a send was completed mark how many completions we have received - // for that buffer. If we have sent the buffer to all peers we can - // reuse the buffer so copy the next chunk of data and send it to all. - // - // If a receive is completed then advance the pointer of completed - // receives. - ibv_wc wc[WC_NUM]; - int n = poll(connections_, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int rank = wc[i].wr_id & 0xff; - - in_flight--; - - if (work_type == SEND_WR && read_offset < total) { - completed_send_count[buff]++; - if (completed_send_count[buff] == num_peers) { - std::copy( - data + read_offset, - data + std::min(read_offset + N, total), - send_buffer(sz, buff).begin()); - post_send_all(sz, buff); - - completed_send_count[buff] = 0; - in_flight += num_peers; - read_offset += N; - } - } - - else if (work_type == RECV_WR) { - completed_recv_end[rank]++; - } - } - - // Process the completed recv - // - // For each rank we have a range of completed recv defined by a begin - // and end inclusive and exlusive in standard C++ fashion. - // - // When there is an unprocessed receive we first check if we have - // finished sending the write location. If so then we reduce in-place - // and then check if there is more to be received and post a recv. - for (int r = 0; r < size_; r++) { - int s = completed_recv_begin[r]; - int e = completed_recv_end[r]; - int w = s * N; - while (w < read_offset && e - s > 0) { - int buff = s % PIPELINE; - reduce_op( - recv_buffer(sz, buff, r).begin(), - data + w, - std::min(N, total - w)); - w += N; - s++; - if (w + (PIPELINE - 1) * N < total) { - recv_from(sz, r, buff); - in_flight++; - } - } - completed_recv_begin[r] = s; - } + encoder.dispatch([in_ptr, out_ptr, size, this, reduce_op]() { + if (size_ > 2 && + ((std::is_same_v && size > 65536) || + size >= 8 * 1024 * 1024 / sizeof(T))) { + ring_.all_reduce<2>(in_ptr, out_ptr, size, 1, reduce_op); + } else { + mesh_.all_reduce(in_ptr, out_ptr, size, reduce_op); } }); } diff --git a/mlx/distributed/jaccl/mesh.h b/mlx/distributed/jaccl/mesh.h index 6f779e9ccb..ed51361a7f 100644 --- a/mlx/distributed/jaccl/mesh.h +++ b/mlx/distributed/jaccl/mesh.h @@ -3,6 +3,8 @@ #pragma once #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/jaccl/mesh_impl.h" +#include "mlx/distributed/jaccl/ring_impl.h" #include "mlx/distributed/jaccl/utils.h" using GroupImpl = mlx::core::distributed::detail::GroupImpl; @@ -72,51 +74,16 @@ class MeshGroup : public GroupImpl { */ void allocate_buffers(); - void send_to(int sz, int rank, int buff) { - connections_[rank].post_send( - send_buffer(sz, buff), SEND_WR << 16 | buff << 8 | rank); - } - - void recv_from(int sz, int rank, int buff) { - connections_[rank].post_recv( - recv_buffer(sz, buff, rank), RECV_WR << 16 | buff << 8 | rank); - } - - SharedBuffer& send_buffer(int sz, int buff) { - return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank_]; - } - - SharedBuffer& recv_buffer(int sz, int buff, int rank) { - return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank]; - } - - void post_send_all(int sz, int buff) { - auto& b = send_buffer(sz, buff); - int wr_id = SEND_WR << 16 | buff << 8; - for (int i = 0; i < size_; i++) { - if (i == rank_) { - continue; - } - connections_[i].post_send(b, wr_id | i); - } - } - - void post_recv_all(int sz, int buff) { - int b = sz * NUM_BUFFERS * size_ + buff * size_; - int wr_id = RECV_WR << 16 | buff << 8; - for (int i = 0; i < size_; i++) { - if (i == rank_) { - continue; - } - connections_[i].post_recv(buffers_[b + i], wr_id | i); - } - } - int rank_; int size_; SideChannel side_channel_; std::vector connections_; std::vector buffers_; + std::vector ring_send_buffers_; + std::vector ring_recv_buffers_; + + MeshImpl mesh_; + RingImpl ring_; }; } // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/mesh_impl.h b/mlx/distributed/jaccl/mesh_impl.h new file mode 100644 index 0000000000..fc486a396b --- /dev/null +++ b/mlx/distributed/jaccl/mesh_impl.h @@ -0,0 +1,358 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/jaccl/utils.h" + +constexpr int MESH_MAX_PEERS = 8; + +namespace mlx::core::distributed::jaccl { + +class MeshImpl { + public: + MeshImpl( + int rank, + int size, + std::vector& conns, + std::vector& buffers) + : rank_(rank), size_(size), connections_(conns), buffers_(buffers) {} + + MeshImpl() : rank_(0), size_(1) {} + + template + void + all_reduce(const T* in_ptr, T* out_ptr, int64_t size, ReduceOp reduce_op) { + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, size * sizeof(T)); + } + + // Fully connected all reduce + T* data = out_ptr; + auto [sz, buffer_size] = buffer_size_from_message(size * sizeof(T)); + int64_t N = buffer_size / sizeof(T); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; + int64_t total = static_cast(size); + int num_peers = size_ - 1; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int completed_recv_begin[MESH_MAX_PEERS] = {0}; + int completed_recv_end[MESH_MAX_PEERS] = {0}; + + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(sz, buff); + std::copy( + data + read_offset, + data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + buff++; + in_flight += 2 * num_peers; + read_offset += N; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed mark how many completions we have received + // for that buffer. If we have sent the buffer to all peers we can + // reuse the buffer so copy the next chunk of data and send it to all. + // + // If a receive is completed then advance the pointer of completed + // receives. + ibv_wc wc[WC_NUM]; + int n = poll(connections_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; + } + } + + else if (work_type == RECV_WR) { + completed_recv_end[rank]++; + } + } + + // Process the completed recv + // + // For each rank we have a range of completed recv defined by a begin + // and end inclusive and exlusive in standard C++ fashion. + // + // When there is an unprocessed receive we first check if we have + // finished sending the write location. If so then we reduce in-place + // and then check if there is more to be received and post a recv. + for (int r = 0; r < size_; r++) { + int s = completed_recv_begin[r]; + int e = completed_recv_end[r]; + int w = s * N; + while (w < read_offset && e - s > 0) { + int buff = s % PIPELINE; + reduce_op( + recv_buffer(sz, buff, r).begin(), + data + w, + std::min(N, total - w)); + w += N; + s++; + if (w + (PIPELINE - 1) * N < total) { + recv_from(sz, r, buff); + in_flight++; + } + } + completed_recv_begin[r] = s; + } + } + } + + void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) { + // Copy our data to the appropriate place + std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + + // Fully connected all gather + char* data = out_ptr; + char* our_data = out_ptr + rank_ * n_bytes; + auto [sz, N] = buffer_size_from_message(n_bytes); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MESH_MAX_PEERS * 2; + int64_t total = static_cast(n_bytes); + int num_peers = size_ - 1; + + // Counters to maintain the state of transfers + int in_flight = 0; + int read_offset = 0; + int completed_send_count[PIPELINE] = {0}; + int write_offset[MESH_MAX_PEERS] = {0}; + + // Prefill the pipeline + int buff = 0; + while (read_offset < total && buff < PIPELINE) { + post_recv_all(sz, buff); + std::copy( + our_data + read_offset, + our_data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + buff++; + in_flight += 2 * num_peers; + read_offset += N; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(connections_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + // Send completed. If all sends completed then send the next chunk. + if (work_type == SEND_WR && read_offset < total) { + completed_send_count[buff]++; + if (completed_send_count[buff] == num_peers) { + std::copy( + our_data + read_offset, + our_data + std::min(read_offset + N, total), + send_buffer(sz, buff).begin()); + post_send_all(sz, buff); + + completed_send_count[buff] = 0; + in_flight += num_peers; + read_offset += N; + } + } + + // Recv completed. If we have more chunks then post another recv. + else if (work_type == RECV_WR) { + std::copy( + recv_buffer(sz, buff, rank).begin(), + recv_buffer(sz, buff, rank).begin() + + std::min(N, total - write_offset[rank]), + data + rank * n_bytes + write_offset[rank]); + write_offset[rank] += N; + if (write_offset[rank] + N * (PIPELINE - 1) < total) { + recv_from(sz, rank, buff); + in_flight++; + } + } + } + } + } + + void send(const char* in_ptr, int64_t n_bytes, int dst) { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + + int in_flight = 0; + int64_t read_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (read_offset < n_bytes && buff < PIPELINE) { + std::copy( + in_ptr + read_offset, + in_ptr + std::min(read_offset + N, n_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, dst, buff); + + buff++; + read_offset += N; + in_flight++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed and we have more data to send then go ahead + // and send them. + ibv_wc wc[WC_NUM]; + int n = connections_[dst].poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + if (read_offset < n_bytes) { + std::copy( + in_ptr + read_offset, + in_ptr + std::min(read_offset + N, n_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, dst, buff); + + read_offset += N; + in_flight++; + } + } + } + } + + void recv(char* out_ptr, int64_t n_bytes, int src) { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + + int in_flight = 0; + int64_t write_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (N * buff < n_bytes && buff < PIPELINE) { + recv_from(sz, src, buff); + + in_flight++; + buff++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a recv was completed copy it to the output and if we have more + // data to fetch post another recv. + ibv_wc wc[WC_NUM]; + int n = connections_[src].poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + std::copy( + recv_buffer(sz, buff, src).begin(), + recv_buffer(sz, buff, src).begin() + + std::min(n_bytes - write_offset, static_cast(N)), + out_ptr + write_offset); + write_offset += N; + + if (write_offset + (PIPELINE - 1) * N < n_bytes) { + recv_from(sz, src, buff); + + in_flight++; + } + } + } + } + + private: + void send_to(int sz, int rank, int buff) { + connections_[rank].post_send( + send_buffer(sz, buff), SEND_WR << 16 | buff << 8 | rank); + } + + void recv_from(int sz, int rank, int buff) { + connections_[rank].post_recv( + recv_buffer(sz, buff, rank), RECV_WR << 16 | buff << 8 | rank); + } + + SharedBuffer& send_buffer(int sz, int buff) { + return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank_]; + } + + SharedBuffer& recv_buffer(int sz, int buff, int rank) { + return buffers_[sz * NUM_BUFFERS * size_ + buff * size_ + rank]; + } + + void post_send_all(int sz, int buff) { + auto& b = send_buffer(sz, buff); + int wr_id = SEND_WR << 16 | buff << 8; + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + connections_[i].post_send(b, wr_id | i); + } + } + + void post_recv_all(int sz, int buff) { + int b = sz * NUM_BUFFERS * size_ + buff * size_; + int wr_id = RECV_WR << 16 | buff << 8; + for (int i = 0; i < size_; i++) { + if (i == rank_) { + continue; + } + connections_[i].post_recv(buffers_[b + i], wr_id | i); + } + } + + int rank_; + int size_; + std::span connections_; + std::span buffers_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/ring.cpp b/mlx/distributed/jaccl/ring.cpp index fe27a2e9c6..4a09d85cc8 100644 --- a/mlx/distributed/jaccl/ring.cpp +++ b/mlx/distributed/jaccl/ring.cpp @@ -15,12 +15,13 @@ RingGroup::RingGroup( const char* coordinator_addr) : rank_(rank), size_(size), + n_conns_(left_devices.size()), side_channel_(rank_, size_, coordinator_addr), left_(create_connections(left_devices)), right_(create_connections(right_devices)) { - if (left_.size() > MAX_CONNS || right_.size() > MAX_CONNS) { + if (left_.size() > RING_MAX_CONNS || right_.size() > RING_MAX_CONNS) { std::ostringstream msg; - msg << "[jaccl] Up to " << MAX_CONNS << " per direction supported but " + msg << "[jaccl] Up to " << RING_MAX_CONNS << " per direction supported but " << left_.size() << " were provided."; throw std::runtime_error(msg.str()); } @@ -30,6 +31,9 @@ RingGroup::RingGroup( // Make sure every node has reached here before continuing side_channel_.all_gather(0); + + // Create the ring implementation object + ring_ = RingImpl(rank_, size_, left_, right_, send_buffers_, recv_buffers_); } void RingGroup::initialize() { @@ -93,7 +97,7 @@ void RingGroup::allocate_buffers() { // Allocate the memory for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { - for (int j = 0; j < MAX_CONNS * 2; j++) { + for (int j = 0; j < n_conns_ * 2; j++) { send_buffers_.emplace_back(FRAME_SIZE * (1 << k)); recv_buffers_.emplace_back(FRAME_SIZE * (1 << k)); } @@ -103,21 +107,18 @@ void RingGroup::allocate_buffers() { // Register the buffers with the corresponding connections for (int k = 0; k < BUFFER_SIZES; k++) { for (int i = 0; i < NUM_BUFFERS; i++) { - for (int j = 0; j < MAX_CONNS * 2; j++) { - int wire = j % MAX_CONNS; - int lr = j / MAX_CONNS; - if (wire >= left_.size()) { - continue; - } + for (int j = 0; j < n_conns_ * 2; j++) { + int wire = j % n_conns_; + int lr = j / n_conns_; if (lr) { - send_buffers_[k * NUM_BUFFERS * MAX_CONNS * 2 + i * MAX_CONNS * 2 + j] + send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(left_[wire].protection_domain); - recv_buffers_[k * NUM_BUFFERS * MAX_CONNS * 2 + i * MAX_CONNS * 2 + j] + recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(right_[wire].protection_domain); } else { - send_buffers_[k * NUM_BUFFERS * MAX_CONNS * 2 + i * MAX_CONNS * 2 + j] + send_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(right_[wire].protection_domain); - recv_buffers_[k * NUM_BUFFERS * MAX_CONNS * 2 + i * MAX_CONNS * 2 + j] + recv_buffers_[k * NUM_BUFFERS * n_conns_ * 2 + i * n_conns_ * 2 + j] .register_to_protection_domain(left_[wire].protection_domain); } } @@ -149,114 +150,12 @@ void RingGroup::all_min(const array& input, array& output, Stream stream) { void RingGroup::all_gather(const array& input, array& output, Stream stream) { auto in_ptr = input.data(); auto out_ptr = output.data(); - size_t n_bytes = input.nbytes(); + int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { - // Copy our data to the appropriate place - std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); - - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_CONNS * 2 * 2; - int n_wires = left_.size(); - size_t n_bytes_per_wire = (n_bytes + (2 * n_wires) - 1) / (2 * n_wires); - size_t out_bytes = n_bytes * size_; - auto [sz, N] = buffer_size_from_message(n_bytes_per_wire); - int n_steps = (n_bytes_per_wire + N - 1) / N; - - // Counters to maintain the state of transfers - int in_flight = 0; - int64_t send_offset[2]; - int64_t recv_offset[2]; - int64_t limits[2]; - int send_count[2 * MAX_CONNS] = {0}; - int recv_count[2 * MAX_CONNS] = {0}; - send_offset[0] = send_offset[1] = rank_ * n_bytes; - recv_offset[0] = ((rank_ + size_ - 1) % size_) * n_bytes; - recv_offset[1] = ((rank_ + 1) % size_) * n_bytes; - limits[0] = n_wires * n_bytes_per_wire; - limits[1] = n_bytes; - - // Possible perf improvement by not syncing at every step but running ahead - // as needed. - for (int k = 0; k < size_ - 1; k++) { - // Prefill the pipeline - int buff = 0; - while (buff < n_steps && buff < PIPELINE) { - post_recv_all(sz, buff); - for (int lr = 0; lr < 2; lr++) { - for (int lw = 0; lw < n_wires; lw++) { - int64_t offset = lw * N + - send_count[lr * MAX_CONNS + lw] * n_wires * N + - lr * n_wires * n_bytes_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_count[lr * MAX_CONNS + lw]++; - } - } - post_send_all(sz, buff); - - buff++; - in_flight += 2 * 2 * n_wires; - } - - // Main loop - // - // Keep going until we have no longer data in flight. - while (in_flight > 0) { - ibv_wc wc[WC_NUM]; - int n = poll(left_, right_, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int wire = wc[i].wr_id & 0xff; - int lr = wire / MAX_CONNS; - int lw = wire % MAX_CONNS; - - in_flight--; - - if (work_type == SEND_WR && send_count[wire] < n_steps) { - int64_t offset = lw * N + send_count[wire] * n_wires * N + - lr * n_wires * n_bytes_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_to(sz, buff, lr, lw); - in_flight++; - send_count[wire]++; - } - - else if (work_type == RECV_WR) { - int64_t offset = lw * N + recv_count[wire] * n_wires * N + - lr * n_wires * n_bytes_per_wire; - std::copy( - recv_buffer(sz, buff, lr, lw).begin(), - recv_buffer(sz, buff, lr, lw).begin() + - std::max(0, std::min(N, limits[lr] - offset)), - out_ptr + recv_offset[lr] + offset); - recv_count[wire]++; - if (recv_count[wire] + (PIPELINE - 1) < n_steps) { - recv_from(sz, buff, lr, lw); - in_flight++; - } - } - } - } - - send_offset[0] = (send_offset[0] + out_bytes - n_bytes) % out_bytes; - recv_offset[0] = (recv_offset[0] + out_bytes - n_bytes) % out_bytes; - send_offset[1] = (send_offset[1] + n_bytes) % out_bytes; - recv_offset[1] = (recv_offset[1] + n_bytes) % out_bytes; - for (int i = 0; i < 2 * MAX_CONNS; i++) { - send_count[i] = recv_count[i] = 0; - } - } + ring_.all_gather(in_ptr, out_ptr, n_bytes, n_conns_); }); } @@ -273,71 +172,8 @@ void RingGroup::send(const array& input, int dst, Stream stream) { int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); - encoder.dispatch([data, n_bytes, dst, left, this]() { - // In the case that size_ == 2 then left == right so we bias send towards - // left and recv towards right so that the selections will be correct for - // the 2 node case. - auto& conns = (dst == left) ? left_ : right_; - int dir = dst == left; - - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_CONNS; - - int n_wires = conns.size(); - int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; - auto [sz, N] = buffer_size_from_message(bytes_per_wire); - - int in_flight = 0; - int64_t read_offset[MAX_CONNS]; - int64_t limits[MAX_CONNS]; - for (int lw = 0; lw < n_wires; lw++) { - read_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); - limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); - } - - // Prefill the pipeline - for (int lw = 0; lw < n_wires; lw++) { - int buff = 0; - while (read_offset[lw] < limits[lw] && buff < PIPELINE) { - std::copy( - data + read_offset[lw], - data + std::min(read_offset[lw] + N, limits[lw]), - send_buffer(sz, buff, dir, lw).begin()); - send_to(sz, buff, dir, lw); - - buff++; - read_offset[lw] += N; - in_flight++; - } - } - - // Main loop - while (in_flight > 0) { - // Poll the hardware for completions. - // - // If a send was completed and we have more data to send then go ahead - // and send them. - ibv_wc wc[WC_NUM]; - int n = poll(conns, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int buff = (wc[i].wr_id >> 8) & 0xff; - int wire = wc[i].wr_id & 0xff; - int lw = wire % MAX_CONNS; - - in_flight--; - - if (read_offset[lw] < limits[lw]) { - std::copy( - data + read_offset[lw], - data + std::min(read_offset[lw] + N, limits[lw]), - send_buffer(sz, buff, dir, lw).begin()); - send_to(sz, buff, dir, lw); - - read_offset[lw] += N; - in_flight++; - } - } - } + encoder.dispatch([data, n_bytes, dst, this]() { + ring_.send(data, n_bytes, dst, n_conns_); }); } @@ -354,69 +190,8 @@ void RingGroup::recv(array& out, int src, Stream stream) { int64_t n_bytes = out.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(out); - encoder.dispatch([data, n_bytes, src, right, this]() { - // In the case that size_ == 2 then left == right so we bias send towards - // left and recv towards right so that the selections will be correct for - // the 2 node case. - auto& conns = (src == right) ? right_ : left_; - int dir = src == right; - - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_CONNS; - - int n_wires = conns.size(); - int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; - auto [sz, N] = buffer_size_from_message(bytes_per_wire); - - int in_flight = 0; - int64_t write_offset[MAX_CONNS]; - int64_t limits[MAX_CONNS]; - for (int lw = 0; lw < n_wires; lw++) { - write_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); - limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); - } - - // Prefill the pipeline - for (int lw = 0; lw < n_wires; lw++) { - int buff = 0; - while (N * buff < limits[lw] && buff < PIPELINE) { - recv_from(sz, buff, dir, lw); - - buff++; - in_flight++; - } - } - - // Main loop - while (in_flight > 0) { - // Poll the hardware for completions. - // - // If a recv was completed copy it to the output and if we have more - // data to fetch post another recv. - ibv_wc wc[WC_NUM]; - int n = poll(conns, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int buff = (wc[i].wr_id >> 8) & 0xff; - int wire = wc[i].wr_id & 0xff; - int lw = wire % MAX_CONNS; - - in_flight--; - - std::copy( - recv_buffer(sz, buff, dir, lw).begin(), - recv_buffer(sz, buff, dir, lw).begin() + - std::max( - 0, std::min(limits[lw] - write_offset[lw], N)), - data + write_offset[lw]); - write_offset[lw] += N; - - if (write_offset[lw] + (PIPELINE - 1) * N < limits[lw]) { - recv_from(sz, buff, dir, lw); - - in_flight++; - } - } - } + encoder.dispatch([data, n_bytes, src, this]() { + ring_.recv(data, n_bytes, src, n_conns_); }); } @@ -428,265 +203,25 @@ void RingGroup::all_reduce( ReduceOp reduce_op) { auto in_ptr = input.data(); auto out_ptr = output.data(); + int64_t size = input.size(); + int64_t n_bytes = input.nbytes(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(input); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { - if (size < size_ * 2 * left_.size()) { - all_reduce_impl<1, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op); + encoder.dispatch([in_ptr, out_ptr, size, n_bytes, this, reduce_op]() { + if (size < size_ * 2 * n_conns_) { + ring_.all_reduce<1, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op); return; } - all_reduce_impl<2, T, ReduceOp>( - in_ptr, out_ptr, size, left_.size(), reduce_op); - }); -} - -template -void RingGroup::all_reduce_impl( - const T* in_ptr, - T* out_ptr, - int64_t size, - int n_wires, - ReduceOp reduce_op) { - // If not inplace all reduce then copy the input to the output first - if (in_ptr != out_ptr) { - std::memcpy(out_ptr, in_ptr, size * sizeof(T)); - } - - constexpr int PIPELINE = 2; - constexpr int WC_NUM = PIPELINE * MAX_CONNS * 2 * MAX_DIR; - int64_t chunk_size = (size + size_ - 1) / size_; - int64_t size_per_wire = - (chunk_size + (MAX_DIR * n_wires) - 1) / (MAX_DIR * n_wires); - auto [sz, N] = buffer_size_from_message(size_per_wire * sizeof(T)); - N /= sizeof(T); - int64_t n_steps = (size_per_wire + N - 1) / N; - - // Counters to maintain the state of transfers - int in_flight = 0; - int64_t chunk_multiple_size = size_ * chunk_size; - int64_t send_offset[MAX_DIR]; - int64_t recv_offset[MAX_DIR]; - int64_t send_limits[MAX_DIR]; - int64_t recv_limits[MAX_DIR]; - int send_count[MAX_DIR * MAX_CONNS] = {0}; - int recv_count[MAX_DIR * MAX_CONNS] = {0}; - send_offset[0] = rank_ * chunk_size; - recv_offset[0] = ((rank_ + size_ - 1) % size_) * chunk_size; - if constexpr (MAX_DIR == 2) { - send_offset[1] = rank_ * chunk_size; - recv_offset[1] = ((rank_ + 1) % size_) * chunk_size; - send_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - send_offset[0])); - send_limits[1] = - std::min(chunk_size, std::max(0, size - send_offset[1])); - recv_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - recv_offset[0])); - recv_limits[1] = - std::min(chunk_size, std::max(0, size - recv_offset[1])); - } else { - send_limits[0] = - std::min(chunk_size, std::max(0, size - send_offset[0])); - recv_limits[0] = - std::min(chunk_size, std::max(0, size - recv_offset[0])); - } - - // First reduce scatter - // - // Possible perf improvement by not syncing at every step but running ahead - // as needed. - for (int k = 0; k < size_ - 1; k++) { - // Prefill the pipeline - int buff = 0; - while (buff < n_steps && buff < PIPELINE) { - post_recv_all(sz, buff, n_wires); - for (int lr = 0; lr < MAX_DIR; lr++) { - for (int lw = 0; lw < n_wires; lw++) { - int64_t offset = lw * N + - send_count[lr * MAX_CONNS + lw] * n_wires * N + - lr * n_wires * size_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, send_limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_count[lr * MAX_CONNS + lw]++; - } - } - post_send_all(sz, buff, n_wires); - - buff++; - in_flight += 2 * MAX_DIR * n_wires; - } - - // Main loop - // - // Keep going until we have no longer data in flight. - while (in_flight > 0) { - ibv_wc wc[WC_NUM]; - int n = poll(left_, right_, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int wire = wc[i].wr_id & 0xff; - int lr = wire / MAX_CONNS; - int lw = wire % MAX_CONNS; - - in_flight--; - - if (work_type == SEND_WR && send_count[wire] < n_steps) { - int64_t offset = lw * N + send_count[wire] * n_wires * N + - lr * n_wires * size_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, send_limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_to(sz, buff, lr, lw); - in_flight++; - send_count[wire]++; - } - - else if (work_type == RECV_WR) { - int64_t offset = lw * N + recv_count[wire] * n_wires * N + - lr * n_wires * size_per_wire; - reduce_op( - recv_buffer(sz, buff, lr, lw).begin(), - out_ptr + recv_offset[lr] + offset, - std::max(0, std::min(N, recv_limits[lr] - offset))); - recv_count[wire]++; - if (recv_count[wire] + (PIPELINE - 1) < n_steps) { - recv_from(sz, buff, lr, lw); - in_flight++; - } - } - } - } - - send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % - chunk_multiple_size; - recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % - chunk_multiple_size; - if constexpr (MAX_DIR == 2) { - send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; - recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; - send_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - send_offset[0])); - send_limits[1] = - std::min(chunk_size, std::max(0, size - send_offset[1])); - recv_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - recv_offset[0])); - recv_limits[1] = - std::min(chunk_size, std::max(0, size - recv_offset[1])); - } else { - send_limits[0] = - std::min(chunk_size, std::max(0, size - send_offset[0])); - recv_limits[0] = - std::min(chunk_size, std::max(0, size - recv_offset[0])); - } - for (int i = 0; i < MAX_DIR * MAX_CONNS; i++) { - send_count[i] = recv_count[i] = 0; - } - } - - // Secondly all gather - // - // The offsets are correct from the scatter reduce - for (int k = 0; k < size_ - 1; k++) { - // Prefill the pipeline - int buff = 0; - while (buff < n_steps && buff < PIPELINE) { - post_recv_all(sz, buff, n_wires); - for (int lr = 0; lr < MAX_DIR; lr++) { - for (int lw = 0; lw < n_wires; lw++) { - int64_t offset = lw * N + - send_count[lr * MAX_CONNS + lw] * n_wires * N + - lr * n_wires * size_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, send_limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_count[lr * MAX_CONNS + lw]++; - } - } - post_send_all(sz, buff, n_wires); - - buff++; - in_flight += 2 * MAX_DIR * n_wires; - } - - // Main loop - // - // Keep going until we have no longer data in flight. - while (in_flight > 0) { - ibv_wc wc[WC_NUM]; - int n = poll(left_, right_, WC_NUM, wc); - for (int i = 0; i < n; i++) { - int work_type = wc[i].wr_id >> 16; - int buff = (wc[i].wr_id >> 8) & 0xff; - int wire = wc[i].wr_id & 0xff; - int lr = wire / MAX_CONNS; - int lw = wire % MAX_CONNS; - - in_flight--; - - if (work_type == SEND_WR && send_count[wire] < n_steps) { - int64_t offset = lw * N + send_count[wire] * n_wires * N + - lr * n_wires * size_per_wire; - std::copy( - out_ptr + send_offset[lr] + offset, - out_ptr + send_offset[lr] + - std::max(offset, std::min(offset + N, send_limits[lr])), - send_buffer(sz, buff, lr, lw).begin()); - send_to(sz, buff, lr, lw); - in_flight++; - send_count[wire]++; - } - - else if (work_type == RECV_WR) { - int64_t offset = lw * N + recv_count[wire] * n_wires * N + - lr * n_wires * size_per_wire; - std::copy( - recv_buffer(sz, buff, lr, lw).begin(), - recv_buffer(sz, buff, lr, lw).begin() + - std::max(0, std::min(N, recv_limits[lr] - offset)), - out_ptr + recv_offset[lr] + offset); - recv_count[wire]++; - if (recv_count[wire] + (PIPELINE - 1) < n_steps) { - recv_from(sz, buff, lr, lw); - in_flight++; - } - } - } + if (n_bytes <= 65536) { + ring_.all_reduce<2, T, ReduceOp>(in_ptr, out_ptr, size, 1, reduce_op); + return; } - send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % - chunk_multiple_size; - recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % - chunk_multiple_size; - if constexpr (MAX_DIR == 2) { - send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; - recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; - send_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - send_offset[0])); - send_limits[1] = - std::min(chunk_size, std::max(0, size - send_offset[1])); - recv_limits[0] = std::min( - n_wires * size_per_wire, std::max(0, size - recv_offset[0])); - recv_limits[1] = - std::min(chunk_size, std::max(0, size - recv_offset[1])); - } else { - send_limits[0] = - std::min(chunk_size, std::max(0, size - send_offset[0])); - recv_limits[0] = - std::min(chunk_size, std::max(0, size - recv_offset[0])); - } - for (int i = 0; i < MAX_DIR * MAX_CONNS; i++) { - send_count[i] = recv_count[i] = 0; - } - } + ring_.all_reduce<2, T, ReduceOp>( + in_ptr, out_ptr, size, n_conns_, reduce_op); + }); } } // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/ring.h b/mlx/distributed/jaccl/ring.h index a59ceb3dd8..b3ce2f7b72 100644 --- a/mlx/distributed/jaccl/ring.h +++ b/mlx/distributed/jaccl/ring.h @@ -3,12 +3,11 @@ #pragma once #include "mlx/distributed/distributed_impl.h" +#include "mlx/distributed/jaccl/ring_impl.h" #include "mlx/distributed/jaccl/utils.h" using GroupImpl = mlx::core::distributed::detail::GroupImpl; -constexpr int MAX_CONNS = 4; - namespace mlx::core::distributed::jaccl { /** @@ -64,14 +63,6 @@ class RingGroup : public GroupImpl { Stream stream, ReduceOp reduce_op); - template - void all_reduce_impl( - const T* in_ptr, - T* out_ptr, - int64_t size, - int n_wires, - ReduceOp reduce_op); - /** * Performs the connection initialization. Namely, after this call all * Connection objects should have a queue pair in RTS state and all buffers @@ -84,95 +75,15 @@ class RingGroup : public GroupImpl { */ void allocate_buffers(); - void send_to(int sz, int buff, int left_right, int wire) { - if (left_right) { - left_[wire].post_send( - send_buffer_left(sz, buff, wire), - SEND_WR << 16 | buff << 8 | (MAX_CONNS + wire)); - } else { - right_[wire].post_send( - send_buffer_right(sz, buff, wire), SEND_WR << 16 | buff << 8 | wire); - } - } - - void recv_from(int sz, int buff, int left_right, int wire) { - if (left_right) { - right_[wire].post_recv( - recv_buffer_right(sz, buff, wire), - RECV_WR << 16 | buff << 8 | (MAX_CONNS + wire)); - } else { - left_[wire].post_recv( - recv_buffer_left(sz, buff, wire), RECV_WR << 16 | buff << 8 | wire); - } - } - - SharedBuffer& send_buffer_right(int sz, int buff, int wire) { - return send_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + wire]; - } - - SharedBuffer& send_buffer_left(int sz, int buff, int wire) { - return send_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + MAX_CONNS + - wire]; - } - - SharedBuffer& send_buffer(int sz, int buff, int left_right, int wire) { - return send_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + - left_right * MAX_CONNS + wire]; - } - - SharedBuffer& recv_buffer_left(int sz, int buff, int wire) { - return recv_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + wire]; - } - - SharedBuffer& recv_buffer_right(int sz, int buff, int wire) { - return recv_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + MAX_CONNS + - wire]; - } - - SharedBuffer& recv_buffer(int sz, int buff, int left_right, int wire) { - return recv_buffers_ - [sz * NUM_BUFFERS * MAX_CONNS * 2 + buff * MAX_CONNS * 2 + - left_right * MAX_CONNS + wire]; - } - - template - void post_recv_all(int sz, int buff, int n_wires) { - for (int lr = 0; lr < MAX_DIR; lr++) { - for (int lw = 0; lw < n_wires; lw++) { - recv_from(sz, buff, lr, lw); - } - } - } - - void post_recv_all(int sz, int buff) { - post_recv_all<2>(sz, buff, left_.size()); - } - - template - void post_send_all(int sz, int buff, int n_wires) { - for (int lr = 0; lr < MAX_DIR; lr++) { - for (int lw = 0; lw < n_wires; lw++) { - send_to(sz, buff, lr, lw); - } - } - } - - void post_send_all(int sz, int buff) { - post_send_all<2>(sz, buff, left_.size()); - } - int rank_; int size_; + int n_conns_; SideChannel side_channel_; std::vector left_; std::vector right_; std::vector send_buffers_; std::vector recv_buffers_; + RingImpl ring_; }; } // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/ring_impl.h b/mlx/distributed/jaccl/ring_impl.h new file mode 100644 index 0000000000..ce883d1fcc --- /dev/null +++ b/mlx/distributed/jaccl/ring_impl.h @@ -0,0 +1,631 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/jaccl/utils.h" + +constexpr int RING_MAX_CONNS = 4; + +namespace mlx::core::distributed::jaccl { + +class RingImpl { + public: + RingImpl( + int rank, + int size, + std::vector& left, + std::vector& right, + std::vector& send_buffers, + std::vector& recv_buffers) + : rank_(rank), + size_(size), + n_conns_(left.size()), + left_(left), + right_(right), + send_buffers_(send_buffers), + recv_buffers_(recv_buffers) {} + + RingImpl( + int rank, + int size, + Connection* left_begin, + Connection* right_begin, + size_t n_conns, + std::vector& send_buffers, + std::vector& recv_buffers) + : rank_(rank), + size_(size), + n_conns_(n_conns), + left_(left_begin, n_conns), + right_(right_begin, n_conns), + send_buffers_(send_buffers), + recv_buffers_(recv_buffers) {} + + RingImpl() : rank_(0), size_(1), n_conns_(0) {} + + template + void all_reduce( + const T* in_ptr, + T* out_ptr, + int64_t size, + int n_wires, + ReduceOp reduce_op) { + // If not inplace all reduce then copy the input to the output first + if (in_ptr != out_ptr) { + std::memcpy(out_ptr, in_ptr, size * sizeof(T)); + } + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * MAX_DIR; + int64_t chunk_size = (size + size_ - 1) / size_; + int64_t size_per_wire = + (chunk_size + (MAX_DIR * n_wires) - 1) / (MAX_DIR * n_wires); + auto [sz, N] = buffer_size_from_message(size_per_wire * sizeof(T)); + N /= sizeof(T); + int64_t n_steps = (size_per_wire + N - 1) / N; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t chunk_multiple_size = size_ * chunk_size; + int64_t send_offset[MAX_DIR]; + int64_t recv_offset[MAX_DIR]; + int64_t send_limits[MAX_DIR]; + int64_t recv_limits[MAX_DIR]; + int send_count[MAX_DIR * RING_MAX_CONNS] = {0}; + int recv_count[MAX_DIR * RING_MAX_CONNS] = {0}; + send_offset[0] = rank_ * chunk_size; + recv_offset[0] = ((rank_ + size_ - 1) % size_) * chunk_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = rank_ * chunk_size; + recv_offset[1] = ((rank_ + 1) % size_) * chunk_size; + send_limits[0] = std::min( + n_wires * size_per_wire, std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + + // First reduce scatter + // + // Possible perf improvement by not syncing at every step but running ahead + // as needed. + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff, n_wires); + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff, n_wires); + + buff++; + in_flight += 2 * MAX_DIR * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + reduce_op( + recv_buffer(sz, buff, lr, lw).begin(), + out_ptr + recv_offset[lr] + offset, + std::max(0, std::min(N, recv_limits[lr] - offset))); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; + recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; + send_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + + // Secondly all gather + // + // The offsets are correct from the scatter reduce + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff, n_wires); + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff, n_wires); + + buff++; + in_flight += 2 * MAX_DIR * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, send_limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * size_per_wire; + std::copy( + recv_buffer(sz, buff, lr, lw).begin(), + recv_buffer(sz, buff, lr, lw).begin() + + std::max(0, std::min(N, recv_limits[lr] - offset)), + out_ptr + recv_offset[lr] + offset); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + recv_offset[0] = (recv_offset[0] + chunk_multiple_size - chunk_size) % + chunk_multiple_size; + if constexpr (MAX_DIR == 2) { + send_offset[1] = (send_offset[1] + chunk_size) % chunk_multiple_size; + recv_offset[1] = (recv_offset[1] + chunk_size) % chunk_multiple_size; + send_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - send_offset[0])); + send_limits[1] = + std::min(chunk_size, std::max(0, size - send_offset[1])); + recv_limits[0] = std::min( + n_wires * size_per_wire, + std::max(0, size - recv_offset[0])); + recv_limits[1] = + std::min(chunk_size, std::max(0, size - recv_offset[1])); + } else { + send_limits[0] = + std::min(chunk_size, std::max(0, size - send_offset[0])); + recv_limits[0] = + std::min(chunk_size, std::max(0, size - recv_offset[0])); + } + for (int i = 0; i < MAX_DIR * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + } + + void + all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes, int n_wires) { + // Copy our data to the appropriate place + std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes); + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS * 2 * 2; + size_t n_bytes_per_wire = (n_bytes + (2 * n_wires) - 1) / (2 * n_wires); + size_t out_bytes = n_bytes * size_; + auto [sz, N] = buffer_size_from_message(n_bytes_per_wire); + int n_steps = (n_bytes_per_wire + N - 1) / N; + + // Counters to maintain the state of transfers + int in_flight = 0; + int64_t send_offset[2]; + int64_t recv_offset[2]; + int64_t limits[2]; + int send_count[2 * RING_MAX_CONNS] = {0}; + int recv_count[2 * RING_MAX_CONNS] = {0}; + send_offset[0] = send_offset[1] = rank_ * n_bytes; + recv_offset[0] = ((rank_ + size_ - 1) % size_) * n_bytes; + recv_offset[1] = ((rank_ + 1) % size_) * n_bytes; + limits[0] = n_wires * n_bytes_per_wire; + limits[1] = n_bytes; + + // Possible perf improvement by not syncing at every step but running ahead + // as needed. + for (int k = 0; k < size_ - 1; k++) { + // Prefill the pipeline + int buff = 0; + while (buff < n_steps && buff < PIPELINE) { + post_recv_all(sz, buff); + for (int lr = 0; lr < 2; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + int64_t offset = lw * N + + send_count[lr * RING_MAX_CONNS + lw] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_count[lr * RING_MAX_CONNS + lw]++; + } + } + post_send_all(sz, buff); + + buff++; + in_flight += 2 * 2 * n_wires; + } + + // Main loop + // + // Keep going until we have no longer data in flight. + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = poll(left_, right_, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lr = wire / RING_MAX_CONNS; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (work_type == SEND_WR && send_count[wire] < n_steps) { + int64_t offset = lw * N + send_count[wire] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + out_ptr + send_offset[lr] + offset, + out_ptr + send_offset[lr] + + std::max(offset, std::min(offset + N, limits[lr])), + send_buffer(sz, buff, lr, lw).begin()); + send_to(sz, buff, lr, lw); + in_flight++; + send_count[wire]++; + } + + else if (work_type == RECV_WR) { + int64_t offset = lw * N + recv_count[wire] * n_wires * N + + lr * n_wires * n_bytes_per_wire; + std::copy( + recv_buffer(sz, buff, lr, lw).begin(), + recv_buffer(sz, buff, lr, lw).begin() + + std::max(0, std::min(N, limits[lr] - offset)), + out_ptr + recv_offset[lr] + offset); + recv_count[wire]++; + if (recv_count[wire] + (PIPELINE - 1) < n_steps) { + recv_from(sz, buff, lr, lw); + in_flight++; + } + } + } + } + + send_offset[0] = (send_offset[0] + out_bytes - n_bytes) % out_bytes; + recv_offset[0] = (recv_offset[0] + out_bytes - n_bytes) % out_bytes; + send_offset[1] = (send_offset[1] + n_bytes) % out_bytes; + recv_offset[1] = (recv_offset[1] + n_bytes) % out_bytes; + for (int i = 0; i < 2 * RING_MAX_CONNS; i++) { + send_count[i] = recv_count[i] = 0; + } + } + } + + void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) { + int left = (rank_ + size_ - 1) % size_; + + // In the case that size_ == 2 then left == right so we bias send towards + // left and recv towards right so that the selections will be correct for + // the 2 node case. + auto& conns = (dst == left) ? left_ : right_; + int dir = dst == left; + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; + + int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; + auto [sz, N] = buffer_size_from_message(bytes_per_wire); + + int in_flight = 0; + int64_t read_offset[RING_MAX_CONNS]; + int64_t limits[RING_MAX_CONNS]; + for (int lw = 0; lw < n_wires; lw++) { + read_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); + limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); + } + + // Prefill the pipeline + for (int lw = 0; lw < n_wires; lw++) { + int buff = 0; + while (read_offset[lw] < limits[lw] && buff < PIPELINE) { + std::copy( + in_ptr + read_offset[lw], + in_ptr + std::min(read_offset[lw] + N, limits[lw]), + send_buffer(sz, buff, dir, lw).begin()); + send_to(sz, buff, dir, lw); + + buff++; + read_offset[lw] += N; + in_flight++; + } + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed and we have more data to send then go ahead + // and send them. + ibv_wc wc[WC_NUM]; + int n = poll(conns, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + if (read_offset[lw] < limits[lw]) { + std::copy( + in_ptr + read_offset[lw], + in_ptr + std::min(read_offset[lw] + N, limits[lw]), + send_buffer(sz, buff, dir, lw).begin()); + send_to(sz, buff, dir, lw); + + read_offset[lw] += N; + in_flight++; + } + } + } + } + + void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) { + int right = (rank_ + 1) % size_; + + // In the case that size_ == 2 then left == right so we bias send towards + // left and recv towards right so that the selections will be correct for + // the 2 node case. + auto& conns = (src == right) ? right_ : left_; + int dir = src == right; + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * RING_MAX_CONNS; + + int64_t bytes_per_wire = (n_bytes + n_wires - 1) / n_wires; + auto [sz, N] = buffer_size_from_message(bytes_per_wire); + + int in_flight = 0; + int64_t write_offset[RING_MAX_CONNS]; + int64_t limits[RING_MAX_CONNS]; + for (int lw = 0; lw < n_wires; lw++) { + write_offset[lw] = std::min(lw * bytes_per_wire, n_bytes); + limits[lw] = std::min((lw + 1) * bytes_per_wire, n_bytes); + } + + // Prefill the pipeline + for (int lw = 0; lw < n_wires; lw++) { + int buff = 0; + while (N * buff < limits[lw] && buff < PIPELINE) { + recv_from(sz, buff, dir, lw); + + buff++; + in_flight++; + } + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a recv was completed copy it to the output and if we have more + // data to fetch post another recv. + ibv_wc wc[WC_NUM]; + int n = poll(conns, WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int wire = wc[i].wr_id & 0xff; + int lw = wire % RING_MAX_CONNS; + + in_flight--; + + std::copy( + recv_buffer(sz, buff, dir, lw).begin(), + recv_buffer(sz, buff, dir, lw).begin() + + std::max( + 0, std::min(limits[lw] - write_offset[lw], N)), + out_ptr + write_offset[lw]); + write_offset[lw] += N; + + if (write_offset[lw] + (PIPELINE - 1) * N < limits[lw]) { + recv_from(sz, buff, dir, lw); + + in_flight++; + } + } + } + } + + private: + void send_to(int sz, int buff, int left_right, int wire) { + if (left_right) { + left_[wire].post_send( + send_buffer_left(sz, buff, wire), + SEND_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); + } else { + right_[wire].post_send( + send_buffer_right(sz, buff, wire), SEND_WR << 16 | buff << 8 | wire); + } + } + + void recv_from(int sz, int buff, int left_right, int wire) { + if (left_right) { + right_[wire].post_recv( + recv_buffer_right(sz, buff, wire), + RECV_WR << 16 | buff << 8 | (RING_MAX_CONNS + wire)); + } else { + left_[wire].post_recv( + recv_buffer_left(sz, buff, wire), RECV_WR << 16 | buff << 8 | wire); + } + } + + SharedBuffer& send_buffer_right(int sz, int buff, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; + } + + SharedBuffer& send_buffer_left(int sz, int buff, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + + wire]; + } + + SharedBuffer& send_buffer(int sz, int buff, int left_right, int wire) { + return send_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + + left_right * n_conns_ + wire]; + } + + SharedBuffer& recv_buffer_left(int sz, int buff, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + wire]; + } + + SharedBuffer& recv_buffer_right(int sz, int buff, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + n_conns_ + + wire]; + } + + SharedBuffer& recv_buffer(int sz, int buff, int left_right, int wire) { + return recv_buffers_ + [sz * NUM_BUFFERS * n_conns_ * 2 + buff * n_conns_ * 2 + + left_right * n_conns_ + wire]; + } + + template + void post_recv_all(int sz, int buff, int n_wires) { + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + recv_from(sz, buff, lr, lw); + } + } + } + + void post_recv_all(int sz, int buff) { + post_recv_all<2>(sz, buff, n_conns_); + } + + template + void post_send_all(int sz, int buff, int n_wires) { + for (int lr = 0; lr < MAX_DIR; lr++) { + for (int lw = 0; lw < n_wires; lw++) { + send_to(sz, buff, lr, lw); + } + } + } + + void post_send_all(int sz, int buff) { + post_send_all<2>(sz, buff, n_conns_); + } + + int rank_; + int size_; + int n_conns_; + std::span left_; + std::span right_; + std::span send_buffers_; + std::span recv_buffers_; +}; + +} // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/utils.h b/mlx/distributed/jaccl/utils.h index 5b700f49b1..8faa774058 100644 --- a/mlx/distributed/jaccl/utils.h +++ b/mlx/distributed/jaccl/utils.h @@ -4,6 +4,7 @@ #include +#include #include #include @@ -221,7 +222,7 @@ std::vector create_connections( const std::vector& device_names); inline int poll( - const std::vector& connections, + std::span connections, int num_completions, ibv_wc* work_completions) { int completions = 0; @@ -244,8 +245,8 @@ inline int poll( } inline int poll( - const std::vector& connections_1, - const std::vector& connections_2, + std::span connections_1, + std::span connections_2, int num_completions, ibv_wc* work_completions) { int completions = 0;