Nccl reduce scatter, all gather#2727
Conversation
python/src/distributed.cpp
Outdated
| )pbdoc"); | ||
|
|
||
| m.def( | ||
| "reduce_scatter", |
There was a problem hiding this comment.
I'm wondering about the name. We use all_sum (instead of all_reduce) to indicate it's a sum. Maybe we should use sum_scatter here to be more consistent? Wdyt?
There was a problem hiding this comment.
Yes, I would agree.. I think it will be more consistent.
There was a problem hiding this comment.
I did the same as for AllReduce by adding a reduction op, let me know if you think that it is not needed and single sum_scatter is enough.
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
…astya236/mlx into nccl-reduce-scatter-all-gather
awni
left a comment
There was a problem hiding this comment.
Looks great! Will merge when tests clear!
|
I fixed a typo, sorry about that. It should pass now. Thanks for reviewing! |
|
The tests failed since your last push. It looks like it's trying to initialize nccl on mac.. can you see the failures? |
|
Finally everything is fixed :) |
Proposed changes
all_gatherandreduce_scatterall_reducetest for ring/MPI. But the following are now shared across all backends:test_all_reduce,test_average_gradients,test_donation,test_shard_linear,test_all_gather.In test_shard_linear, since we don’t have quantized matmuls on CUDA yet, the quantized variant runs only when CUDA is not available.
Test:
mlx.launch -n 8 mlx/python/tests/nccl_test_distributed.py