Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
269af4e
add parallel merge using mpi
adammoody Aug 9, 2021
9ba081b
handle case where some ranks might have 0 items
adammoody Aug 10, 2021
d29a702
add inclusive scan prefix sum
adammoody Aug 11, 2021
ed49713
report more timing info
adammoody Aug 11, 2021
e94f2a0
Update megatron/data/indexed_dataset.py
adammoody Aug 12, 2021
687ff32
Update megatron/data/indexed_dataset.py
adammoody Aug 12, 2021
af59545
rename total size variable for clarity
adammoody Aug 12, 2021
4f648a0
move translation to bin/idx file names a level deeper
adammoody Aug 13, 2021
9f2ba6a
parallel merge for cached dataset
adammoody Aug 13, 2021
72d6c9c
add alltrue function
adammoody Aug 13, 2021
8b67bec
move collectives to new distdata class, add torch.distributed
adammoody Aug 14, 2021
3eca1f3
drop unused prefix_sum function
adammoody Aug 14, 2021
a691b48
allow ranks to pass a list of files to be merged
adammoody Aug 15, 2021
e4a34e2
check that input dataset files exist
adammoody Aug 15, 2021
8b168ca
fix: using wrong doc_idx list for mmap
adammoody Aug 16, 2021
7a02693
move init dist and collectives to distdata class
adammoody Aug 16, 2021
eca2940
add --merge option, move parallel/serial to their own functions
adammoody Aug 16, 2021
b14491d
Merge branch 'main' into pmerge
adammoody Aug 16, 2021
ec11281
Update megatron/data/distdata.py
adammoody Aug 16, 2021
354d13b
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
2dc3f7a
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
980e904
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
ebd20a6
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
69b2f49
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
50de06a
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
af290ad
drop extraneous numpy tolist calls
adammoody Aug 16, 2021
4b58c74
rename self.MPI to mpi4py
adammoody Aug 16, 2021
71a2fdc
handle case where no ranks have elements in their file
adammoody Aug 16, 2021
73d3a24
rename tokenize_start to time_start
adammoody Aug 16, 2021
b9e69be
drop unrelated comment in distdata.min
adammoody Aug 16, 2021
da615c6
add comment why pointers_shift is not None and add assert
adammoody Aug 16, 2021
c42f41f
note why pointers uses sizes count and offset values
adammoody Aug 16, 2021
a3a7d53
can just rely on rank 0 for the leading 0 element
adammoody Aug 17, 2021
163310a
add write_list function
adammoody Aug 17, 2021
01b2be0
determine element size
adammoody Aug 17, 2021
4b6e8ff
add checks for consistent element_size values
adammoody Aug 17, 2021
ea08555
check that at least one rank has a file to merge
adammoody Aug 17, 2021
2524fce
assert that torch backend is gloo or mpi
adammoody Aug 17, 2021
ca14d48
add collectives for assert and raise
adammoody Aug 17, 2021
d482f36
rename to allassert and allraise_if
adammoody Aug 17, 2021
28d76f5
check dtype instead of element_size
adammoody Aug 17, 2021
f706108
add uint32 to element_sizes table
adammoody Aug 17, 2021
f122883
infer dtype from files being merged
adammoody Aug 17, 2021
57c012e
add write_header function to indexed dataset classes
adammoody Aug 17, 2021
eed8327
call write_header internally from IndexedDataset classes
adammoody Aug 17, 2021
a75cfc2
return number of bytes written from write calls
adammoody Aug 17, 2021
afcfcf9
Merge branch 'main' into pmerge
adammoody Aug 17, 2021
74b733a
move scatterv to distdata class
adammoody Aug 17, 2021
dadb51b
add functions to format status and error messages
adammoody Aug 17, 2021
a2f8fa0
defer merge_files_dist to future PR
adammoody Aug 17, 2021
39e6cd7
open files using with, refresh comments
adammoody Aug 18, 2021
2a29d99
rely on default torch datatypes
adammoody Aug 18, 2021
d6fa895
fix some status messages from preprocess script
adammoody Aug 18, 2021
1216c0a
fix: exclusive scan computing pointers list
adammoody Aug 18, 2021
a64d3da
Merge branch 'pointerfix' into pmerge
adammoody Aug 18, 2021
fde439e
fix: exclusive scan to compute mmap pointers list
adammoody Aug 18, 2021
ba14351
note about seek
adammoody Aug 19, 2021
852fdd0
rename preprocess_dataset_mpi.py to preprocess_data_dist.py
adammoody Aug 19, 2021
61f4b46
update usage comments at top of script
adammoody Aug 19, 2021
22400f3
restore commented print_rank_0 statements
adammoody Aug 19, 2021
5cfcb95
restore status message in mmap merge_file_
adammoody Aug 19, 2021
74c4883
drop mpi4py, sad :(
adammoody Aug 19, 2021
373e514
Merge branch 'main' into pmerge
adammoody Aug 19, 2021
78ab715
add test case for parallel merge
adammoody Aug 19, 2021
002b403
add preprocess_data_dist test for serial merge
adammoody Aug 19, 2021
ba763f7
improve error handling
adammoody Aug 20, 2021
fa11159
refactor get_pointers code
adammoody Aug 20, 2021
7e53fd3
bug fix in exscan
adammoody Aug 20, 2021
53df36f
further refactor get_pointers
adammoody Aug 20, 2021
c43348f
move exscan collective for pointers outside of try block
adammoody Aug 20, 2021
81c21dd
clarify some comments
adammoody Aug 20, 2021
adee502
include string 1k in name of test files
adammoody Aug 20, 2021
13ae421
use temporary file for index
adammoody Aug 20, 2021
f3e1b1d
fix: implement scatterv from torch.distributed.scatter
adammoody Aug 23, 2021
42962e1
switch to pad method in torch.nn.functional
adammoody Aug 25, 2021
9a2f383
return data received in scatterv as new tensor
adammoody Aug 25, 2021
15b7603
raise exception if conflicting scratch and merge options
adammoody Aug 25, 2021
4adaddd
use allraise method from distdata in preprocess_data_dist
adammoody Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions megatron/data/distdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,23 @@ def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0
self.allassert(outval.dtype == np.int64,
f"Requires outval to be of type numpy.int64")

# Note that we create the torch.tensor outtensor to receive incoming
# data by using from_numpy. This sets outtensor to point to the same
# underlying memory of the outval numpy array. Receiving data into
# outtensor thus writes incoming data directly into the numpy array.

# Define list of tensors to scatter on the root.
# torch.distributed.scatter requires each tensor to be the same shape,
# so find the max size across all count values and pad.
scatterlist = None
if self.rank == root:
scatterlist = list(torch.split(torch.from_numpy(invals), counts))
outtensor = torch.from_numpy(outval)
dist.scatter(outtensor, scatterlist, src=root)
scatterlist = []
slices = list(torch.split(torch.from_numpy(invals), counts))
for num, s in zip(counts, slices):
padtensor = torch.zeros(max(counts), dtype=torch.int64)
padtensor[:num] = s
scatterlist.append(padtensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If scatter_object_list doesn't work for you, maybe we can improve slightly.

import torch.nn.functional as F

slices = torch.split(torch.from_numpy(invals), counts)
max_size = max(counts)
scatterlist = [F.pad(slice, (0, max_size - len(slice))) for slice in slices]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's cleaner. Thanks for the tip.

I did that just to try it and pushed a commit in case you want to use it. I can also try scatter_object_list. I suspect that should also work.

I suppose the tensor-based method could be more efficient communication-wise (no pickle step), though this scatter step will not take much time in either case compared to the total time of the script.

The bigger concern might be the memory required on rank 0 to put together the arguments for the scatter. With the mpi4py Scatterv, I know mpi4py sends data from the original numpy array. With torch, it looks like we'll at least be doubling the memory by effectively slicing up the original numpy list into these per-rank tensors/sublists. I don't know if it would be worse the doubling the memory -- it depends on the implementation under the covers. However, even that likely won't be an issue until the input index list is really big, at which point, we can always fall back to the file-based scatter.

Whichever you prefer is good with me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the day we can't fit everything on a single process, we'll think of a better way IMO. (perhaps bring backmpi4py). Let's stick to the padding strategy.


# Receive a tensor of the max count size from the root,
# then copy values into output numpy array, which may be smaller.
recvtensor = torch.zeros(max(counts), dtype=torch.int64)
dist.scatter(recvtensor, scatterlist, src=root)
Copy link
Member

@thomasw21 thomasw21 Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about scatter_object_list? I tried it and it seems to work nicely, though I don't know what's the cost using that method vs padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. So I seem to be using an old enough torch.distributed that I don't have scatter_object_list:

AttributeError: module 'torch.distributed' has no attribute 'scatter_object_list'

It looks like it was added about 9 months ago in this commit:

pytorch/pytorch@02d89f9

I can also see that it internally is just calling broadcast and scatter a couple of times.

After seeing that, I think our pad method is probably the best way to go after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's go with padding.

outval[:] = recvtensor[:counts[self.rank]]
Copy link
Member

@thomasw21 thomasw21 Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't benefit from the inplace operator, let's return the tensor instead of doing an inplace operation? Typically you could return only recvtensor[:counts[self.rank]] which would remove the need to initialise outval.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Made that change.


def alltrue(self, val):
"""Returns True if all procs input True, False otherwise"""
Expand Down