Skip to content

Commit 9722111

Browse files
adammoodythomasw21
andauthored
distributed merge of per-rank Megatron data files (#55)
* add parallel merge using mpi * handle case where some ranks might have 0 items * add inclusive scan prefix sum * report more timing info * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * rename total size variable for clarity * move translation to bin/idx file names a level deeper * parallel merge for cached dataset * add alltrue function * move collectives to new distdata class, add torch.distributed * drop unused prefix_sum function * allow ranks to pass a list of files to be merged * check that input dataset files exist * fix: using wrong doc_idx list for mmap * move init dist and collectives to distdata class * add --merge option, move parallel/serial to their own functions * Update megatron/data/distdata.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * drop extraneous numpy tolist calls * rename self.MPI to mpi4py * handle case where no ranks have elements in their file * rename tokenize_start to time_start * drop unrelated comment in distdata.min * add comment why pointers_shift is not None and add assert * note why pointers uses sizes count and offset values * can just rely on rank 0 for the leading 0 element * add write_list function * determine element size * add checks for consistent element_size values * check that at least one rank has a file to merge * assert that torch backend is gloo or mpi * add collectives for assert and raise * rename to allassert and allraise_if * check dtype instead of element_size * add uint32 to element_sizes table * infer dtype from files being merged * add write_header function to indexed dataset classes * call write_header internally from IndexedDataset classes * return number of bytes written from write calls * move scatterv to distdata class * add functions to format status and error messages * defer merge_files_dist to future PR * open files using with, refresh comments * rely on default torch datatypes * fix some status messages from preprocess script * fix: exclusive scan computing pointers list * fix: exclusive scan to compute mmap pointers list * note about seek * rename preprocess_dataset_mpi.py to preprocess_data_dist.py * update usage comments at top of script * restore commented print_rank_0 statements * restore status message in mmap merge_file_ * drop mpi4py, sad :( * add test case for parallel merge * add preprocess_data_dist test for serial merge * improve error handling * refactor get_pointers code * bug fix in exscan * further refactor get_pointers * move exscan collective for pointers outside of try block * clarify some comments * include string 1k in name of test files * use temporary file for index * fix: implement scatterv from torch.distributed.scatter * switch to pad method in torch.nn.functional * return data received in scatterv as new tensor * raise exception if conflicting scratch and merge options * use allraise method from distdata in preprocess_data_dist Co-authored-by: Thomas Wang <[email protected]>
1 parent 6d88ae2 commit 9722111

File tree

4 files changed

+992
-208
lines changed

4 files changed

+992
-208
lines changed

megatron/data/distdata.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import os
2+
import numpy as np
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import torch.distributed as dist
7+
8+
class DistDataError(Exception):
9+
"""Defines an empty exception to throw when some other rank hit a real exception."""
10+
pass
11+
12+
class DistData(object):
13+
def __init__(self, backend='gloo'):
14+
assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'"
15+
16+
dist.init_process_group(backend, init_method="env://")
17+
18+
# lookup our process rank and the group size
19+
self.rank = dist.get_rank()
20+
self.numranks = dist.get_world_size()
21+
22+
def allassert(self, cond, msg):
23+
"""Check that cond is True on all ranks, assert with msg everywhere if not.
24+
25+
To prevent deadlocks in cases where an assertion might only fail on one rank,
26+
this executes an allreduce to ensure that if any rank finds that an assertion
27+
has been violated, all ranks fail an assertion check.
28+
The condition must be true on all ranks for this not to assert.
29+
"""
30+
alltrue = self.alltrue(cond)
31+
assert alltrue, msg
32+
33+
def allraise_if(self, err):
34+
"""Raise exception if err is not None on any rank.
35+
36+
Similarly to allassert, this raises an exception on all ranks if err
37+
is set to an exception on any rank. Rank(s) where err is not None
38+
re-raise err as exception, and ranks where err is None raise DistDataError.
39+
Thus all ranks raise an exception if any rank has an active exception,
40+
which helps avoid deadlocks in cases where an exception may be raised
41+
on a subset of ranks.
42+
"""
43+
alltrue = self.alltrue(err is None)
44+
if not alltrue:
45+
# At least one rank raised an exception.
46+
# Re-raise the actual exception if this rank threw one.
47+
if err is not None:
48+
raise err
49+
50+
# TODO: is there a better exception to use here?
51+
# On other ranks, raise an "empty" exception to indicate
52+
# that we're only failing because someone else did.
53+
raise DistDataError
54+
55+
def barrier(self):
56+
"""Globally synchronize all processes"""
57+
dist.barrier()
58+
59+
def bcast(self, val, root):
60+
"""Broadcast a scalar value from root to all ranks"""
61+
vals = [val]
62+
dist.broadcast_object_list(vals, src=root)
63+
return vals[0]
64+
65+
def scatterv_(self, invals: np.array, counts: list, root:int=0):
66+
"""Scatter int64 values from invals according to counts array, return received portion in a new tensor"""
67+
68+
self.allassert(len(counts) == self.numranks,
69+
f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}")
70+
71+
# Define list of tensors to scatter on the root.
72+
# torch.distributed.scatter requires each tensor to be the same shape,
73+
# so find the max size across all count values and pad.
74+
max_size = max(counts)
75+
scatterlist = None
76+
if self.rank == root:
77+
slices = list(torch.split(torch.from_numpy(invals), counts))
78+
scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices]
79+
80+
# Receive a tensor of the max count size from the root,
81+
# then copy values into output numpy array, which may be smaller.
82+
recvtensor = torch.zeros(max_size, dtype=torch.int64)
83+
dist.scatter(recvtensor, scatterlist, src=root)
84+
return recvtensor[:counts[self.rank]]
85+
86+
def alltrue(self, val):
87+
"""Returns True if all procs input True, False otherwise"""
88+
# torch.dist does not support reductions with bool types
89+
# so we cast to int and cast the result back to bool
90+
tensor = torch.tensor([int(val)], dtype=torch.int32)
91+
dist.all_reduce(tensor, op=dist.ReduceOp.BAND)
92+
return bool(tensor[0])
93+
94+
def sum(self, val):
95+
"""Compute sum of a scalar val, and return total on all ranks."""
96+
tensor = torch.tensor([val])
97+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
98+
return tensor[0]
99+
100+
def exscan(self, val: int):
101+
"""Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank."""
102+
# torch.distributed doesn't have a scan, so fallback to allreduce
103+
tensor = torch.zeros(self.numranks, dtype=torch.int64)
104+
tensor[self.rank:] = val
105+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
106+
return int(tensor[self.rank]) - val
107+
108+
def min(self, val):
109+
"""Return minimum of scalar val to all ranks."""
110+
tensor = torch.tensor([val])
111+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
112+
return tensor[0]
113+
114+
def minrank(self, cond):
115+
"""Find first rank whose condition is True, return that rank if any, None otherwise."""
116+
minrank = self.numranks
117+
if cond:
118+
minrank = self.rank
119+
minrank = self.min(minrank)
120+
121+
if minrank < self.numranks:
122+
return minrank
123+
return None
124+
125+
def bcast_first(self, val):
126+
"""Broadcast val from first rank where it is not None, return val if any, None otherwise"""
127+
# Find the first rank with a valid value.
128+
minrank = self.minrank(val is not None)
129+
130+
# If there is no rank with a valid value, return None
131+
if minrank is None:
132+
return None
133+
134+
# Otherwise broadcast the value from the first valid rank.
135+
val = self.bcast(val, root=minrank)
136+
return val
137+
138+
def all_sum_(self, vals: np.array):
139+
"""Sums values in numpy array vals element-wise and update vals in place with final result on all ranks"""
140+
# Builds torch.tensor with from_numpy to use same underlying memory as numpy array.
141+
tensor = torch.from_numpy(vals)
142+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
143+
144+
def open(self, filename, truncate=None):
145+
"""Create, truncate, and open a file shared by all ranks."""
146+
147+
# Don't truncate existing file until all ranks reach this point
148+
self.barrier()
149+
150+
# We'll capture any exception in this variable
151+
err = None
152+
153+
# Rank 0 creates and truncates file.
154+
if self.rank == 0:
155+
try:
156+
f = open(filename, 'wb')
157+
158+
# Some file systems like GPFS deliver faster write speed
159+
# if the file size is known before data is written to the file.
160+
if truncate is not None:
161+
f.truncate(truncate)
162+
163+
except Exception as e:
164+
err = e
165+
166+
# Verify that rank 0 created the file
167+
self.allraise_if(err)
168+
169+
# Wait for rank 0 to open (and truncate) file,
170+
# then have all ranks open file for writing.
171+
if self.rank != 0:
172+
try:
173+
f = open(filename, 'r+b')
174+
except Exception as e:
175+
err = e
176+
177+
# Verify that all ranks successfully opened the file
178+
self.allraise_if(err)
179+
180+
return f
181+
182+
def remove(self, filename):
183+
"""Remove a shared file."""
184+
185+
# Don't remove the file until all are ready
186+
self.barrier()
187+
188+
# We'll capture any exception in this variable
189+
err = None
190+
191+
# Rank 0 removes the file if it exists.
192+
if self.rank == 0:
193+
try:
194+
if os.path.exists(filename):
195+
os.remove(filename)
196+
except Exception as e:
197+
err = e
198+
199+
# Verify that rank 0 successfully removed the file.
200+
self.allraise_if(err)
201+
202+
def rename(self, srcfile, destfile):
203+
"""Rename a shared file."""
204+
205+
# Don't rename until all are ready
206+
self.barrier()
207+
208+
# We'll capture any exception in this variable
209+
err = None
210+
211+
# Rank 0 renames the file.
212+
if self.rank == 0:
213+
try:
214+
if os.path.exists(srcfile):
215+
os.rename(srcfile, destfile)
216+
except Exception as e:
217+
err = e
218+
219+
# Verify that the rename succeeded
220+
self.allraise_if(err)

0 commit comments

Comments
 (0)