Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlx/backend/cpu/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,9 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream());
}

void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed
65 changes: 65 additions & 0 deletions mlx/backend/cuda/distributed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported.");
}
}

void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};

auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));

encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);

auto capture = encoder.capture_context();
distributed::detail::all_gather(group(), input, outputs[0], s);
}

void ReduceScatter::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};

auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));

encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);

auto capture = encoder.capture_context();

switch (reduce_type_) {
case Sum:
distributed::detail::sum_scatter(group(), input, outputs[0], s);
break;
default:
throw std::runtime_error("Only sum scatter is supported. ");
}
}
} // namespace mlx::core::distributed
1 change: 0 additions & 1 deletion mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)

namespace distributed {
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/metal/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
}

void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error(
"[ReduceScatter::eval_gpu] has no GPU implementation.");
}

} // namespace mlx::core::distributed
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
NO_CPU_MULTI(AllGather)
NO_CPU_MULTI(Send)
NO_CPU_MULTI(Recv)
NO_CPU_MULTI(ReduceScatter)
} // namespace distributed

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
NO_GPU_MULTI(ReduceScatter)
} // namespace distributed

} // namespace mlx::core
12 changes: 12 additions & 0 deletions mlx/distributed/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
group.raw_group()->recv(out, src, stream);
}

void sum_scatter(
Group group,
const array& input,
array& output,
Stream stream) {
group.raw_group()->sum_scatter(input, output, stream);
}

class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
Expand Down Expand Up @@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void sum_scatter(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};

} // namespace detail
Expand Down
5 changes: 5 additions & 0 deletions mlx/distributed/distributed_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class GroupImpl {
virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
virtual void
sum_scatter(const array& input, array& output, Stream stream) = 0;
};

/* Define the MLX stream that the communication should happen in. */
Expand All @@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);

/** Reduce scatter with average operation */
void sum_scatter(Group group, const array& input, array& output, Stream stream);

} // namespace mlx::core::distributed::detail
4 changes: 4 additions & 0 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
});
}

void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
}

private:
MPI_Comm comm_;
bool global_;
Expand Down
49 changes: 45 additions & 4 deletions mlx/distributed/nccl/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
}

void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllGather(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
comm_,
encoder.stream()));
});
}

void send(const array& input, int dst, Stream stream) override {
Expand All @@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
}

void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
});
}

void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
});
}

void sum_scatter(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
});
}

template <typename T>
Expand All @@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
encoder.stream()));
}

template <typename T>
void reduce_scatter_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);

CHECK_NCCL(ncclReduceScatter(
input.data<T>(),
output.data<T>(),
output.size(),
dt,
op,
comm_,
encoder.stream()));
}

int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
Expand Down
26 changes: 26 additions & 0 deletions mlx/distributed/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,30 @@ array recv_like(
return recv(x.shape(), x.dtype(), src, group_, s);
}

array sum_scatter(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
if (x.shape()[0] % group.size() != 0) {
std::ostringstream msg;
msg << "[sum_scatter] Invalid shape=" << x.shape()
<< " for a group of size " << group.size()
<< ". The first dimension (axis 0) must be divisible by the group size.";
throw std::invalid_argument(msg.str());
}

auto result_shape = x.shape();
result_shape[0] /= group.size();
auto stream = detail::communication_stream(group, s);

return array(
std::move(result_shape),
x.dtype(),
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
{x});
}
} // namespace mlx::core::distributed
5 changes: 5 additions & 0 deletions mlx/distributed/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,9 @@ array all_min(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});

array sum_scatter(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});

} // namespace mlx::core::distributed
26 changes: 26 additions & 0 deletions mlx/distributed/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
int src_;
};

class ReduceScatter : public DistPrimitive {
public:
enum ReduceType { Sum, Min, Max };
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

const char* name() const override {
switch (reduce_type_) {
case Sum:
return "Sum ReduceScatter";
case Min:
return "Min ReduceScatter";
case Max:
return "Max ReduceScatter";
}
return "<unknwon ReduceScatter>";
}

private:
ReduceType reduce_type_;
};
} // namespace mlx::core::distributed
4 changes: 4 additions & 0 deletions mlx/distributed/ring/ring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
});
}

void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] sum_scatter not supported.");
}

private:
template <typename T, typename ReduceOp>
void all_reduce(
Expand Down
34 changes: 33 additions & 1 deletion python/src/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) {
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
send. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Expand Down Expand Up @@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The array that was received from ``src``.
)pbdoc");

m.def(
"sum_scatter",
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::sum_scatter(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sum ``x`` across all processes in the group and shard the result along the first axis across ranks.
``x.shape[0]`` must be divisible by the group size.

The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group.
Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead.
Currently supported only for the NCCL backend.

Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
sum scatter. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``.
)pbdoc");
}
Loading