-
Notifications
You must be signed in to change notification settings - Fork 228
distributed merge of per-rank Megatron data files #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
269af4e
9ba081b
d29a702
ed49713
e94f2a0
687ff32
af59545
4f648a0
9f2ba6a
72d6c9c
8b67bec
3eca1f3
a691b48
e4a34e2
8b168ca
7a02693
eca2940
b14491d
ec11281
354d13b
2dc3f7a
980e904
ebd20a6
69b2f49
50de06a
af290ad
4b58c74
71a2fdc
73d3a24
b9e69be
da615c6
c42f41f
a3a7d53
163310a
01b2be0
4b6e8ff
ea08555
2524fce
ca14d48
d482f36
28d76f5
f706108
f122883
57c012e
eed8327
a75cfc2
afcfcf9
74b733a
dadb51b
a2f8fa0
39e6cd7
2a29d99
d6fa895
1216c0a
a64d3da
fde439e
ba14351
852fdd0
61f4b46
22400f3
5cfcb95
74c4883
373e514
78ab715
002b403
ba763f7
fa11159
7e53fd3
53df36f
c43348f
81c21dd
adee502
13ae421
f3e1b1d
42962e1
9a2f383
15b7603
4adaddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,16 +73,23 @@ def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0 | |
| self.allassert(outval.dtype == np.int64, | ||
| f"Requires outval to be of type numpy.int64") | ||
|
|
||
| # Note that we create the torch.tensor outtensor to receive incoming | ||
| # data by using from_numpy. This sets outtensor to point to the same | ||
| # underlying memory of the outval numpy array. Receiving data into | ||
| # outtensor thus writes incoming data directly into the numpy array. | ||
|
|
||
| # Define list of tensors to scatter on the root. | ||
| # torch.distributed.scatter requires each tensor to be the same shape, | ||
| # so find the max size across all count values and pad. | ||
| scatterlist = None | ||
| if self.rank == root: | ||
| scatterlist = list(torch.split(torch.from_numpy(invals), counts)) | ||
| outtensor = torch.from_numpy(outval) | ||
| dist.scatter(outtensor, scatterlist, src=root) | ||
| scatterlist = [] | ||
| slices = list(torch.split(torch.from_numpy(invals), counts)) | ||
| for num, s in zip(counts, slices): | ||
| padtensor = torch.zeros(max(counts), dtype=torch.int64) | ||
| padtensor[:num] = s | ||
| scatterlist.append(padtensor) | ||
|
|
||
| # Receive a tensor of the max count size from the root, | ||
| # then copy values into output numpy array, which may be smaller. | ||
| recvtensor = torch.zeros(max(counts), dtype=torch.int64) | ||
| dist.scatter(recvtensor, scatterlist, src=root) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, interesting. So I seem to be using an old enough It looks like it was added about 9 months ago in this commit: I can also see that it internally is just calling broadcast and scatter a couple of times. After seeing that, I think our pad method is probably the best way to go after all.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah let's go with padding. |
||
| outval[:] = recvtensor[:counts[self.rank]] | ||
|
||
|
|
||
| def alltrue(self, val): | ||
| """Returns True if all procs input True, False otherwise""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
scatter_object_listdoesn't work for you, maybe we can improve slightly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's cleaner. Thanks for the tip.
I did that just to try it and pushed a commit in case you want to use it. I can also try
scatter_object_list. I suspect that should also work.I suppose the tensor-based method could be more efficient communication-wise (no pickle step), though this scatter step will not take much time in either case compared to the total time of the script.
The bigger concern might be the memory required on rank 0 to put together the arguments for the scatter. With the
mpi4py Scatterv, I knowmpi4pysends data from the original numpy array. With torch, it looks like we'll at least be doubling the memory by effectively slicing up the original numpy list into these per-rank tensors/sublists. I don't know if it would be worse the doubling the memory -- it depends on the implementation under the covers. However, even that likely won't be an issue until the input index list is really big, at which point, we can always fall back to the file-based scatter.Whichever you prefer is good with me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the day we can't fit everything on a single process, we'll think of a better way IMO. (perhaps bring back
mpi4py). Let's stick to the padding strategy.