diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index d641d581ba..22dc4b4cc8 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -95,4 +95,9 @@ void Recv::eval_cpu( distributed::detail::recv(group(), outputs[0], src_, stream()); } +void ReduceScatter::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[ReduceScatter] Not implemented yet."); +} } // namespace mlx::core::distributed diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 4d2658534f..07a5ed10f7 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -53,4 +53,69 @@ void AllReduce::eval_gpu( "Only all reduce sum, max, and min are supported."); } } + +void AllGather::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().row_contiguous) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto input = ensure_contiguous(inputs[0]); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + + encoder.set_input_array(input); + encoder.set_output_array(outputs[0]); + + auto capture = encoder.capture_context(); + distributed::detail::all_gather(group(), input, outputs[0], s); +} + +void ReduceScatter::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().row_contiguous) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto input = ensure_contiguous(inputs[0]); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + + encoder.set_input_array(input); + encoder.set_output_array(outputs[0]); + + auto capture = encoder.capture_context(); + + switch (reduce_type_) { + case Sum: + distributed::detail::sum_scatter(group(), input, outputs[0], s); + break; + default: + throw std::runtime_error("Only sum scatter is supported. "); + } +} } // namespace mlx::core::distributed diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 77c2956655..43a60eedbb 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace distributed { -NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) } // namespace distributed diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index a800d2e0fe..217ee3c946 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation."); } +void ReduceScatter::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error( + "[ReduceScatter::eval_gpu] has no GPU implementation."); +} + } // namespace mlx::core::distributed diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 4d373bd1a2..b32e074e8f 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce) NO_CPU_MULTI(AllGather) NO_CPU_MULTI(Send) NO_CPU_MULTI(Recv) +NO_CPU_MULTI(ReduceScatter) } // namespace distributed } // namespace mlx::core diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index a57df046cd..406a627b9c 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) +NO_GPU_MULTI(ReduceScatter) } // namespace distributed } // namespace mlx::core diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index d71ebb9b12..2f5ea80292 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) { group.raw_group()->recv(out, src, stream); } +void sum_scatter( + Group group, + const array& input, + array& output, + Stream stream) { + group.raw_group()->sum_scatter(input, output, stream); +} + class EmptyGroup : public GroupImpl { public: Stream communication_stream(StreamOrDevice s) override { @@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } + void sum_scatter(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } }; } // namespace detail diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index c90b0ba47b..d889587abc 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -28,6 +28,8 @@ class GroupImpl { virtual void recv(array& out, int src, Stream stream) = 0; virtual void all_max(const array& input, array& output, Stream stream) = 0; virtual void all_min(const array& input, array& output, Stream stream) = 0; + virtual void + sum_scatter(const array& input, array& output, Stream stream) = 0; }; /* Define the MLX stream that the communication should happen in. */ @@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream); /** Min reduction */ void all_min(Group group, const array& input, array& output, Stream stream); +/** Reduce scatter with average operation */ +void sum_scatter(Group group, const array& input, array& output, Stream stream); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 494fb02dcc..bf87425e48 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl { }); } + void sum_scatter(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[mpi] sum_scatter not yet implemented."); + } + private: MPI_Comm comm_; bool global_; diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 8a5376242f..71fc8b3bd3 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl { } void all_gather(const array& input, array& output, Stream stream) override { - throw std::runtime_error( - "[nccl] All gather not supported in NCCL backend."); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + auto& encoder = cu::get_command_encoder(stream); + CHECK_NCCL(ncclAllGather( + input.data(), + output.data(), + input.size(), + dt, + comm_, + encoder.stream())); + }); } void send(const array& input, int dst, Stream stream) override { @@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl { } void all_max(const array& input, array& output, Stream stream) override { - throw std::runtime_error("[nccl] All max not supported in NCCL backend."); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMax); + }); } void all_min(const array& input, array& output, Stream stream) override { - throw std::runtime_error("[nccl] All min not supported in NCCL backend."); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMin); + }); + } + + void sum_scatter(const array& input, array& output, Stream stream) override { + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + reduce_scatter_impl(input, output, stream, dt, ncclSum); + }); } template @@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl { encoder.stream())); } + template + void reduce_scatter_impl( + const array& input, + array& output, + Stream stream, + ncclDataType_t dt, + ncclRedOp_t op) { + auto& encoder = cu::get_command_encoder(stream); + + CHECK_NCCL(ncclReduceScatter( + input.data(), + output.data(), + output.size(), + dt, + op, + comm_, + encoder.stream())); + } + int rank_, size_; std::string initMethod_; ncclUniqueId uniqueId_; diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 157bc26129..1762f0e6bc 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -157,4 +157,30 @@ array recv_like( return recv(x.shape(), x.dtype(), src, group_, s); } +array sum_scatter( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + if (group.size() == 1) { + return x; + } + if (x.shape()[0] % group.size() != 0) { + std::ostringstream msg; + msg << "[sum_scatter] Invalid shape=" << x.shape() + << " for a group of size " << group.size() + << ". The first dimension (axis 0) must be divisible by the group size."; + throw std::invalid_argument(msg.str()); + } + + auto result_shape = x.shape(); + result_shape[0] /= group.size(); + auto stream = detail::communication_stream(group, s); + + return array( + std::move(result_shape), + x.dtype(), + std::make_shared(stream, group, ReduceScatter::Sum), + {x}); +} } // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index edd1fc9f40..7688a5f1c2 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -48,4 +48,9 @@ array all_min( std::optional group = std::nullopt, StreamOrDevice s = {}); +array sum_scatter( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 7ad00a0d63..18a0d65f5f 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -127,4 +127,30 @@ class Recv : public DistPrimitive { int src_; }; +class ReduceScatter : public DistPrimitive { + public: + enum ReduceType { Sum, Min, Max }; + ReduceScatter(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "Sum ReduceScatter"; + case Min: + return "Min ReduceScatter"; + case Max: + return "Max ReduceScatter"; + } + return ""; + } + + private: + ReduceType reduce_type_; +}; } // namespace mlx::core::distributed diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ac55ea30b0..23537c4d74 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -731,6 +731,10 @@ class RingGroup : public GroupImpl { }); } + void sum_scatter(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[ring] sum_scatter not supported."); + } + private: template void all_reduce( diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index e2e191dbbd..d147c27836 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -229,7 +229,7 @@ void init_distributed(nb::module_& parent_module) { x (array): Input array. dst (int): Rank of the destination process in the group. group (Group): The group of processes that will participate in the - sned. If set to ``None`` the global group is used. Default: + send. If set to ``None`` the global group is used. Default: ``None``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. @@ -301,4 +301,36 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The array that was received from ``src``. )pbdoc"); + + m.def( + "sum_scatter", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::sum_scatter(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def sum_scatter(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Sum ``x`` across all processes in the group and shard the result along the first axis across ranks. + ``x.shape[0]`` must be divisible by the group size. + + The result is equivalent to ``all_sum(x)[rank*chunk_size:(rank+1)*chunk_size]``, where ``chunk_size = x.shape[0] // group.size()`` and ``rank`` is the rank of this process in the group. + Note: ``all_sum`` is mentioned only for illustration; the actual implementation does not perform ``all_sum`` and uses a single reduce-scatter collective instead. + Currently supported only for the NCCL backend. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + sum scatter. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + Returns: + array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``. + )pbdoc"); } diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py index 5feb51bc9c..8db4566625 100644 --- a/python/tests/mlx_distributed_tests.py +++ b/python/tests/mlx_distributed_tests.py @@ -1,7 +1,5 @@ # Copyright © 2025 Apple Inc. -import unittest - import mlx.core as mx import mlx.nn as nn import mlx_tests @@ -63,6 +61,48 @@ def new_all_sum(x, **kwargs): finally: mx.distributed.all_sum = original_all_sum + def test_all_reduce(self): + g = mx.distributed.init() + dtypes = [ + (mx.int8, 0), + (mx.uint8, 0), + (mx.int32, 0), + (mx.uint32, 0), + (mx.float32, 1e-6), + (mx.float16, 5e-3), + (mx.bfloat16, 1e-1), + ] + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + + for dt, rtol in dtypes: + for sh in sizes: + x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt) + + # All sum + y = mx.distributed.all_sum(x[g.rank()], group=g) + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() + self.assertLessEqual(maxrelerror, rtol) + + # All max + y = mx.distributed.all_max(x[g.rank()], group=g) + z = x.max(0) + self.assertTrue(mx.all(y == z)) + + # All min + y = mx.distributed.all_min(x[g.rank()], group=g) + z = x.min(0) + self.assertTrue(mx.all(y == z)) + def test_donation(self): x = mx.random.normal((1024,)) mx.eval(x) @@ -103,18 +143,19 @@ def test_shard_linear(self): y = lin(x) y1 = slin1(x) y2 = slin2(x[part]) - self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) - self.assertTrue(mx.allclose(y[part], y1)) - - # And their quant versions - qlin = lin.to_quantized() - slin1 = shard_linear(qlin, "all-to-sharded") - slin2 = shard_linear(qlin, "sharded-to-all") - y = qlin(x) - y1 = slin1(x) - y2 = slin2(x[part]) - self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) - self.assertTrue(mx.allclose(y[part], y1)) + self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol)) + self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol)) + + # And their quant versions (QuintizedMatmul is not supported on CUDA) + if not mx.cuda.is_available(): + qlin = lin.to_quantized() + slin1 = shard_linear(qlin, "all-to-sharded") + slin2 = shard_linear(qlin, "sharded-to-all") + y = qlin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol)) + self.assertTrue(mx.allclose(y[part], y1)) # Check the backward works as expected def dummy_loss(model, x, y): @@ -197,12 +238,18 @@ def dummy_loss(model, x, y): ) self.assertTrue( mx.allclose( - g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 + g1["layers"][1]["bias"], + g2["layers"][1]["bias"], + atol=self.atol, + rtol=self.rtol, ) ) self.assertTrue( mx.allclose( - g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 + g1["layers"][3]["bias"], + g2["layers"][3]["bias"], + atol=self.atol, + rtol=self.rtol, ) ) @@ -248,3 +295,20 @@ def sharding(path, weight): y1 = mod(x) y2 = smod(x) self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) + + def test_all_gather(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int32, + mx.uint32, + mx.float32, + mx.float16, + mx.bfloat16, + ] + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_gather(x) + self.assertEqual(y.shape, (world.size() * 2, 2, 4)) + self.assertTrue(mx.all(y == 1)) diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 26d340dbea..e8a6aaa615 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -1,15 +1,16 @@ # Copyright © 2024 Apple Inc. -import unittest - import mlx.core as mx import mlx_distributed_tests +import mlx_tests class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): @classmethod def setUpClass(cls): - world = mx.distributed.init(strict=True, backend="mpi") + _ = mx.distributed.init(strict=True, backend="mpi") + cls.atol = 1e-6 + cls.rtol = 1e-4 def test_groups(self): world = mx.distributed.init() @@ -27,18 +28,11 @@ def test_groups(self): sub = world.split(world.rank() // 2) self.assertEqual(sub.size(), 2) - def test_all_reduce(self): + def test_all_reduce_extra(self): world = mx.distributed.init() dtypes = [ - (mx.int8, 0), - (mx.uint8, 0), (mx.int16, 0), (mx.uint16, 0), - (mx.int32, 0), - (mx.uint32, 0), - (mx.float32, 1e-6), - (mx.float16, 5e-3), - (mx.bfloat16, 1e-1), (mx.complex64, 1e-6), ] sizes = [ @@ -76,16 +70,11 @@ def test_all_reduce(self): z = x.min(0) self.assertTrue(mx.all(y == z)) - def test_all_gather(self): + def test_all_gather_extra(self): world = mx.distributed.init() dtypes = [ - mx.int8, - mx.uint8, mx.int16, mx.uint16, - mx.int32, - mx.uint32, - mx.float32, mx.complex64, ] for dt in dtypes: @@ -150,4 +139,4 @@ def test_send_recv(self): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/nccl_test_distributed.py b/python/tests/nccl_test_distributed.py index c9461d118c..e6eed8a99c 100644 --- a/python/tests/nccl_test_distributed.py +++ b/python/tests/nccl_test_distributed.py @@ -1,283 +1,52 @@ # Copyright © 2024 Apple Inc. + import mlx.core as mx -import mlx.nn as nn +import mlx_distributed_tests import mlx_tests -from mlx.nn.layers.distributed import shard_inplace, shard_linear -from mlx.nn.utils import average_gradients -class TestNCCLDistributed(mlx_tests.MLXTestCase): +class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): @classmethod def setUpClass(cls): - world = mx.distributed.init(strict=True, backend="nccl") - rank = world.rank() - mx.set_default_device(mx.Device(mx.gpu, rank % 8)) + _ = mx.distributed.init(strict=True, backend="nccl") + cls.atol = 1e-4 + cls.rtol = 1e-4 + + def test_sum_scatter(self): - def test_all_reduce(self): world = mx.distributed.init() + dtypes = [ - (mx.int8, 0), - (mx.uint8, 0), - (mx.int32, 0), - (mx.uint32, 0), (mx.float32, 1e-6), (mx.float16, 5e-3), (mx.bfloat16, 1e-1), ] sizes = [ - (7,), - (10,), + (8,), + (64,), (1024,), (1024, 1024), ] - key = mx.random.key(0) + key = mx.random.key(world.rank()) for dt, rtol in dtypes: for sh in sizes: - x = ( - mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 - ).astype(dt) + x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh - # All sum - y = mx.distributed.all_sum(x[world.rank()]) - z = x.sum(0) - maxrelerror = (y - z).abs() + # Sum scatter + y = mx.distributed.sum_scatter(x) # shape=sh/world.size() + z = mx.distributed.all_sum(x) # shape=sh + chunk = sh[0] // world.size() + start = world.rank() * chunk + stop = start + chunk + z_ref = z[start:stop] + + maxrelerror = (y - z_ref).abs() if rtol > 0: - maxrelerror /= z.abs() + maxrelerror /= z_ref.abs() maxrelerror = maxrelerror.max() self.assertLessEqual(maxrelerror, rtol) - def test_average_gradients(self): - original_all_sum = mx.distributed.all_sum - n_calls = 0 - xtype = None - - def new_all_sum(x, **kwargs): - nonlocal n_calls - nonlocal xtype - - n_calls += 1 - if xtype is not None: - self.assertEqual(xtype, x.dtype) - - return original_all_sum(x, **kwargs) - - mx.distributed.all_sum = new_all_sum - try: - grads = [mx.ones(10) for i in range(10)] - new_grads = average_gradients(grads) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 1) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=4 * 50) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=0) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 10) - - n_calls = 0 - xtype = mx.float16 - new_grads = average_gradients( - grads, - all_reduce_size=2 * 50, - communication_type=mx.float16, - ) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - finally: - mx.distributed.all_sum = original_all_sum - - def test_donation(self): - x = mx.random.normal((1024,)) - mx.eval(x) - mx.synchronize() - - mx.reset_peak_memory() - scale = mx.array(2.0) - y = mx.distributed.all_sum(x) - mx.eval(y) - mx.synchronize() - all_sum_only = mx.get_peak_memory() - y = mx.distributed.all_sum(x) * scale - mx.eval(y) - mx.synchronize() - all_sum_with_binary = mx.get_peak_memory() - - self.assertEqual(all_sum_only, all_sum_with_binary) - - def test_shard_linear(self): - # Seed the prng to have the same inputs and weights generated everywhere - mx.random.seed(0xF0F0F0F0) - - # Prepare inputs - world = mx.distributed.init() - part = ( - slice(None), - slice( - world.rank() * 1024 // world.size(), - (world.rank() + 1) * 1024 // world.size(), - ), - ) - x = mx.random.normal((4, 1024)) - - # Create and shard some linear layers - lin = nn.Linear(1024, 1024, bias=True) - slin1 = shard_linear(lin, "all-to-sharded") - slin2 = shard_linear(lin, "sharded-to-all") - y = lin(x) - y1 = slin1(x) - y2 = slin2(x[part]) - self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4)) - self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4)) - - # Check the backward works as expected - def dummy_loss(model, x, y): - return (model(x) * y).sum() - - mod = nn.Sequential( - nn.Linear(128, 128), - nn.Linear(128, 128), - nn.Linear(128, 128), - nn.Linear(128, 128), - ) - smod = nn.Sequential( - shard_linear(mod.layers[0], "all-to-sharded"), - shard_linear(mod.layers[1], "sharded-to-all"), - shard_linear(mod.layers[2], "all-to-sharded"), - shard_linear(mod.layers[3], "sharded-to-all"), - ) - - grad1 = nn.value_and_grad(mod, dummy_loss) - grad2 = nn.value_and_grad(smod, dummy_loss) - - x = mx.random.normal((4, 128)) - y = mx.random.normal((4, 128)) - - l1, g1 = grad1(mod, x, y) - l2, g2 = grad2(smod, x, y) - mx.eval(l1, g1, l2, g2) - - part = slice( - world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() - ) - - self.assertTrue(mx.allclose(l1, l2)) - self.assertTrue( - mx.allclose( - g1["layers"][0]["weight"][part], - g2["layers"][0]["weight"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][2]["weight"][part], - g2["layers"][2]["weight"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][1]["weight"][:, part], - g2["layers"][1]["weight"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][3]["weight"][:, part], - g2["layers"][3]["weight"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][0]["bias"][part], - g2["layers"][0]["bias"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][2]["bias"][part], - g2["layers"][2]["bias"], - atol=1e-6, - rtol=1e-4, - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 - ) - ) - self.assertTrue( - mx.allclose( - g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 - ) - ) - - def test_shard_predicate(self): - mx.random.seed(0xF0F0F0F0) - - class MyConv(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.aggregate = kwargs.pop("aggregate", False) - self.conv = nn.Conv2d(*args, **kwargs) - - def __call__(self, x): - x = self.conv(x) - if self.aggregate: - x = mx.distributed.all_sum(x) - return x - - def sharding(path, weight): - parts = path.split(".") - even = int(parts[1]) % 2 == 0 - if even: - return 0 - else: - return -1 if parts[-1] != "bias" else None - - mod = nn.Sequential( - MyConv(3, 128, kernel_size=3), - MyConv(128, 128, kernel_size=3), - MyConv(128, 128, kernel_size=3), - MyConv(128, 3, kernel_size=3), - ) - smod = nn.Sequential( - MyConv(3, 128, kernel_size=3), - MyConv(128, 128, kernel_size=3, aggregate=True), - MyConv(128, 128, kernel_size=3), - MyConv(128, 3, kernel_size=3, aggregate=True), - ) - smod.update(mod.parameters()) - shard_inplace(smod, sharding) - - x = mx.random.normal((4, 16, 16, 3)) - y1 = mod(x) - y2 = smod(x) - self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) - if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 6721b08312..dab40e48dd 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -1,7 +1,5 @@ # Copyright © 2024 Apple Inc. -import unittest - import mlx.core as mx import mlx_distributed_tests import mlx_tests @@ -10,7 +8,9 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): @classmethod def setUpClass(cls): - world = mx.distributed.init(strict=True, backend="ring") + _ = mx.distributed.init(strict=True, backend="ring") + cls.atol = 1e-6 + cls.rtol = 1e-4 def test_groups(self): world = mx.distributed.init() @@ -24,18 +24,11 @@ def test_groups(self): with self.assertRaises(RuntimeError): sub = world.split(world.rank() % 2) - def test_all_reduce(self): + def test_all_reduce_extra(self): world = mx.distributed.init() dtypes = [ - (mx.int8, 0), - (mx.uint8, 0), (mx.int16, 0), (mx.uint16, 0), - (mx.int32, 0), - (mx.uint32, 0), - (mx.float32, 1e-6), - (mx.float16, 5e-3), - (mx.bfloat16, 1e-1), (mx.complex64, 1e-6), ] sizes = [ @@ -45,7 +38,6 @@ def test_all_reduce(self): (1024, 1024), ] key = mx.random.key(0) - reductions = ["min", "max", "sum"] for dt, rtol in dtypes: for sh in sizes: @@ -72,16 +64,11 @@ def test_all_reduce(self): z = x.min(0) self.assertTrue(mx.all(y == z)) - def test_all_gather(self): + def test_all_gather_extra(self): world = mx.distributed.init() dtypes = [ - mx.int8, - mx.uint8, mx.int16, mx.uint16, - mx.int32, - mx.uint32, - mx.float32, mx.complex64, ] for dt in dtypes: