From 269af4ebe3d839c3c74b76872e3fd192c6a3ffe2 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 9 Aug 2021 11:19:27 -0700 Subject: [PATCH 01/74] add parallel merge using mpi --- megatron/data/indexed_dataset.py | 140 +++++++++++++++++++++++++++++-- tools/preprocess_dataset_mpi.py | 57 ++++++++++++- 2 files changed, 189 insertions(+), 8 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 831feb81f..1a0b92139 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -417,21 +417,21 @@ def __init__(self, path, skip_warmup=False): offset = stream.tell() if not skip_warmup: - print_rank_0(" warming up index mmap file...") +# print_rank_0(" warming up index mmap file...") _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) - print_rank_0(" reading sizes...") +# print_rank_0(" reading sizes...") self._sizes = np.frombuffer( self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) - print_rank_0(" reading pointers...") +# print_rank_0(" reading pointers...") self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) - print_rank_0(" reading document index...") +# print_rank_0(" reading document index...") self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, offset=offset + self._sizes.nbytes + self._pointers.nbytes) @@ -576,8 +576,8 @@ def merge_file_(self, another_file): index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype - total_len = len(index.sizes)+len(self._sizes) - print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") +# total_len = len(index.sizes)+len(self._sizes) +# print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") offset = len(self._sizes) self._sizes.extend(index.sizes) @@ -592,3 +592,131 @@ def finalize(self, index_file): with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: index.write(self._sizes, self._doc_idx) + +def mpi_get_sum(val, mpi, comm): + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + comm.Allreduce(insize, outsize, op=mpi.SUM) + return outsize[0] + +def mpi_get_offset(val, mpi, comm): + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + comm.Scan(insize, outsize, op=mpi.SUM) + offset = outsize[0] - insize[0] + return offset + +def merge_files_mpi_bin(outfile, infile, mpi, comm): + import stat + import shutil + + comm.barrier() + + # wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing + rank = comm.Get_rank() + if rank == 0: + f = open(outfile, 'wb') + comm.barrier() + if rank != 0: + f = open(outfile, 'r+b') + + # get file size of binary file for this rank + filesize = os.stat(infile)[stat.ST_SIZE] + + # compute offset this rank should start copying + # its data into the merged file + offset = mpi_get_offset(filesize, mpi, comm) + #print(rank, infile, filesize, offset) + + # seek to appropriate offset and copy data + f.seek(offset) + with open(infile, "rb") as fsrc: + shutil.copyfileobj(fsrc, f) + + f.close() + + comm.barrier() + +def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): + rank = comm.Get_rank() + + comm.barrier() + + # wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing + rank = comm.Get_rank() + if rank == 0: + f = open(outfile, 'wb') + comm.barrier() + if rank != 0: + f = open(outfile, 'r+b') + + index = MMapIndexedDataset.Index(infile) + sizes = index.sizes + if rank == 0: + docs = index.doc_idx + if rank != 0: + docs = index.doc_idx[1:] + + numsizes = len(sizes) + numdocs = len(docs) + size_count = mpi_get_sum(numsizes, mpi, comm) + docs_count = mpi_get_sum(numdocs, mpi, comm) + size_offset = mpi_get_offset(numsizes, mpi, comm) + docs_offset = mpi_get_offset(numdocs, mpi, comm) + + # have rank 0 write the file header + if rank == 0: + f.write(MMapIndexedDataset.Index._HDR_MAGIC) + f.write(struct.pack(' 0: + np.cumsum(pointers, axis=0, out=pointers) + pointers *= dtype().itemsize + pointer_last = pointers[-1] + pointer_offset = mpi_get_offset(pointer_last, mpi, comm) + pointers += pointer_offset + + pointers_shift = 0 + if rank == 0 and len(sizes) > 0: + pointers_shift = pointers[0] + pointers_shift = mpi_get_sum(pointers_shift, mpi, comm) + pointers -= pointers_shift + + f.seek(pos + size_offset * np.int64().itemsize) + f.write(pointers.tobytes(order='C')) + del pointers + pos += size_count * np.int64().itemsize + + doc_idx = np.array(docs, dtype=np.int64) + doc_idx += size_offset + + f.seek(pos + docs_offset * np.int64().itemsize) + f.write(doc_idx.tobytes(order='C')) + del doc_idx + + f.close() + + comm.barrier() + +def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): + merge_files_mpi_bin(data_file_path(filemain), data_file_path(filerank), mpi, comm) + merge_files_mpi_idx(index_file_path(filemain), index_file_path(filerank), mpi, comm, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index b80ad4a25..28cfdd8d5 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -65,7 +65,7 @@ from datasets.utils.file_utils import OfflineModeIsEnabled from megatron.tokenizer import build_tokenizer -from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype +from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, merge_files_mpi # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -166,6 +166,8 @@ def get_args(): help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') + group.add_argument('--scratch', type=str, default='/dev/shm', + help='Path to local storage on compute nodes to write per-rank files before merging.') group.add_argument('--log-interval', type=int, default=30, help='Seconds between progress updates (0 to disable)') @@ -281,6 +283,23 @@ def load_dset(args): if args.rank != 0: logging.set_verbosity(logging.ERROR) + success = True + dsetname = args.input + if not download: + try: + dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) + except: + # this print might be noisy, but better than nothing + print("ERROR: Unexpected error:", sys.exc_info()[0]) + success = False + + # determine whether everyone succeeded in loading the dataset + success = all_true(args, success) + if not success: + return None + + return dset + # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. success = True @@ -391,6 +410,11 @@ def get_start_end(num, rank, num_ranks): def get_filename(args, key, rank=None): pathname = args.output_prefix + # redirect per-rank file to scratch dir if defined + if args.scratch and rank is not None: + basename = os.path.basename(pathname) + pathname = os.path.join(args.scratch, basename) + if rank is not None: filename = f"{pathname}_{key}_{args.level}_{rank}" else: @@ -408,6 +432,7 @@ def rank_files_write(args, dset, idx, encoder): # we'll set this to false on any problem success = True err = None + times = np.zeros(2, dtype=np.float32) try: # create data file for each rank if args.rank == 0: @@ -432,8 +457,13 @@ def rank_files_write(args, dset, idx, encoder): for i in idx[idx_start:idx_end]: for key in args.columns: # tokenize text for the given sample index + start_read = time.time() text = dset[i][key] + start_encode = time.time() doc, bytes_processed = encoder.encode_text(text) + end_encode = time.time() + times[0] += start_encode - start_read + times[1] += end_encode - start_encode # add tokenized sequence to our data file for key, sentences in doc.items(): @@ -444,6 +474,9 @@ def rank_files_write(args, dset, idx, encoder): dset_stats[1] += len(sentences) dset_stats[2] += bytes_processed +# dset_stats[0] += 1 +# dset_stats[2] += len(text) + if args.rank == 0 and args.log_interval > 0 and time.time() > progress_next: current = time.time() progress_next = current + float(args.log_interval) @@ -480,6 +513,7 @@ def rank_files_write(args, dset, idx, encoder): tokenize_end = time.time() # compute total stats across all processes + all_sum_(args, times) all_sum_(args, dset_stats) if args.rank == 0: secs = tokenize_end - tokenize_start @@ -497,6 +531,25 @@ def rank_files_write(args, dset, idx, encoder): return success, err def rank_files_merge(args): + merge_start = time.time() + numbytes = np.zeros(1, dtype=np.int64) + for key in args.columns: + filemain = get_filename(args, key) + filerank = get_filename(args, key, args.rank) + binfile = data_file_path(filerank) + numbytes[0] += os.stat(binfile)[stat.ST_SIZE] + merge_files_mpi(filemain, filerank, args.MPI, args.mpi_comm, dtype=best_fitting_dtype(args.vocab_size)) + barrier(args) + all_sum_(args, numbytes) + merge_end = time.time() + if args.rank == 0: + secs = merge_end - merge_start + byterate = numbytes[0] / secs if secs > 0.0 else 0.0 + print("Seconds to merge (parallel):", secs, flush=True) + print("Bytes=", numbytes[0], "bytes/sec=", byterate, flush=True) + if args.scratch: + return + # rank 0 merges all per-rank files if args.rank == 0: print("Merging rank files ...", flush=True) @@ -520,7 +573,7 @@ def rank_files_merge(args): for key in args.columns: infile = get_filename(args, key, rank) - print(f"Merging file {infile}", flush=True) +# print(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes From 9ba081be6c31fb0cf97092bdcca31fc2ee80e7e7 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 9 Aug 2021 18:06:00 -0700 Subject: [PATCH 02/74] handle case where some ranks might have 0 items --- megatron/data/indexed_dataset.py | 92 +++++++++++++++++++++++++++----- tools/preprocess_dataset_mpi.py | 29 ++++------ 2 files changed, 90 insertions(+), 31 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1a0b92139..63c1bed31 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -593,20 +593,40 @@ def finalize(self, index_file): with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: index.write(self._sizes, self._doc_idx) + def mpi_get_sum(val, mpi, comm): + """Compute sum of val, and return total on all ranks.""" insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) comm.Allreduce(insize, outsize, op=mpi.SUM) return outsize[0] + def mpi_get_offset(val, mpi, comm): + """Compute preifx sum (exclusive scan) of val, and return offset of each rank.""" insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) comm.Scan(insize, outsize, op=mpi.SUM) offset = outsize[0] - insize[0] return offset + +def mpi_get_min(val, mpi, comm): + """Return minimum val to all ranks.""" + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + comm.Allreduce(insize, outsize, op=mpi.MIN) + return outsize[0] + + +# To create the binary files given a set of per-rank binary +# files, one simply concatenates the data from the per-rank +# binary files in rank order. We stat each rank file to determine +# its size, execute a scan to compute the byte offset where +# the calling rank should write its data, seek to proper +# spot, and copy the full file. def merge_files_mpi_bin(outfile, infile, mpi, comm): + """Concatenate per-rank binary files into a new file given by outfile""" import stat import shutil @@ -627,7 +647,6 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): # compute offset this rank should start copying # its data into the merged file offset = mpi_get_offset(filesize, mpi, comm) - #print(rank, infile, filesize, offset) # seek to appropriate offset and copy data f.seek(offset) @@ -638,8 +657,10 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): comm.barrier() + def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): rank = comm.Get_rank() + numranks = comm.Get_size() comm.barrier() @@ -652,6 +673,7 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): if rank != 0: f = open(outfile, 'r+b') + # read the index file for the calling rank index = MMapIndexedDataset.Index(infile) sizes = index.sizes if rank == 0: @@ -659,6 +681,11 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): if rank != 0: docs = index.doc_idx[1:] + # Compute total number of size and document index + # values across all ranks. Also compute the offset + # of the calling rank for each value considering + # the values of sizes/docs for all ranks before the + # calling rank. numsizes = len(sizes) numdocs = len(docs) size_count = mpi_get_sum(numsizes, mpi, comm) @@ -667,48 +694,88 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): docs_offset = mpi_get_offset(numdocs, mpi, comm) # have rank 0 write the file header + pos = 0 if rank == 0: f.write(MMapIndexedDataset.Index._HDR_MAGIC) f.write(struct.pack(' 0: np.cumsum(pointers, axis=0, out=pointers) pointers *= dtype().itemsize pointer_last = pointers[-1] + + # Then account for bytes for all sentences on ranks + # before the calling rank. pointer_offset = mpi_get_offset(pointer_last, mpi, comm) pointers += pointer_offset - pointers_shift = 0 - if rank == 0 and len(sizes) > 0: - pointers_shift = pointers[0] - pointers_shift = mpi_get_sum(pointers_shift, mpi, comm) - pointers -= pointers_shift - + # Finally, zero-base the offset values by subtracting + # the number of bytes of the first sentence. To do that + # we first need to find the rank having the first sentence, + # then bcast that size to all ranks. + if size_count > 0: + # There is at least one sentence across all ranks, + # figure out which rank has the first sentence which + # is not necessarily rank 0. + minrank = numranks + if len(sizes) > 0: + minrank = rank + minrank = mpi_get_min(minrank, mpi, comm) + + # Broadcast size of the first sentence from minrank. + # We "bcast" using all allreduce. + pointers_shift = 0 + if minrank == rank: + pointers_shift = pointers[0] + pointers_shift = mpi_get_sum(pointers_shift, mpi, comm) + + # Zero-base pointers by subtracting size of first + # sentence from all values. + pointers -= pointers_shift + + # Seek to proper offset for this rank and write + # pointer values into file, stored as int64. f.seek(pos + size_offset * np.int64().itemsize) f.write(pointers.tobytes(order='C')) del pointers pos += size_count * np.int64().itemsize + # The document index points to the position in the sizes + # array for the starting sentence of each document. + # A variable number of sentences can be in each document. + # Adjust document index for number of sentences that + # come before the calling rank. doc_idx = np.array(docs, dtype=np.int64) doc_idx += size_offset + # Seek to proper offset for this rank and write + # document index into file, stored as int64. f.seek(pos + docs_offset * np.int64().itemsize) f.write(doc_idx.tobytes(order='C')) del doc_idx @@ -717,6 +784,7 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): comm.barrier() + def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): merge_files_mpi_bin(data_file_path(filemain), data_file_path(filerank), mpi, comm) merge_files_mpi_idx(index_file_path(filemain), index_file_path(filerank), mpi, comm, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 28cfdd8d5..25c214efe 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -166,8 +166,8 @@ def get_args(): help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') - group.add_argument('--scratch', type=str, default='/dev/shm', - help='Path to local storage on compute nodes to write per-rank files before merging.') + group.add_argument('--scratch', type=str, default=None, + help='Path to local storage on compute nodes to write per-rank files before merging, like /dev/shm.') group.add_argument('--log-interval', type=int, default=30, help='Seconds between progress updates (0 to disable)') @@ -284,27 +284,11 @@ def load_dset(args): logging.set_verbosity(logging.ERROR) success = True + err = None dsetname = args.input - if not download: - try: - dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) - except: - # this print might be noisy, but better than nothing - print("ERROR: Unexpected error:", sys.exc_info()[0]) - success = False - - # determine whether everyone succeeded in loading the dataset - success = all_true(args, success) - if not success: - return None - - return dset # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. - success = True - err = None - dsetname = args.input if args.rank == 0: print(f"Opening dataset {dsetname}") try: @@ -539,6 +523,13 @@ def rank_files_merge(args): binfile = data_file_path(filerank) numbytes[0] += os.stat(binfile)[stat.ST_SIZE] merge_files_mpi(filemain, filerank, args.MPI, args.mpi_comm, dtype=best_fitting_dtype(args.vocab_size)) + + # rename files for now and also do regular merge so we can time both and "cmp" them + if args.rank == 0: + binfile = data_file_path(filemain) + idxfile = index_file_path(filemain) + os.rename(binfile, binfile + ".par") + os.rename(idxfile, idxfile + ".par") barrier(args) all_sum_(args, numbytes) merge_end = time.time() From d29a702314991c7fe9aeff37f6fcda4f2d14fc9c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 11 Aug 2021 09:52:12 -0700 Subject: [PATCH 03/74] add inclusive scan prefix sum --- tools/preprocess_dataset_mpi.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 25c214efe..f3955c43e 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -252,6 +252,20 @@ def all_sum_(args, vals): tensor = torch.from_numpy(vals) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) +def prefix_sum(args, val): + """Returns result of an inclusive scan prefix sum of val across ranks""" + if args.use_mpi: + inval = np.array([val], dtype=np.int64) + outval = np.zeros_like(inval) + args.mpi_comm.Scan(inval, outval, op=args.MPI.SUM) + return outval[0] + else: + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(args.numranks, dtype=torch.int64) + tensor[args.rank:args.numranks] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[args.rank]) + def all_true(args, val): """Returns True if all procs input True, False otherwise""" if args.use_mpi: From ed4971325f1353784b642dbed553e9ae0484fb7f Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 11 Aug 2021 12:25:24 -0700 Subject: [PATCH 04/74] report more timing info --- megatron/data/indexed_dataset.py | 52 +++++++++----- tools/preprocess_dataset_mpi.py | 116 ++++++++++++++++++++++--------- 2 files changed, 115 insertions(+), 53 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 63c1bed31..417858a7f 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -619,6 +619,26 @@ def mpi_get_min(val, mpi, comm): return outsize[0] +def mpi_create_file(filename, mpi, comm): + """Create, truncate, and open a file shared by all ranks.""" + # Don't truncate file until all ranks reach this point + comm.barrier() + + # Wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing. + rank = comm.Get_rank() + if rank == 0: + f = open(filename, 'wb') + comm.barrier() + if rank != 0: + f = open(filename, 'r+b') + + # TODO: verify that all ranks successfully opened the file + comm.barrier() + + return f + + # To create the binary files given a set of per-rank binary # files, one simply concatenates the data from the per-rank # binary files in rank order. We stat each rank file to determine @@ -632,14 +652,8 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): comm.barrier() - # wait for rank 0 to open (and truncate) file, - # then have all ranks open file for writing - rank = comm.Get_rank() - if rank == 0: - f = open(outfile, 'wb') - comm.barrier() - if rank != 0: - f = open(outfile, 'r+b') + # Create shared output file. + f = mpi_create_file(outfile, mpi, comm) # get file size of binary file for this rank filesize = os.stat(infile)[stat.ST_SIZE] @@ -655,6 +669,7 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): f.close() + # TODO: check that all ranks wrote successfully comm.barrier() @@ -664,16 +679,10 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): comm.barrier() - # wait for rank 0 to open (and truncate) file, - # then have all ranks open file for writing - rank = comm.Get_rank() - if rank == 0: - f = open(outfile, 'wb') - comm.barrier() - if rank != 0: - f = open(outfile, 'r+b') + # Create shared output file + f = mpi_create_file(outfile, mpi, comm) - # read the index file for the calling rank + # Read the index file for the calling rank index = MMapIndexedDataset.Index(infile) sizes = index.sizes if rank == 0: @@ -693,7 +702,7 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): size_offset = mpi_get_offset(numsizes, mpi, comm) docs_offset = mpi_get_offset(numdocs, mpi, comm) - # have rank 0 write the file header + # Have rank 0 write the file header pos = 0 if rank == 0: f.write(MMapIndexedDataset.Index._HDR_MAGIC) @@ -703,7 +712,8 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): f.write(struct.pack(' 0.0 else 0.0 sentrate = dset_stats[1] / secs if secs > 0.0 else 0.0 byterate = dset_stats[2] / secs if secs > 0.0 else 0.0 - print("Tokenize stats:", secs) - print(f" Seconds to tokenize: {secs}") + print("Process stats:") + print(f" Seconds to process: {secs}") print(f" {dset_stats[0]} docs {docrate} docs/sec") print(f" {dset_stats[1]} sents {sentrate} sents/sec") print(f" {dset_stats[2]} bytes {format_byterate(byterate)}") + print(f" Total read seconds {times[0]}, {times[0]/dset_stats[0]} sec/sample") + print(f" Total encode seconds {times[1]}, {times[1]/dset_stats[0]} sec/sample") + print(f" Total write seconds {times[2]}, {times[2]/dset_stats[0]} sec/sample") # allreduce to check whether all ranks wrote their part successfully success = all_true(args, success) return success, err def rank_files_merge(args): - merge_start = time.time() - numbytes = np.zeros(1, dtype=np.int64) - for key in args.columns: - filemain = get_filename(args, key) - filerank = get_filename(args, key, args.rank) - binfile = data_file_path(filerank) - numbytes[0] += os.stat(binfile)[stat.ST_SIZE] - merge_files_mpi(filemain, filerank, args.MPI, args.mpi_comm, dtype=best_fitting_dtype(args.vocab_size)) - - # rename files for now and also do regular merge so we can time both and "cmp" them + # only try parallel merge when using MPI + if args.use_mpi: + merge_start = time.time() + numbytes = np.zeros(1, dtype=np.int64) + for key in args.columns: + filemain = get_filename(args, key) + filerank = get_filename(args, key, args.rank) + merge_files_mpi(filemain, filerank, args.MPI, args.mpi_comm, dtype=best_fitting_dtype(args.vocab_size)) + + # total up bytes read in merge + binfile = data_file_path(filerank) + idxfile = data_file_path(filerank) + numbytes[0] += os.stat(binfile)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] + + # rename files for now and also do regular merge so we can time both and "cmp" them + if args.rank == 0: + binfile = data_file_path(filemain) + idxfile = index_file_path(filemain) + os.rename(binfile, binfile + ".par") + os.rename(idxfile, idxfile + ".par") + + barrier(args) + all_sum_(args, numbytes) + merge_end = time.time() if args.rank == 0: - binfile = data_file_path(filemain) - idxfile = index_file_path(filemain) - os.rename(binfile, binfile + ".par") - os.rename(idxfile, idxfile + ".par") - barrier(args) - all_sum_(args, numbytes) - merge_end = time.time() - if args.rank == 0: - secs = merge_end - merge_start - byterate = numbytes[0] / secs if secs > 0.0 else 0.0 - print("Seconds to merge (parallel):", secs, flush=True) - print("Bytes=", numbytes[0], "bytes/sec=", byterate, flush=True) - if args.scratch: - return + secs = merge_end - merge_start + byterate = numbytes[0] / secs if secs > 0.0 else 0.0 + print("Parallel merge stats:") + print(f" Scratch: {args.scratch}") + print(f" Seconds to merge: {secs}") + print(f" {int(numbytes)} bytes {format_byterate(byterate)}") + + # if using node-local storage, skip sequential merge test + if args.scratch: + return # rank 0 merges all per-rank files if args.rank == 0: @@ -658,7 +706,7 @@ def main(): barrier(args) startup_end = time.time() if args.rank == 0: - print("Seconds to startup:", startup_end - startup_start) + print(f"Seconds to startup: {startup_end - startup_start}") # have each rank write its file, returns False if any rank had a problem success, err = rank_files_write(args, dset, idx, encoder) From e94f2a0cea0fd1188f80233520c1ade16a371768 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 12 Aug 2021 13:58:13 -0500 Subject: [PATCH 05/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 417858a7f..5efca3b12 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -603,7 +603,7 @@ def mpi_get_sum(val, mpi, comm): def mpi_get_offset(val, mpi, comm): - """Compute preifx sum (exclusive scan) of val, and return offset of each rank.""" + """Compute prefix sum (exclusive scan) of val, and return offset of each rank.""" insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) comm.Scan(insize, outsize, op=mpi.SUM) From 687ff32f347f47b960bac2edd07b0f8940b3a1bc Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 12 Aug 2021 14:11:49 -0500 Subject: [PATCH 06/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 5efca3b12..0d5c20439 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -697,8 +697,8 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): # calling rank. numsizes = len(sizes) numdocs = len(docs) - size_count = mpi_get_sum(numsizes, mpi, comm) - docs_count = mpi_get_sum(numdocs, mpi, comm) + total_size_count = mpi_get_sum(numsizes, mpi, comm) + total_docs_count = mpi_get_sum(numdocs, mpi, comm) size_offset = mpi_get_offset(numsizes, mpi, comm) docs_offset = mpi_get_offset(numdocs, mpi, comm) From af595454a4f2f6bc7dd5ee509d84d2adeafab151 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 12 Aug 2021 12:26:49 -0700 Subject: [PATCH 07/74] rename total size variable for clarity --- megatron/data/indexed_dataset.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 0d5c20439..3b1737693 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -708,8 +708,8 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): f.write(MMapIndexedDataset.Index._HDR_MAGIC) f.write(struct.pack(' 0: + if total_size_count > 0: # There is at least one sentence across all ranks, # figure out which rank has the first sentence which # is not necessarily rank 0. @@ -774,7 +776,9 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): f.seek(pos + size_offset * np.int64().itemsize) f.write(pointers.tobytes(order='C')) del pointers - pos += size_count * np.int64().itemsize + + # Advance past list of pointer values + pos += total_size_count * np.int64().itemsize # The document index points to the position in the sizes # array for the starting sentence of each document. From 4f648a0ae930ba0e324f4a14ae523483da6c9333 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 13 Aug 2021 11:27:44 -0700 Subject: [PATCH 08/74] move translation to bin/idx file names a level deeper --- megatron/data/indexed_dataset.py | 18 +++++++----------- tools/preprocess_dataset_mpi.py | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 3b1737693..70037d8d6 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -650,13 +650,11 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): import stat import shutil - comm.barrier() - # Create shared output file. - f = mpi_create_file(outfile, mpi, comm) + f = mpi_create_file(data_file_path(outfile), mpi, comm) # get file size of binary file for this rank - filesize = os.stat(infile)[stat.ST_SIZE] + filesize = os.stat(data_file_path(infile))[stat.ST_SIZE] # compute offset this rank should start copying # its data into the merged file @@ -664,7 +662,7 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): # seek to appropriate offset and copy data f.seek(offset) - with open(infile, "rb") as fsrc: + with open(data_file_path(infile), "rb") as fsrc: shutil.copyfileobj(fsrc, f) f.close() @@ -677,13 +675,11 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): rank = comm.Get_rank() numranks = comm.Get_size() - comm.barrier() - # Create shared output file - f = mpi_create_file(outfile, mpi, comm) + f = mpi_create_file(index_file_path(outfile), mpi, comm) # Read the index file for the calling rank - index = MMapIndexedDataset.Index(infile) + index = MMapIndexedDataset.Index(index_file_path(infile)) sizes = index.sizes if rank == 0: docs = index.doc_idx @@ -802,7 +798,7 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): # Concatenate the data files - merge_files_mpi_bin(data_file_path(filemain), data_file_path(filerank), mpi, comm) + merge_files_mpi_bin(filemain, filerank, mpi, comm) # Combine index files into a single index file - merge_files_mpi_idx(index_file_path(filemain), index_file_path(filerank), mpi, comm, dtype) + merge_files_mpi_idx(filemain, filerank, mpi, comm, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 19d6cfdf9..6504e6589 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -577,7 +577,7 @@ def rank_files_merge(args): # total up bytes read in merge binfile = data_file_path(filerank) - idxfile = data_file_path(filerank) + idxfile = index_file_path(filerank) numbytes[0] += os.stat(binfile)[stat.ST_SIZE] numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] From 9f2ba6ae3204d50576740a9516fe2f7fc5a9e9a8 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 13 Aug 2021 15:17:49 -0700 Subject: [PATCH 09/74] parallel merge for cached dataset --- megatron/data/indexed_dataset.py | 140 ++++++++++++++++++++++++++++++- tools/preprocess_dataset_mpi.py | 6 +- 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 70037d8d6..944903f8c 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -671,7 +671,129 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): comm.barrier() -def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): +def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): + rank = comm.Get_rank() + numranks = comm.Get_size() + + # Create shared output file + f = mpi_create_file(index_file_path(outfile), mpi, comm) + + # Read the index file for the calling rank + index = IndexedDataset(infile) + sizes = index.sizes + if rank == 0: + data_offsets = index.data_offsets + dim_offsets = index.dim_offsets + docs = index.doc_idx + if rank != 0: + data_offsets = index.data_offsets[1:] + dim_offsets = index.dim_offsets[1:] + docs = index.doc_idx[1:] + + # Compute total number of size and document index + # values across all ranks. Also compute the offset + # of the calling rank for each value considering + # the values of sizes/docs for all ranks before the + # calling rank. + numdata = len(data_offsets) + numsize = len(sizes) + numdim = len(dim_offsets) + numdoc = len(docs) + + global_data_count = mpi_get_sum(numdata, mpi, comm) + global_size_count = mpi_get_sum(numsize, mpi, comm) + global_dim_count = mpi_get_sum(numdim, mpi, comm) + global_doc_count = mpi_get_sum(numdoc, mpi, comm) + + global_data_offset = mpi_get_offset(numdata, mpi, comm) + global_size_offset = mpi_get_offset(numsize, mpi, comm) + global_dim_offset = mpi_get_offset(numdim, mpi, comm) + global_doc_offset = mpi_get_offset(numdoc, mpi, comm) + + # Have rank 0 write the file header + pos = 0 + if rank == 0: + f.write(IndexedDataset._HDR_MAGIC) + f.write(struct.pack(' 0 else 0 + dim_offset = mpi_get_offset(dim_last, mpi, comm) + dim_offsets64 += dim_offset + + # Seek to proper offset for this rank and write + # dom offset values into file, stored as int64. + f.seek(pos + global_dim_offset * np.int64().itemsize) + f.write(dim_offsets64.tobytes(order='C')) + del dim_offsets64 + + # Advance past list of dim offset values + pos += global_dim_count * np.int64().itemsize + + # The data index records the byte offset to the start of each + # sentence within the binary data file. + # Adjust our data index values for number of bytes that + # come before the calling rank. + data_offsets64 = np.array(data_offsets, dtype=np.int64) + byte_last = data_offsets[-1] if numdata > 0 else 0 + byte_offset = mpi_get_offset(byte_last, mpi, comm) + data_offsets64 += byte_offset + + # Seek to proper offset for this rank and write + # data (byte) offset index into file, stored as int64. + f.seek(pos + global_data_offset * np.int64().itemsize) + f.write(data_offsets64.tobytes(order='C')) + del data_offsets64 + + # Advance past list of data (byte) offset values + pos += global_data_count * np.int64().itemsize + + # Each sentence is stored as a tensor. + # The tensor for each sentence can be multidimensional. + # The number of tensor dimensions per sentence is variable, + # and the size of each dimension of a sentence is arbitrary. + # The size list records a flattened list of the sizes + # for each dimension of a sentence. + # The list of size values from each rank are + # concatenated and stored as int64. + f.seek(pos + global_size_offset * np.int64().itemsize) + sizes64 = np.array(sizes, dtype=np.int64) + f.write(sizes64.tobytes(order='C')) + del sizes64 + + # Advance past list of size values + pos += global_size_count * np.int64().itemsize + + # The document index points to the position in the sizes + # array for the first sentence of the sample. + docs64 = np.array(docs, dtype=np.int64) + docs64 += global_size_offset + + # Seek to proper offset for this rank and write + # document index into file, stored as int64. + f.seek(pos + global_doc_offset * np.int64().itemsize) + f.write(docs64.tobytes(order='C')) + del docs64 + + f.close() + + # TODO: check that all ranks wrote successfully + comm.barrier() + + +def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): rank = comm.Get_rank() numranks = comm.Get_size() @@ -797,8 +919,22 @@ def merge_files_mpi_idx(outfile, infile, mpi, comm, dtype): def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): + # read index file of this rank to determine its type + indexstr = infer_dataset_impl(filerank) + + # check that all ranks have the same type + indexstrmap = {"cached": 1, "mmap": 2} + indextype = indexstrmap[indexstr] if indexstr in indexstrmap else 0 + rank0type = comm.bcast(indextype, root=0) + #allsame = all_true(indextype == rank0type) + #if not allsame: + # error + # Concatenate the data files merge_files_mpi_bin(filemain, filerank, mpi, comm) # Combine index files into a single index file - merge_files_mpi_idx(filemain, filerank, mpi, comm, dtype) + if indexstr == "cached": + merge_files_mpi_idx_cached(filemain, filerank, mpi, comm, dtype) + elif indexstr == "mmap": + merge_files_mpi_idx_mmap(filemain, filerank, mpi, comm, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 6504e6589..c0e1defe3 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -475,9 +475,10 @@ def rank_files_write(args, dset, idx, encoder): filebase = get_filename(args, key, args.rank) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # divide index list evenly among ranks idx_start, idx_end = get_start_end(len(idx), args.rank, args.numranks) @@ -617,9 +618,10 @@ def rank_files_merge(args): filebase = get_filename(args, key) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # merge all ranks into one file for rank in range(args.numranks): From 72d6c9c276cea6b7839989f6bc47848d85800826 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 13 Aug 2021 15:38:39 -0700 Subject: [PATCH 10/74] add alltrue function --- megatron/data/indexed_dataset.py | 61 ++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 944903f8c..861271d23 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -619,22 +619,53 @@ def mpi_get_min(val, mpi, comm): return outsize[0] +def mpi_alltrue(val, mpi, comm): + """Returns True if all procs input True, False otherwise""" + inval = np.array([val], dtype=np.bool_) + outval = np.zeros_like(inval) + comm.Allreduce(inval, outval, op=mpi.LAND) + return bool(outval[0]) + + def mpi_create_file(filename, mpi, comm): """Create, truncate, and open a file shared by all ranks.""" - # Don't truncate file until all ranks reach this point + success = True + err = None + + # Don't truncate existing file until all ranks reach this point comm.barrier() - # Wait for rank 0 to open (and truncate) file, - # then have all ranks open file for writing. + # Rank 0 creates and truncates file. rank = comm.Get_rank() if rank == 0: - f = open(filename, 'wb') - comm.barrier() - if rank != 0: - f = open(filename, 'r+b') + try: + f = open(filename, 'wb') + except Exception as e: + success = False + err = e + + # Verify that rank 0 created the file + success = mpi_alltrue(success, mpi, comm) + if not success: + if err is not None: + raise err + return None - # TODO: verify that all ranks successfully opened the file - comm.barrier() + # Wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing. + if rank != 0: + try: + f = open(filename, 'r+b') + except Exception as e: + success = False + err = e + + # Verify that all ranks successfully opened the file + success = mpi_alltrue(success, mpi, comm) + if not success: + if err is not None: + raise err + return None return f @@ -919,16 +950,18 @@ def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): - # read index file of this rank to determine its type + # read header of index file of this rank to determine its type indexstr = infer_dataset_impl(filerank) - # check that all ranks have the same type + # map type string to an integer for easier bcast indexstrmap = {"cached": 1, "mmap": 2} indextype = indexstrmap[indexstr] if indexstr in indexstrmap else 0 + + # check that all ranks have the same type rank0type = comm.bcast(indextype, root=0) - #allsame = all_true(indextype == rank0type) - #if not allsame: - # error + sametype = mpi_alltrue(indextype == rank0type and rank0type != 0, mpi, comm) + if not sametype: + assert False, "Cannot merge dataset files of different types" # Concatenate the data files merge_files_mpi_bin(filemain, filerank, mpi, comm) From 8b67becb26010c90bc024259fe1120cb799b8978 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 13 Aug 2021 20:12:32 -0700 Subject: [PATCH 11/74] move collectives to new distdata class, add torch.distributed --- megatron/data/distdata.py | 143 +++++++++++++++++++++++++ megatron/data/indexed_dataset.py | 173 +++++++++---------------------- tools/preprocess_dataset_mpi.py | 13 ++- 3 files changed, 199 insertions(+), 130 deletions(-) create mode 100644 megatron/data/distdata.py diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py new file mode 100644 index 000000000..ed0d63668 --- /dev/null +++ b/megatron/data/distdata.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +import torch.distributed as dist + +class DistData(object): + def __init__(self, mpi=None): + self.MPI = mpi + + if self.MPI: + self.comm = mpi.COMM_WORLD + self.rank = self.comm.Get_rank() + self.numranks = self.comm.Get_rank() + else: + self.rank = dist.get_rank() + self.numranks = dist.get_world_size() + + def barrier(self): + """Globally synchronize all processes""" + if self.MPI: + self.comm.barrier() + else: + dist.barrier() + + def bcast(self, val, root): + """Broadcast a scalar value from root to all ranks""" + if self.MPI: + return self.comm.bcast(val, root=root) + else: + vals = [val] + dist.broadcast_object_list(vals, src=root) + return vals[0] + + def bcast_list(self, vals, root=0): + """Broadcast list of vals from root to all ranks, returns newly allocated list""" + if self.MPI: + return self.comm.bcast(vals, root=root) + else: + # broadcast length of vals list + length = [len(vals)] + dist.broadcast_object_list(length, src=root) + + # allocate a tensor of appropriate size + # initialize tensor with list values on root + if self.rank == root: + tvals = torch.tensor(vals, dtype=torch.int64) + else: + tvals = torch.zeros(length[0], dtype=torch.int64) + + # broadcast tensor from root, and return as a new list + dist.broadcast(tvals, src=root) + return tvals.tolist() + + def alltrue(self, val): + """Returns True if all procs input True, False otherwise""" + if self.MPI: + inval = np.array([val], dtype=np.bool_) + outval = np.zeros_like(inval) + self.comm.Allreduce(inval, outval, op=self.MPI.LAND) + return bool(outval[0]) + else: + tensor = torch.tensor([int(val)], dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.BAND) + return bool(tensor[0]) + + def sum(self, val): + """Compute sum of val, and return total on all ranks.""" + if self.MPI: + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + self.comm.Allreduce(insize, outsize, op=self.MPI.SUM) + return outsize[0] + else: + tensor = torch.tensor([val], dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor[0] + + def exscan(self, val): + """Compute prefix sum (exclusive scan) of val, and return offset of each rank.""" + if self.MPI: + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + self.comm.Scan(insize, outsize, op=self.MPI.SUM) + return outsize[0] - insize[0] + else: + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(self.numranks, dtype=torch.int64) + tensor[self.rank:] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[self.rank]) - val + + def min(self, val): + """Return minimum val to all ranks.""" + if self.MPI: + insize = np.array([val], dtype=np.int64) + outsize = np.zeros_like(insize) + self.comm.Allreduce(insize, outsize, op=self.MPI.MIN) + return outsize[0] + else: + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.tensor([val], dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return tensor[0] + + def open(self, filename): + """Create, truncate, and open a file shared by all ranks.""" + success = True + err = None + + # Don't truncate existing file until all ranks reach this point + self.barrier() + + # Rank 0 creates and truncates file. + if self.rank == 0: + try: + f = open(filename, 'wb') + except Exception as e: + success = False + err = e + + # Verify that rank 0 created the file + success = self.alltrue(success) + if not success: + if err is not None: + raise err + return None + + # Wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing. + if self.rank != 0: + try: + f = open(filename, 'r+b') + except Exception as e: + success = False + err = e + + # Verify that all ranks successfully opened the file + success = self.alltrue(success) + if not success: + if err is not None: + raise err + return None + + return f diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 861271d23..b1a37d904 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -594,102 +594,26 @@ def finalize(self, index_file): index.write(self._sizes, self._doc_idx) -def mpi_get_sum(val, mpi, comm): - """Compute sum of val, and return total on all ranks.""" - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - comm.Allreduce(insize, outsize, op=mpi.SUM) - return outsize[0] - - -def mpi_get_offset(val, mpi, comm): - """Compute prefix sum (exclusive scan) of val, and return offset of each rank.""" - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - comm.Scan(insize, outsize, op=mpi.SUM) - offset = outsize[0] - insize[0] - return offset - - -def mpi_get_min(val, mpi, comm): - """Return minimum val to all ranks.""" - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - comm.Allreduce(insize, outsize, op=mpi.MIN) - return outsize[0] - - -def mpi_alltrue(val, mpi, comm): - """Returns True if all procs input True, False otherwise""" - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - comm.Allreduce(inval, outval, op=mpi.LAND) - return bool(outval[0]) - - -def mpi_create_file(filename, mpi, comm): - """Create, truncate, and open a file shared by all ranks.""" - success = True - err = None - - # Don't truncate existing file until all ranks reach this point - comm.barrier() - - # Rank 0 creates and truncates file. - rank = comm.Get_rank() - if rank == 0: - try: - f = open(filename, 'wb') - except Exception as e: - success = False - err = e - - # Verify that rank 0 created the file - success = mpi_alltrue(success, mpi, comm) - if not success: - if err is not None: - raise err - return None - - # Wait for rank 0 to open (and truncate) file, - # then have all ranks open file for writing. - if rank != 0: - try: - f = open(filename, 'r+b') - except Exception as e: - success = False - err = e - - # Verify that all ranks successfully opened the file - success = mpi_alltrue(success, mpi, comm) - if not success: - if err is not None: - raise err - return None - - return f - - # To create the binary files given a set of per-rank binary # files, one simply concatenates the data from the per-rank # binary files in rank order. We stat each rank file to determine # its size, execute a scan to compute the byte offset where # the calling rank should write its data, seek to proper # spot, and copy the full file. -def merge_files_mpi_bin(outfile, infile, mpi, comm): +def merge_files_dist_bin(outfile, infile, distctx): """Concatenate per-rank binary files into a new file given by outfile""" import stat import shutil # Create shared output file. - f = mpi_create_file(data_file_path(outfile), mpi, comm) + f = distctx.open(data_file_path(outfile)) # get file size of binary file for this rank filesize = os.stat(data_file_path(infile))[stat.ST_SIZE] # compute offset this rank should start copying # its data into the merged file - offset = mpi_get_offset(filesize, mpi, comm) + offset = distctx.exscan(filesize) # seek to appropriate offset and copy data f.seek(offset) @@ -699,15 +623,15 @@ def merge_files_mpi_bin(outfile, infile, mpi, comm): f.close() # TODO: check that all ranks wrote successfully - comm.barrier() + distctx.barrier() -def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): - rank = comm.Get_rank() - numranks = comm.Get_size() +def merge_files_dist_idx_cached(outfile, infile, distctx, dtype): + rank = distctx.rank + numranks = distctx.numranks # Create shared output file - f = mpi_create_file(index_file_path(outfile), mpi, comm) + f = distctx.open(index_file_path(outfile)) # Read the index file for the calling rank index = IndexedDataset(infile) @@ -731,15 +655,15 @@ def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): numdim = len(dim_offsets) numdoc = len(docs) - global_data_count = mpi_get_sum(numdata, mpi, comm) - global_size_count = mpi_get_sum(numsize, mpi, comm) - global_dim_count = mpi_get_sum(numdim, mpi, comm) - global_doc_count = mpi_get_sum(numdoc, mpi, comm) + global_data_count = distctx.sum(numdata) + global_size_count = distctx.sum(numsize) + global_dim_count = distctx.sum(numdim) + global_doc_count = distctx.sum(numdoc) - global_data_offset = mpi_get_offset(numdata, mpi, comm) - global_size_offset = mpi_get_offset(numsize, mpi, comm) - global_dim_offset = mpi_get_offset(numdim, mpi, comm) - global_doc_offset = mpi_get_offset(numdoc, mpi, comm) + global_data_offset = distctx.exscan(numdata) + global_size_offset = distctx.exscan(numsize) + global_dim_offset = distctx.exscan(numdim) + global_doc_offset = distctx.exscan(numdoc) # Have rank 0 write the file header pos = 0 @@ -753,7 +677,7 @@ def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): # Broadcast value of pos from rank 0, # and advance file position past file header on all ranks. - pos = mpi_get_sum(pos, mpi, comm) + pos = distctx.bcast(pos, root=0) # The dimension list records the offset within # the sizes list for each sentence. Adjust dimension @@ -761,7 +685,7 @@ def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): # that come before the calling rank. dim_offsets64 = np.array(dim_offsets, dtype=np.int64) dim_last = dim_offsets[-1] if numdim > 0 else 0 - dim_offset = mpi_get_offset(dim_last, mpi, comm) + dim_offset = distctx.exscan(dim_last) dim_offsets64 += dim_offset # Seek to proper offset for this rank and write @@ -779,7 +703,7 @@ def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): # come before the calling rank. data_offsets64 = np.array(data_offsets, dtype=np.int64) byte_last = data_offsets[-1] if numdata > 0 else 0 - byte_offset = mpi_get_offset(byte_last, mpi, comm) + byte_offset = distctx.exscan(byte_last) data_offsets64 += byte_offset # Seek to proper offset for this rank and write @@ -821,15 +745,15 @@ def merge_files_mpi_idx_cached(outfile, infile, mpi, comm, dtype): f.close() # TODO: check that all ranks wrote successfully - comm.barrier() + distctx.barrier() -def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): - rank = comm.Get_rank() - numranks = comm.Get_size() +def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): + rank = distctx.rank + numranks = distctx.numranks # Create shared output file - f = mpi_create_file(index_file_path(outfile), mpi, comm) + f = distctx.open(index_file_path(outfile)) # Read the index file for the calling rank index = MMapIndexedDataset.Index(index_file_path(infile)) @@ -846,10 +770,10 @@ def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): # calling rank. numsizes = len(sizes) numdocs = len(docs) - total_size_count = mpi_get_sum(numsizes, mpi, comm) - total_docs_count = mpi_get_sum(numdocs, mpi, comm) - size_offset = mpi_get_offset(numsizes, mpi, comm) - docs_offset = mpi_get_offset(numdocs, mpi, comm) + global_size_count = distctx.sum(numsizes) + global_docs_count = distctx.sum(numdocs) + global_size_offset = distctx.exscan(numsizes) + global_docs_offset = distctx.exscan(numdocs) # Have rank 0 write the file header pos = 0 @@ -857,23 +781,23 @@ def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): f.write(MMapIndexedDataset.Index._HDR_MAGIC) f.write(struct.pack(' 0: + if global_size_count > 0: # There is at least one sentence across all ranks, # figure out which rank has the first sentence which # is not necessarily rank 0. minrank = numranks if len(sizes) > 0: minrank = rank - minrank = mpi_get_min(minrank, mpi, comm) + minrank = distctx.min(minrank) # Broadcast size of the first sentence from minrank. - # We "bcast" using all allreduce. pointers_shift = 0 if minrank == rank: pointers_shift = pointers[0] - pointers_shift = mpi_get_sum(pointers_shift, mpi, comm) + pointers_shift = distctx.bcast(pointers_shift, root=minrank) # Zero-base pointers by subtracting size of first # sentence from all values. @@ -922,12 +845,12 @@ def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): # Seek to proper offset for this rank and write # pointer values into file, stored as int64. - f.seek(pos + size_offset * np.int64().itemsize) + f.seek(pos + global_size_offset * np.int64().itemsize) f.write(pointers.tobytes(order='C')) del pointers # Advance past list of pointer values - pos += total_size_count * np.int64().itemsize + pos += global_size_count * np.int64().itemsize # The document index points to the position in the sizes # array for the starting sentence of each document. @@ -935,21 +858,21 @@ def merge_files_mpi_idx_mmap(outfile, infile, mpi, comm, dtype): # Adjust document index for number of sentences that # come before the calling rank. doc_idx = np.array(docs, dtype=np.int64) - doc_idx += size_offset + doc_idx += global_size_offset # Seek to proper offset for this rank and write # document index into file, stored as int64. - f.seek(pos + docs_offset * np.int64().itemsize) + f.seek(pos + global_docs_offset * np.int64().itemsize) f.write(doc_idx.tobytes(order='C')) del doc_idx f.close() # TODO: check that all ranks wrote successfully - comm.barrier() + distctx.barrier() -def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): +def merge_files_dist(filemain, filerank, distctx, dtype=np.int64): # read header of index file of this rank to determine its type indexstr = infer_dataset_impl(filerank) @@ -958,16 +881,16 @@ def merge_files_mpi(filemain, filerank, mpi, comm, dtype=np.int64): indextype = indexstrmap[indexstr] if indexstr in indexstrmap else 0 # check that all ranks have the same type - rank0type = comm.bcast(indextype, root=0) - sametype = mpi_alltrue(indextype == rank0type and rank0type != 0, mpi, comm) + rank0type = distctx.bcast(indextype, root=0) + sametype = distctx.alltrue(indextype == rank0type and rank0type != 0) if not sametype: assert False, "Cannot merge dataset files of different types" # Concatenate the data files - merge_files_mpi_bin(filemain, filerank, mpi, comm) + merge_files_dist_bin(filemain, filerank, distctx) # Combine index files into a single index file if indexstr == "cached": - merge_files_mpi_idx_cached(filemain, filerank, mpi, comm, dtype) + merge_files_dist_idx_cached(filemain, filerank, distctx, dtype) elif indexstr == "mmap": - merge_files_mpi_idx_mmap(filemain, filerank, mpi, comm, dtype) + merge_files_dist_idx_mmap(filemain, filerank, distctx, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index c0e1defe3..55c451c7c 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -65,7 +65,8 @@ from datasets.utils.file_utils import OfflineModeIsEnabled from megatron.tokenizer import build_tokenizer -from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, merge_files_mpi +from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, merge_files_dist +from megatron.data.distdata import DistData # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -277,7 +278,7 @@ def prefix_sum(args, val): else: # torch.distributed doesn't have a scan, so fallback to allreduce tensor = torch.zeros(args.numranks, dtype=torch.int64) - tensor[args.rank:args.numranks] = val + tensor[args.rank:] = val dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return int(tensor[args.rank]) @@ -567,14 +568,16 @@ def rank_files_write(args, dset, idx, encoder): return success, err def rank_files_merge(args): - # only try parallel merge when using MPI - if args.use_mpi: + mpictx = args.MPI if args.use_mpi else None + distctx = DistData(mpi=mpictx) + args.dist_merge = True + if args.dist_merge: merge_start = time.time() numbytes = np.zeros(1, dtype=np.int64) for key in args.columns: filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) - merge_files_mpi(filemain, filerank, args.MPI, args.mpi_comm, dtype=best_fitting_dtype(args.vocab_size)) + merge_files_dist(filemain, filerank, distctx, dtype=best_fitting_dtype(args.vocab_size)) # total up bytes read in merge binfile = data_file_path(filerank) From 3eca1f35557963a119410178038d12f701ac4829 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 13 Aug 2021 20:29:25 -0700 Subject: [PATCH 12/74] drop unused prefix_sum function --- megatron/data/indexed_dataset.py | 6 +++--- tools/preprocess_dataset_mpi.py | 14 -------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index b1a37d904..17d6a30a6 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -627,8 +627,8 @@ def merge_files_dist_bin(outfile, infile, distctx): def merge_files_dist_idx_cached(outfile, infile, distctx, dtype): + # get our rank rank = distctx.rank - numranks = distctx.numranks # Create shared output file f = distctx.open(index_file_path(outfile)) @@ -749,8 +749,8 @@ def merge_files_dist_idx_cached(outfile, infile, distctx, dtype): def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): + # get our rank rank = distctx.rank - numranks = distctx.numranks # Create shared output file f = distctx.open(index_file_path(outfile)) @@ -828,7 +828,7 @@ def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): # There is at least one sentence across all ranks, # figure out which rank has the first sentence which # is not necessarily rank 0. - minrank = numranks + minrank = distctx.numranks if len(sizes) > 0: minrank = rank minrank = distctx.min(minrank) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 55c451c7c..d464755ba 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -268,20 +268,6 @@ def all_sum_(args, vals): tensor = torch.from_numpy(vals) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) -def prefix_sum(args, val): - """Returns result of an inclusive scan prefix sum of val across ranks""" - if args.use_mpi: - inval = np.array([val], dtype=np.int64) - outval = np.zeros_like(inval) - args.mpi_comm.Scan(inval, outval, op=args.MPI.SUM) - return outval[0] - else: - # torch.distributed doesn't have a scan, so fallback to allreduce - tensor = torch.zeros(args.numranks, dtype=torch.int64) - tensor[args.rank:] = val - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return int(tensor[args.rank]) - def all_true(args, val): """Returns True if all procs input True, False otherwise""" if args.use_mpi: From a691b481f1e91dc852d991f0848b51d8f85e5a19 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Sun, 15 Aug 2021 12:19:41 -0700 Subject: [PATCH 13/74] allow ranks to pass a list of files to be merged --- megatron/data/distdata.py | 26 +++- megatron/data/indexed_dataset.py | 247 +++++++++++++++++++++---------- tools/preprocess_dataset_mpi.py | 4 +- 3 files changed, 193 insertions(+), 84 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index ed0d63668..c1fb6f0d3 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -9,7 +9,7 @@ def __init__(self, mpi=None): if self.MPI: self.comm = mpi.COMM_WORLD self.rank = self.comm.Get_rank() - self.numranks = self.comm.Get_rank() + self.numranks = self.comm.Get_size() else: self.rank = dist.get_rank() self.numranks = dist.get_world_size() @@ -101,6 +101,30 @@ def min(self, val): dist.all_reduce(tensor, op=dist.ReduceOp.MIN) return tensor[0] + def minrank(self, cond): + """Find first rank whose condition is True, return that rank if any, None otherwise.""" + minrank = self.numranks + if cond: + minrank = self.rank + minrank = self.min(minrank) + + if minrank < self.numranks: + return minrank + return None + + def bcast_first(self, val): + """Broadcast val from first rank where it is not None, return val if any, None otherwise""" + # Find the first rank with a valid value. + minrank = self.minrank(val is not None) + + # If there is no rank with a valid value, return None + if minrank is None: + return None + + # Otherwise broadcast the value from the first valid rank. + val = self.bcast(val, root=minrank) + return val + def open(self, filename): """Create, truncate, and open a file shared by all ranks.""" success = True diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 17d6a30a6..6b125a118 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -600,50 +600,79 @@ def finalize(self, index_file): # its size, execute a scan to compute the byte offset where # the calling rank should write its data, seek to proper # spot, and copy the full file. -def merge_files_dist_bin(outfile, infile, distctx): +def gather_files_dist_bin(outfile, filelist, distctx): """Concatenate per-rank binary files into a new file given by outfile""" import stat import shutil # Create shared output file. - f = distctx.open(data_file_path(outfile)) + fout = distctx.open(data_file_path(outfile)) - # get file size of binary file for this rank - filesize = os.stat(data_file_path(infile))[stat.ST_SIZE] + # lookup size of each of our binary files + filesizes = [] + for f in filelist: + filesize = os.stat(data_file_path(f))[stat.ST_SIZE] + filesizes.append(filesize) # compute offset this rank should start copying # its data into the merged file - offset = distctx.exscan(filesize) + numbytes = sum(filesizes) + offset = distctx.exscan(numbytes) - # seek to appropriate offset and copy data - f.seek(offset) - with open(data_file_path(infile), "rb") as fsrc: - shutil.copyfileobj(fsrc, f) + # seek to appropriate starting offset in the merged file + fout.seek(offset) - f.close() + # copy in contents of each of our files + for f in filelist: + with open(data_file_path(f), "rb") as fsrc: + shutil.copyfileobj(fsrc, fout) + + fout.close() # TODO: check that all ranks wrote successfully distctx.barrier() -def merge_files_dist_idx_cached(outfile, infile, distctx, dtype): +def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # get our rank rank = distctx.rank # Create shared output file - f = distctx.open(index_file_path(outfile)) + fout = distctx.open(index_file_path(outfile)) # Read the index file for the calling rank - index = IndexedDataset(infile) - sizes = index.sizes - if rank == 0: - data_offsets = index.data_offsets - dim_offsets = index.dim_offsets - docs = index.doc_idx - if rank != 0: - data_offsets = index.data_offsets[1:] - dim_offsets = index.dim_offsets[1:] - docs = index.doc_idx[1:] + sizes = [] + data_offsets = [0] + dim_offsets = [0] + docs = [0] + for f in filelist: + index = IndexedDataset(f) + + doc_offset = len(sizes) + + sizes.extend(index.sizes.tolist()) + + data_offset = data_offsets[-1] + tmpdata_offsets = np.copy(index.data_offsets[1:]) + tmpdata_offsets += data_offset + data_offsets.extend(tmpdata_offsets.tolist()) + + dim_offset = dim_offsets[-1] + tmpdim_offsets = np.copy(index.dim_offsets[1:]) + tmpdim_offsets += dim_offset + dim_offsets.extend(tmpdim_offsets.tolist()) + + tmpdocs = np.copy(index.doc_idx[1:]) + tmpdocs += doc_offset + docs.extend(tmpdocs.tolist()) + + # Drop first entry from the lists that start with + # a "0" value if we're not the first rank with some size. + minrank = distctx.minrank(len(sizes) > 0) + if rank != minrank: + del data_offsets[0] + del dim_offsets[0] + del docs[0] # Compute total number of size and document index # values across all ranks. Also compute the offset @@ -668,12 +697,12 @@ def merge_files_dist_idx_cached(outfile, infile, distctx, dtype): # Have rank 0 write the file header pos = 0 if rank == 0: - f.write(IndexedDataset._HDR_MAGIC) - f.write(struct.pack(' 0) + if rank != minrank: + del docs[0] # Compute total number of size and document index # values across all ranks. Also compute the offset @@ -778,12 +819,12 @@ def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): # Have rank 0 write the file header pos = 0 if rank == 0: - f.write(MMapIndexedDataset.Index._HDR_MAGIC) - f.write(struct.pack(' 0: - minrank = rank - minrank = distctx.min(minrank) - - # Broadcast size of the first sentence from minrank. - pointers_shift = 0 - if minrank == rank: - pointers_shift = pointers[0] - pointers_shift = distctx.bcast(pointers_shift, root=minrank) + pointers_shift = pointers[0] if len(sizes) > 0 else None + pointers_shift = distctx.bcast_first(pointers_shift) # Zero-base pointers by subtracting size of first # sentence from all values. @@ -845,8 +878,8 @@ def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): # Seek to proper offset for this rank and write # pointer values into file, stored as int64. - f.seek(pos + global_size_offset * np.int64().itemsize) - f.write(pointers.tobytes(order='C')) + fout.seek(pos + global_size_offset * np.int64().itemsize) + fout.write(pointers.tobytes(order='C')) del pointers # Advance past list of pointer values @@ -862,35 +895,87 @@ def merge_files_dist_idx_mmap(outfile, infile, distctx, dtype): # Seek to proper offset for this rank and write # document index into file, stored as int64. - f.seek(pos + global_docs_offset * np.int64().itemsize) - f.write(doc_idx.tobytes(order='C')) + fout.seek(pos + global_docs_offset * np.int64().itemsize) + fout.write(doc_idx.tobytes(order='C')) del doc_idx - f.close() + fout.close() # TODO: check that all ranks wrote successfully distctx.barrier() -def merge_files_dist(filemain, filerank, distctx, dtype=np.int64): - # read header of index file of this rank to determine its type - indexstr = infer_dataset_impl(filerank) +# Verify that all files in filelist are of the same index type. +# Returns the identified type {cached, mmap} as a string. +def gather_files_dist_check_type(filelist, distctx): + # map type string to an integer for easier bcast, use 0 for unknown + implmap = {"cached": 1, "mmap": 2} - # map type string to an integer for easier bcast - indexstrmap = {"cached": 1, "mmap": 2} - indextype = indexstrmap[indexstr] if indexstr in indexstrmap else 0 + # check that all files in filelist are of the same type + sametype = True + ourtype = None + for f in filelist: + # read header of index file to determine its type + impl = infer_dataset_impl(f) + implval = implmap[impl] if impl in implmap else 0 - # check that all ranks have the same type - rank0type = distctx.bcast(indextype, root=0) - sametype = distctx.alltrue(indextype == rank0type and rank0type != 0) + if ourtype is None: + ourtype = implval + + if implval != ourtype: + sametype = False + + # check that all ranks have the same type, + # and that there is no unknown type + bcasttype = distctx.bcast_first(ourtype) + sametype = distctx.alltrue(sametype and ourtype == bcasttype and bcasttype != 0) if not sametype: assert False, "Cannot merge dataset files of different types" + # map back to return index string name + for key in implmap.keys(): + if implmap[key] == bcasttype: + return key + + +# Collectively merge files into a new output file specified in filemain. +# Each rank contributes a distinct list of zero or more files in filelist. +# Each rank merges its set of files into filemain collectively with all +# other ranks. +def gather_files_dist(filemain, filelist, distctx, dtype=np.int64): + # TODO: seems like this could be relaxed + # Check that files are all of the same index type + indexstr = gather_files_dist_check_type(filelist, distctx) + # Concatenate the data files - merge_files_dist_bin(filemain, filerank, distctx) + gather_files_dist_bin(filemain, filelist, distctx) # Combine index files into a single index file if indexstr == "cached": - merge_files_dist_idx_cached(filemain, filerank, distctx, dtype) + gather_files_dist_idx_cached(filemain, filelist, distctx, dtype) elif indexstr == "mmap": - merge_files_dist_idx_mmap(filemain, filerank, distctx, dtype) + gather_files_dist_idx_mmap(filemain, filelist, distctx, dtype) + + +def get_start_end(count, rank, numranks): + num, remainder = divmod(count, numranks) + if rank < remainder: + start = (num + 1) * rank + end = start + num + 1 + else: + start = (num + 1) * remainder + num * (rank - remainder) + end = start + num + return start, end + + +# Given a global list of files in filelist, and a set of processed defined +# by the distributed environment in distctx, collectively merge files into +# a new output specified in filemain. +def merge_files_dist(filemain, filelist, distctx, dtype=np.int64): + # TODO: if file sizes vary significantly, it might be better to consider + # file size when splitting the list to different ranks. + + # evenly divide list of files among ranks + start, end = get_start_end(len(filelist), distctx.rank, distctx.numranks) + sublist = filelist[start:end] + return gather_files_dist(filemain, sublist, distctx, dtype) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index d464755ba..9b2171f70 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -65,7 +65,7 @@ from datasets.utils.file_utils import OfflineModeIsEnabled from megatron.tokenizer import build_tokenizer -from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, merge_files_dist +from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist from megatron.data.distdata import DistData # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer @@ -563,7 +563,7 @@ def rank_files_merge(args): for key in args.columns: filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) - merge_files_dist(filemain, filerank, distctx, dtype=best_fitting_dtype(args.vocab_size)) + gather_files_dist(filemain, [filerank], distctx, dtype=best_fitting_dtype(args.vocab_size)) # total up bytes read in merge binfile = data_file_path(filerank) From e4a34e2a759e9c40c81ec66bb9571868514874e8 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Sun, 15 Aug 2021 16:22:05 -0700 Subject: [PATCH 14/74] check that input dataset files exist --- megatron/data/indexed_dataset.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 6b125a118..11c425656 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -908,6 +908,21 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # Verify that all files in filelist are of the same index type. # Returns the identified type {cached, mmap} as a string. def gather_files_dist_check_type(filelist, distctx): + # Sanity check for typos in file names. + # Check that a data file exists for each of our files. + exists = True + for f in filelist: + binfile = data_file_path(f) + if not os.path.exists(binfile): + exists = False + + # Check that all ranks have all of their files. + allexist = distctx.alltrue(exists) + if not allexist: + if not exists: + assert False, f"At least one of the following names was not found: {filelist}" + assert False, f"Some rank is missing its input file" + # map type string to an integer for easier bcast, use 0 for unknown implmap = {"cached": 1, "mmap": 2} @@ -925,11 +940,16 @@ def gather_files_dist_check_type(filelist, distctx): if implval != ourtype: sametype = False - # check that all ranks have the same type, - # and that there is no unknown type + # Check that all ranks have the same type, + # and that there is no unknown type. + # This checks that: + # - all of our own files (if any) are of the same type AND + # - either we have no files or the type of our files match the broadcast type AND + # - the broadcast type is of a known type: {cached, mmap} bcasttype = distctx.bcast_first(ourtype) - sametype = distctx.alltrue(sametype and ourtype == bcasttype and bcasttype != 0) - if not sametype: + matchtype = sametype and (ourtype is None or ourtype == bcasttype) and bcasttype != 0 + allsame = distctx.alltrue(matchtype) + if not allsame: assert False, "Cannot merge dataset files of different types" # map back to return index string name From 8b168cab90c2b1d653bec9149b1645c0b12567f5 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Sun, 15 Aug 2021 17:39:05 -0700 Subject: [PATCH 15/74] fix: using wrong doc_idx list for mmap --- megatron/data/indexed_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 11c425656..021dde329 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -796,7 +796,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): tmpdocs = np.copy(index.doc_idx[1:]) tmpdocs += docs_offset - docs.extend(index.doc_idx.tolist()) + docs.extend(tmpdocs.tolist()) # Drop first entry from the lists that start with # a "0" value if we're not the first rank with some size. From 7a0269383f369764a074164c57c17e6d6e069318 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Sun, 15 Aug 2021 17:39:40 -0700 Subject: [PATCH 16/74] move init dist and collectives to distdata class --- megatron/data/distdata.py | 28 +++++++- tools/preprocess_dataset_mpi.py | 123 ++++++-------------------------- 2 files changed, 45 insertions(+), 106 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index c1fb6f0d3..072d19884 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -1,16 +1,28 @@ import numpy as np + import torch import torch.distributed as dist class DistData(object): - def __init__(self, mpi=None): - self.MPI = mpi + def __init__(self, backend='gloo', use_mpi4py=False): + self.MPI = None + + # use mpi4py instead of torch.distributed if requested + if use_mpi4py: + try: + from mpi4py import MPI + self.MPI = MPI + except: + #print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) + pass + # lookup our process rank and the group size if self.MPI: - self.comm = mpi.COMM_WORLD + self.comm = self.MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.numranks = self.comm.Get_size() else: + dist.init_process_group(backend, init_method="env://") self.rank = dist.get_rank() self.numranks = dist.get_world_size() @@ -125,6 +137,16 @@ def bcast_first(self, val): val = self.bcast(val, root=minrank) return val + def all_sum_(self, vals): + """Sums values in vals element-wise and updates vals with final result on all ranks""" + if self.MPI: + outval = np.zeros_like(vals) + self.comm.Allreduce(vals, outval, op=self.MPI.SUM) + vals[:] = outval + else: + tensor = torch.from_numpy(vals) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + def open(self, filename): """Create, truncate, and open a file shared by all ranks.""" success = True diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 9b2171f70..bd8746cea 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -185,23 +185,13 @@ def get_args(): args.tensor_model_parallel_size = 1 args.vocab_extra_ids = 0 + # initialize our distributed environment # use mpi4py instead of torch.distributed if requested - args.use_mpi = False - if args.mpi4py: - try: - start_mpi = time.time() - - from mpi4py import MPI - args.MPI = MPI - args.use_mpi = True + args.distctx = DistData(use_mpi4py=args.mpi4py, backend=args.torch_backend) - if args.use_mpi: - args.MPI.COMM_WORLD.barrier() - end_mpi = time.time() - if args.MPI.COMM_WORLD.Get_rank() == 0: - print(f"Seconds to import MPI: {end_mpi - start_mpi}") - except: - print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) + # some functions like build_tokenizer use args.rank to filter stdout messages + args.rank = args.distctx.rank + args.numranks = args.distctx.numranks return args @@ -209,77 +199,6 @@ def format_byterate(byterate): mbps = byterate / (1024.0 * 1024.0) return f"{mbps:0.3f} MB/s" -def init_distributed(args): - """Determine which distributed runtime to use and connect up processes""" - # select our distributed runtime (MPI or torch.distributed) - # lookup our process rank and the group size - # some functions like build_tokenizer use args.rank to filter stdout messages - if args.use_mpi: - args.mpi_comm = args.MPI.COMM_WORLD - args.rank = args.mpi_comm.Get_rank() - args.numranks = args.mpi_comm.Get_size() - else: - time_start = time.time() - - dist.init_process_group(args.torch_backend, init_method="env://") - args.rank = dist.get_rank() - args.numranks = dist.get_world_size() - - barrier(args) - time_end = time.time() - if args.rank == 0: - print(f"Seconds to init_process_group: {time_end - time_start}") - -def barrier(args): - """Globally synchronize all processes.""" - if args.use_mpi: - args.mpi_comm.barrier() - else: - dist.barrier() - -def bcast(args, vals, root=0): - """Broadcast list of vals from root to all ranks, returns newly allocated list""" - if args.use_mpi: - vals = args.mpi_comm.bcast(vals, root=root) - return vals - else: - # broadcast length of vals list - length = [len(vals)] - dist.broadcast_object_list(length, src=root) - - # allocate a tensor of appropriate size - # initialize tensor with list values on root - if args.rank == root: - tvals = torch.tensor(vals, dtype=torch.int64) - else: - tvals = torch.zeros(length[0], dtype=torch.int64) - - # broadcast tensor from root, and return as a new list - dist.broadcast(tvals, src=root) - return tvals.tolist() - -def all_sum_(args, vals): - """Sums values in vals element-wise and updates vals with final result on all ranks""" - if args.use_mpi: - outval = np.zeros_like(vals) - args.mpi_comm.Allreduce(vals, outval, op=args.MPI.SUM) - vals[:] = outval - else: - tensor = torch.from_numpy(vals) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - -def all_true(args, val): - """Returns True if all procs input True, False otherwise""" - if args.use_mpi: - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - args.mpi_comm.Allreduce(inval, outval, op=args.MPI.LAND) - return bool(outval[0]) - else: - tensor = torch.tensor([int(val)], dtype=torch.int32) - dist.all_reduce(tensor, op=dist.ReduceOp.BAND) - return bool(tensor[0]) - def load_dset(args): # Avoid downloading datasets unless explicitly requested. # We allow the user to override this behavior if they set $HF_DATASETS_OFFLINE. @@ -324,7 +243,7 @@ def load_dset(args): err = e # determine whether rank 0 succeeded in loading the dataset - success = all_true(args, success) + success = args.distctx.alltrue(success) if not success: return None, err @@ -340,7 +259,7 @@ def load_dset(args): err = e # verify that all ranks loaded the dataset - success = all_true(args, success) + success = args.distctx.alltrue(success) if not success: if args.rank == 0: print(f"ERROR: At least one process failed to load {dsetname}", flush=True) @@ -375,9 +294,9 @@ def select_sample_list(args, dset_size): # broadcast sample index values from rank 0 to all procs time_bcast = time.time() - idx = bcast(args, idx, root=0) + idx = args.distctx.bcast(idx, root=0) - barrier(args) + args.distctx.barrier() time_end = time.time() if args.rank == 0: print(f"Select index stats:") @@ -529,12 +448,12 @@ def rank_files_write(args, dset, idx, encoder): print(f"{timestamp}: Waiting for ranks to finalize files ...", flush=True) # wait for all ranks to finish their files - barrier(args) + args.distctx.barrier() tokenize_end = time.time() # compute total stats across all processes - all_sum_(args, times) - all_sum_(args, dset_stats) + args.distctx.all_sum_(times) + args.distctx.all_sum_(dset_stats) if args.rank == 0: secs = tokenize_end - tokenize_start docrate = dset_stats[0] / secs if secs > 0.0 else 0.0 @@ -550,12 +469,10 @@ def rank_files_write(args, dset, idx, encoder): print(f" Total write seconds {times[2]}, {times[2]/dset_stats[0]} sec/sample") # allreduce to check whether all ranks wrote their part successfully - success = all_true(args, success) + success = args.distctx.alltrue(success) return success, err def rank_files_merge(args): - mpictx = args.MPI if args.use_mpi else None - distctx = DistData(mpi=mpictx) args.dist_merge = True if args.dist_merge: merge_start = time.time() @@ -563,7 +480,7 @@ def rank_files_merge(args): for key in args.columns: filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) - gather_files_dist(filemain, [filerank], distctx, dtype=best_fitting_dtype(args.vocab_size)) + gather_files_dist(filemain, [filerank], args.distctx, dtype=best_fitting_dtype(args.vocab_size)) # total up bytes read in merge binfile = data_file_path(filerank) @@ -578,8 +495,8 @@ def rank_files_merge(args): os.rename(binfile, binfile + ".par") os.rename(idxfile, idxfile + ".par") - barrier(args) - all_sum_(args, numbytes) + args.distctx.barrier() + args.distctx.all_sum_(numbytes) merge_end = time.time() if args.rank == 0: secs = merge_end - merge_start @@ -640,7 +557,7 @@ def rank_files_merge(args): print(f" {numbytes} bytes {format_byterate(byterate)}") # hold everyone until rank 0 is done - barrier(args) + args.distctx.barrier() def rank_files_delete(args): # delete per-rank files @@ -659,14 +576,14 @@ def rank_files_delete(args): os.remove(idxfile) # hold everyone until all are done - barrier(args) + args.distctx.barrier() def main(): args = get_args() startup_start = time.time() # connect processes and cache our rank and number of procs in args - init_distributed(args) + #init_distributed(args) # load the dataset dset, err = load_dset(args) @@ -694,7 +611,7 @@ def main(): args.level = "sentence" # wait for all ranks before stopping timer - barrier(args) + args.distctx.barrier() startup_end = time.time() if args.rank == 0: print(f"Seconds to startup: {startup_end - startup_start}") From eca2940f0984edb4182e1629b1047af7dcb338f2 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 13:39:45 -0700 Subject: [PATCH 17/74] add --merge option, move parallel/serial to their own functions --- tools/preprocess_dataset_mpi.py | 130 +++++++++++++++++++------------- 1 file changed, 76 insertions(+), 54 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index bd8746cea..8076df75d 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -161,24 +161,29 @@ def get_args(): choices=['lazy', 'cached', 'mmap']) group = parser.add_argument_group(title='runtime') - group.add_argument('--mpi4py', action='store_true', - help='Assume script has been launched as an MPI job, and use MPI for communication.') - group.add_argument('--torch-backend', type=str, default='gloo', choices = ['gloo', 'mpi'], + group.add_argument('--torch-backend', type=str, default='gloo', choices=['gloo', 'mpi'], help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') + group.add_argument('--mpi4py', action='store_true', + help='Assume script has been launched as an MPI job, and use mpi4py for communication.') + group.add_argument('--merge', type=str, default='parallel', choices=['parallel', 'serial', 'both'], + help=('Method to merge intermediate per-rank files into the final data files. ' + 'With "parallel", each rank writes directly to the final files, ' + 'while rank 0 copies data from all per-rank files with "serial". ' + 'A parallel merge can be faster, but for correctness, it requires the underlying file system ' + 'to support parallel write operations to a file shared among multiple processes. ' + 'One can choose "both" for testing purposes, in which case the final files written ' + 'by the parallel method are given an additional ".par" extension.')) group.add_argument('--scratch', type=str, default=None, - help='Path to local storage on compute nodes to write per-rank files before merging, like /dev/shm.') + help=('Path to local storage on compute nodes to write per-rank files before merging, like /dev/shm. ' + 'One can only use this option with a parallel merge.')) group.add_argument('--log-interval', type=int, default=30, help='Seconds between progress updates (0 to disable)') args = parser.parse_args() args.keep_empty = False - if args.tokenizer_type.lower().startswith('bert'): - if not args.split_sentences: - print("Bert tokenizer detected, are you sure you don't want to split sentences?") - # some default/dummy values for the tokenizer args.rank = 0 args.make_vocab_size_divisible_by = 128 @@ -193,6 +198,16 @@ def get_args(): args.rank = args.distctx.rank args.numranks = args.distctx.numranks + if args.tokenizer_type.lower().startswith('bert'): + if not args.split_sentences: + if args.rank == 0: + print("Bert tokenizer detected, are you sure you don't want to split sentences?") + + # TODO: perhaps more user friendly to disable scratch and print a warning? + # check that serial merge is not attempted with scratch + if args.scratch is not None and args.merge != 'parallel': + assert False, "The --scratch option is only valid with --merge=parallel" + return args def format_byterate(byterate): @@ -347,7 +362,7 @@ def get_filename(args, key, rank=None): pathname = args.output_prefix # redirect per-rank file to scratch dir if defined - if args.scratch and rank is not None: + if args.scratch is not None and rank is not None: basename = os.path.basename(pathname) pathname = os.path.join(args.scratch, basename) @@ -413,8 +428,6 @@ def rank_files_write(args, dset, idx, encoder): times[0] += start_encode - start_read times[1] += start_write - start_encode times[2] += end_write - start_write -# dset_stats[0] += 1 -# dset_stats[2] += len(text) if args.rank == 0 and args.log_interval > 0 and time.time() > progress_next: current = time.time() @@ -472,45 +485,43 @@ def rank_files_write(args, dset, idx, encoder): success = args.distctx.alltrue(success) return success, err -def rank_files_merge(args): - args.dist_merge = True - if args.dist_merge: - merge_start = time.time() - numbytes = np.zeros(1, dtype=np.int64) - for key in args.columns: - filemain = get_filename(args, key) - filerank = get_filename(args, key, args.rank) - gather_files_dist(filemain, [filerank], args.distctx, dtype=best_fitting_dtype(args.vocab_size)) - - # total up bytes read in merge - binfile = data_file_path(filerank) - idxfile = index_file_path(filerank) - numbytes[0] += os.stat(binfile)[stat.ST_SIZE] - numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] - - # rename files for now and also do regular merge so we can time both and "cmp" them - if args.rank == 0: - binfile = data_file_path(filemain) - idxfile = index_file_path(filemain) - os.rename(binfile, binfile + ".par") - os.rename(idxfile, idxfile + ".par") +def rank_files_merge_parallel(args): + """Each process directly writes its portion of the data from its per-rank file into the final file.""" + merge_start = time.time() + numbytes = np.zeros(1, dtype=np.int64) + for key in args.columns: + filemain = get_filename(args, key) + filerank = get_filename(args, key, args.rank) + gather_files_dist(filemain, [filerank], args.distctx, dtype=best_fitting_dtype(args.vocab_size)) + + # total up bytes read in merge + binfile = data_file_path(filerank) + idxfile = index_file_path(filerank) + numbytes[0] += os.stat(binfile)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] + + # If user want to use both a parallel and serial merge (for testing), + # rename the parallel output files so that the serial merge does not clobber them. + if args.merge == 'both' and args.rank == 0: + binfile = data_file_path(filemain) + idxfile = index_file_path(filemain) + os.rename(binfile, binfile + ".par") + os.rename(idxfile, idxfile + ".par") + + # Total up number of bytes read across all ranks, + # and wait on all ranks before stopping the timer. + args.distctx.all_sum_(numbytes) + merge_end = time.time() + if args.rank == 0: + secs = merge_end - merge_start + byterate = numbytes[0] / secs if secs > 0.0 else 0.0 + print("Parallel merge stats:") + print(f" Scratch: {args.scratch}") + print(f" Seconds to merge: {secs}") + print(f" {int(numbytes)} bytes {format_byterate(byterate)}") - args.distctx.barrier() - args.distctx.all_sum_(numbytes) - merge_end = time.time() - if args.rank == 0: - secs = merge_end - merge_start - byterate = numbytes[0] / secs if secs > 0.0 else 0.0 - print("Parallel merge stats:") - print(f" Scratch: {args.scratch}") - print(f" Seconds to merge: {secs}") - print(f" {int(numbytes)} bytes {format_byterate(byterate)}") - - # if using node-local storage, skip sequential merge test - if args.scratch: - return - - # rank 0 merges all per-rank files +def rank_files_merge_serial(args): + """Rank 0 merges data from all per-rank files into the final file.""" if args.rank == 0: print("Merging rank files ...", flush=True) merge_start = time.time() @@ -539,8 +550,9 @@ def rank_files_merge(args): # sum up the number of merged bytes binfile = data_file_path(infile) - filesize = os.stat(binfile)[stat.ST_SIZE] - numbytes += filesize + idxfile = index_file_path(infile) + numbytes += os.stat(binfile)[stat.ST_SIZE] + numbytes += os.stat(idxfile)[stat.ST_SIZE] # finalize the merged file print("Finalizing merged file ...", flush=True) @@ -559,6 +571,19 @@ def rank_files_merge(args): # hold everyone until rank 0 is done args.distctx.barrier() +def rank_files_merge(args): + # use parallel merge if asked + if args.merge in ['parallel', 'both']: + rank_files_merge_parallel(args) + + # if using node-local storage, skip sequential merge + if args.scratch is not None: + return + + # can fall back to a serial merge + if args.merge in ['serial', 'both']: + rank_files_merge_serial(args) + def rank_files_delete(args): # delete per-rank files if args.rank == 0: @@ -582,9 +607,6 @@ def main(): args = get_args() startup_start = time.time() - # connect processes and cache our rank and number of procs in args - #init_distributed(args) - # load the dataset dset, err = load_dset(args) if dset is None: From ec11281f765d088b4c05b4bbd8cac3d802893aa0 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 15:50:33 -0500 Subject: [PATCH 18/74] Update megatron/data/distdata.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/distdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 072d19884..14d491dfe 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -17,7 +17,7 @@ def __init__(self, backend='gloo', use_mpi4py=False): pass # lookup our process rank and the group size - if self.MPI: + if self.MPI is not None: self.comm = self.MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.numranks = self.comm.Get_size() From 354d13bd6dd2804bcec7d5631fbb17b17903649f Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:07:01 -0500 Subject: [PATCH 19/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 430eac52c..3b77a972f 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -609,10 +609,7 @@ def gather_files_dist_bin(outfile, filelist, distctx): fout = distctx.open(data_file_path(outfile)) # lookup size of each of our binary files - filesizes = [] - for f in filelist: - filesize = os.stat(data_file_path(f))[stat.ST_SIZE] - filesizes.append(filesize) + filesizes = [os.stat(data_file_path(f))[stat.ST_SIZE] for f in filelist] # compute offset this rank should start copying # its data into the merged file From 2dc3f7adc8a4e6dd315ee171e3fcdfee3f64f018 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:20:11 -0500 Subject: [PATCH 20/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 3b77a972f..efdd00dac 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -715,7 +715,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): dim_offsets64 += dim_offset # Seek to proper offset for this rank and write - # dom offset values into file, stored as int64. + # dim offset values into file, stored as int64. fout.seek(pos + global_dim_offset * np.int64().itemsize) fout.write(dim_offsets64.tobytes(order='C')) del dim_offsets64 From 980e9043e3c26575cd9c7771176721a6744b9e24 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 17:19:13 -0500 Subject: [PATCH 21/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index efdd00dac..85a688a61 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -649,9 +649,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): sizes.extend(index.sizes.tolist()) - data_offset = data_offsets[-1] - tmpdata_offsets = np.copy(index.data_offsets[1:]) - tmpdata_offsets += data_offset + tmpdata_offsets = index.data_offsets[1:] + data_offsets[-1] data_offsets.extend(tmpdata_offsets.tolist()) dim_offset = dim_offsets[-1] From ebd20a6fadcfd7f189f0e1d484e53babbf002f0b Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 17:20:08 -0500 Subject: [PATCH 22/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 85a688a61..049b4e176 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -652,9 +652,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): tmpdata_offsets = index.data_offsets[1:] + data_offsets[-1] data_offsets.extend(tmpdata_offsets.tolist()) - dim_offset = dim_offsets[-1] - tmpdim_offsets = np.copy(index.dim_offsets[1:]) - tmpdim_offsets += dim_offset + tmpdim_offsets = index.dim_offsets[1:] + dim_offsets[-1] dim_offsets.extend(tmpdim_offsets.tolist()) tmpdocs = np.copy(index.doc_idx[1:]) From 69b2f49b113661a836369c4b8e87b9f8a7b3044e Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 17:20:37 -0500 Subject: [PATCH 23/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 049b4e176..93f6e69f8 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -655,8 +655,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): tmpdim_offsets = index.dim_offsets[1:] + dim_offsets[-1] dim_offsets.extend(tmpdim_offsets.tolist()) - tmpdocs = np.copy(index.doc_idx[1:]) - tmpdocs += doc_offset + tmpdocs = index.doc_idx[1:] + doc_offset docs.extend(tmpdocs.tolist()) # Drop first entry from the lists that start with From 50de06ace09d156219e6a532e6e5992cd2f1d93d Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 17:21:29 -0500 Subject: [PATCH 24/74] Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- megatron/data/indexed_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 93f6e69f8..bfee48f98 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -786,8 +786,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): sizes.extend(index.sizes.tolist()) - tmpdocs = np.copy(index.doc_idx[1:]) - tmpdocs += docs_offset + tmpdocs = index.doc_idx[1:] + docs_offset docs.extend(tmpdocs.tolist()) # Drop first entry from the lists that start with From af290ad94ca9fa0747243ee411c92d5e1a665059 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:03:02 -0700 Subject: [PATCH 25/74] drop extraneous numpy tolist calls --- megatron/data/indexed_dataset.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index bfee48f98..968a43f2b 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -647,16 +647,10 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): doc_offset = len(sizes) - sizes.extend(index.sizes.tolist()) - - tmpdata_offsets = index.data_offsets[1:] + data_offsets[-1] - data_offsets.extend(tmpdata_offsets.tolist()) - - tmpdim_offsets = index.dim_offsets[1:] + dim_offsets[-1] - dim_offsets.extend(tmpdim_offsets.tolist()) - - tmpdocs = index.doc_idx[1:] + doc_offset - docs.extend(tmpdocs.tolist()) + sizes.extend(index.sizes) + data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) + dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) + docs.extend(index.doc_idx[1:] + doc_offset) # Drop first entry from the lists that start with # a "0" value if we're not the first rank with some size. @@ -784,10 +778,8 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): docs_offset = len(sizes) - sizes.extend(index.sizes.tolist()) - - tmpdocs = index.doc_idx[1:] + docs_offset - docs.extend(tmpdocs.tolist()) + sizes.extend(index.sizes) + docs.extend(index.doc_idx[1:] + docs_offset) # Drop first entry from the lists that start with # a "0" value if we're not the first rank with some size. From 4b58c74c569ee62e5a3d51544f38d4b0181ab70e Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:03:55 -0700 Subject: [PATCH 26/74] rename self.MPI to mpi4py --- megatron/data/distdata.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 14d491dfe..fc89249aa 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -5,20 +5,19 @@ class DistData(object): def __init__(self, backend='gloo', use_mpi4py=False): - self.MPI = None - # use mpi4py instead of torch.distributed if requested + self.mpi4py = None if use_mpi4py: try: from mpi4py import MPI - self.MPI = MPI + self.mpi4py = MPI except: #print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) pass # lookup our process rank and the group size - if self.MPI is not None: - self.comm = self.MPI.COMM_WORLD + if self.mpi4py is not None: + self.comm = self.mpi4py.COMM_WORLD self.rank = self.comm.Get_rank() self.numranks = self.comm.Get_size() else: @@ -28,14 +27,14 @@ def __init__(self, backend='gloo', use_mpi4py=False): def barrier(self): """Globally synchronize all processes""" - if self.MPI: + if self.mpi4py is not None: self.comm.barrier() else: dist.barrier() def bcast(self, val, root): """Broadcast a scalar value from root to all ranks""" - if self.MPI: + if self.mpi4py is not None: return self.comm.bcast(val, root=root) else: vals = [val] @@ -44,7 +43,7 @@ def bcast(self, val, root): def bcast_list(self, vals, root=0): """Broadcast list of vals from root to all ranks, returns newly allocated list""" - if self.MPI: + if self.mpi4py is not None: return self.comm.bcast(vals, root=root) else: # broadcast length of vals list @@ -64,10 +63,10 @@ def bcast_list(self, vals, root=0): def alltrue(self, val): """Returns True if all procs input True, False otherwise""" - if self.MPI: + if self.mpi4py is not None: inval = np.array([val], dtype=np.bool_) outval = np.zeros_like(inval) - self.comm.Allreduce(inval, outval, op=self.MPI.LAND) + self.comm.Allreduce(inval, outval, op=self.mpi4py.LAND) return bool(outval[0]) else: tensor = torch.tensor([int(val)], dtype=torch.int32) @@ -76,10 +75,10 @@ def alltrue(self, val): def sum(self, val): """Compute sum of val, and return total on all ranks.""" - if self.MPI: + if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.MPI.SUM) + self.comm.Allreduce(insize, outsize, op=self.mpi4py.SUM) return outsize[0] else: tensor = torch.tensor([val], dtype=torch.int64) @@ -88,10 +87,10 @@ def sum(self, val): def exscan(self, val): """Compute prefix sum (exclusive scan) of val, and return offset of each rank.""" - if self.MPI: + if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) - self.comm.Scan(insize, outsize, op=self.MPI.SUM) + self.comm.Scan(insize, outsize, op=self.mpi4py.SUM) return outsize[0] - insize[0] else: # torch.distributed doesn't have a scan, so fallback to allreduce @@ -102,10 +101,10 @@ def exscan(self, val): def min(self, val): """Return minimum val to all ranks.""" - if self.MPI: + if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.MPI.MIN) + self.comm.Allreduce(insize, outsize, op=self.mpi4py.MIN) return outsize[0] else: # torch.distributed doesn't have a scan, so fallback to allreduce @@ -139,9 +138,9 @@ def bcast_first(self, val): def all_sum_(self, vals): """Sums values in vals element-wise and updates vals with final result on all ranks""" - if self.MPI: + if self.mpi4py is not None: outval = np.zeros_like(vals) - self.comm.Allreduce(vals, outval, op=self.MPI.SUM) + self.comm.Allreduce(vals, outval, op=self.mpi4py.SUM) vals[:] = outval else: tensor = torch.from_numpy(vals) From 71a2fdcfcacf923f4ef2413b795db7e8b1d71e9f Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:05:44 -0700 Subject: [PATCH 27/74] handle case where no ranks have elements in their file --- megatron/data/indexed_dataset.py | 14 ++++++++++---- tools/preprocess_dataset_mpi.py | 9 ++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 968a43f2b..aa70ae0a4 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -655,7 +655,8 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # Drop first entry from the lists that start with # a "0" value if we're not the first rank with some size. minrank = distctx.minrank(len(sizes) > 0) - if rank != minrank: + firstrank = minrank if minrank is not None else 0 + if rank != firstrank: del data_offsets[0] del dim_offsets[0] del docs[0] @@ -685,8 +686,10 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): if rank == 0: fout.write(IndexedDataset._HDR_MAGIC) fout.write(struct.pack(' 0) - if rank != minrank: + firstrank = minrank if minrank is not None else 0 + if rank != firstrank: del docs[0] # Compute total number of size and document index @@ -794,8 +798,10 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # calling rank. numsizes = len(sizes) numdocs = len(docs) + global_size_count = distctx.sum(numsizes) global_docs_count = distctx.sum(numdocs) + global_size_offset = distctx.exscan(numsizes) global_docs_offset = distctx.exscan(numdocs) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 8076df75d..85256e3d6 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -472,14 +472,17 @@ def rank_files_write(args, dset, idx, encoder): docrate = dset_stats[0] / secs if secs > 0.0 else 0.0 sentrate = dset_stats[1] / secs if secs > 0.0 else 0.0 byterate = dset_stats[2] / secs if secs > 0.0 else 0.0 + secs_read_per_sample = times[0] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_encode_per_sample = times[1] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_write_per_sample = times[2] / dset_stats[0] if dset_stats[0] > 0 else 0.0 print("Process stats:") print(f" Seconds to process: {secs}") print(f" {dset_stats[0]} docs {docrate} docs/sec") print(f" {dset_stats[1]} sents {sentrate} sents/sec") print(f" {dset_stats[2]} bytes {format_byterate(byterate)}") - print(f" Total read seconds {times[0]}, {times[0]/dset_stats[0]} sec/sample") - print(f" Total encode seconds {times[1]}, {times[1]/dset_stats[0]} sec/sample") - print(f" Total write seconds {times[2]}, {times[2]/dset_stats[0]} sec/sample") + print(f" Total read seconds {times[0]}, {secs_read_per_sample} sec/sample") + print(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") + print(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") # allreduce to check whether all ranks wrote their part successfully success = args.distctx.alltrue(success) From 73d3a2470ae5714dd8a7851af7a34122b71142ac Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:09:58 -0700 Subject: [PATCH 28/74] rename tokenize_start to time_start --- tools/preprocess_dataset_mpi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 85256e3d6..1145dd978 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -374,7 +374,7 @@ def get_filename(args, key, rank=None): return filename def rank_files_write(args, dset, idx, encoder): - tokenize_start = time.time() + time_start = time.time() # we'll total up the number of docs, sentences, and bytes # processed across all ranks @@ -433,7 +433,7 @@ def rank_files_write(args, dset, idx, encoder): current = time.time() progress_next = current + float(args.log_interval) - elapsed = current - tokenize_start + elapsed = current - time_start timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") docs = dset_stats[0] * args.numranks percent = docs / len(idx) * 100.0 @@ -462,13 +462,13 @@ def rank_files_write(args, dset, idx, encoder): # wait for all ranks to finish their files args.distctx.barrier() - tokenize_end = time.time() + time_end = time.time() # compute total stats across all processes args.distctx.all_sum_(times) args.distctx.all_sum_(dset_stats) if args.rank == 0: - secs = tokenize_end - tokenize_start + secs = time_end - time_start docrate = dset_stats[0] / secs if secs > 0.0 else 0.0 sentrate = dset_stats[1] / secs if secs > 0.0 else 0.0 byterate = dset_stats[2] / secs if secs > 0.0 else 0.0 From b9e69bea3d9b8af7ab589565ae2660ce6e75d429 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:38:24 -0700 Subject: [PATCH 29/74] drop unrelated comment in distdata.min --- megatron/data/distdata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index fc89249aa..932a913c8 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -107,7 +107,6 @@ def min(self, val): self.comm.Allreduce(insize, outsize, op=self.mpi4py.MIN) return outsize[0] else: - # torch.distributed doesn't have a scan, so fallback to allreduce tensor = torch.tensor([val], dtype=torch.int64) dist.all_reduce(tensor, op=dist.ReduceOp.MIN) return tensor[0] From da615c6dd428ec791c0af99679501459d81cf68b Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:45:55 -0700 Subject: [PATCH 30/74] add comment why pointers_shift is not None and add assert --- megatron/data/indexed_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index aa70ae0a4..3aa0025d5 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -855,12 +855,14 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # we first need to find the rank having the first sentence, # then bcast that size to all ranks. if global_size_count > 0: - # There is at least one sentence across all ranks, - # figure out which rank has the first sentence which - # is not necessarily rank 0. + # Since global_size_count > 0, there is at least one sentence across all ranks. + # Get the value from the first rank that has a value, which may not be rank 0. pointers_shift = pointers[0] if len(sizes) > 0 else None pointers_shift = distctx.bcast_first(pointers_shift) + # Since there is at least one, bcast_first should return some value other than None. + assert pointers_shift is not None, "Expected at least one rank to have a valid element" + # Zero-base pointers by subtracting size of first # sentence from all values. pointers -= pointers_shift From c42f41f576ade2fe9e5cd513680b8873fb2cede7 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 16:48:27 -0700 Subject: [PATCH 31/74] note why pointers uses sizes count and offset values --- megatron/data/indexed_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 3aa0025d5..719abe157 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -867,6 +867,10 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # sentence from all values. pointers -= pointers_shift + # Since the pointers array is the same length as the sizes array, + # we use global_size_offset and global_size_count to position + # within the file for writing the pointer values. + # Seek to proper offset for this rank and write # pointer values into file, stored as int64. fout.seek(pos + global_size_offset * np.int64().itemsize) From a3a7d539dcf6f96045a7768c553b5fdcaca248cf Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 19:44:16 -0700 Subject: [PATCH 32/74] can just rely on rank 0 for the leading 0 element --- megatron/data/indexed_dataset.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 719abe157..14f6861e8 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -652,11 +652,9 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) docs.extend(index.doc_idx[1:] + doc_offset) - # Drop first entry from the lists that start with - # a "0" value if we're not the first rank with some size. - minrank = distctx.minrank(len(sizes) > 0) - firstrank = minrank if minrank is not None else 0 - if rank != firstrank: + # Drop the zero entry from the lists that start with + # a "0" value unless we're rank 0 + if rank != 0: del data_offsets[0] del dim_offsets[0] del docs[0] @@ -784,11 +782,9 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): sizes.extend(index.sizes) docs.extend(index.doc_idx[1:] + docs_offset) - # Drop first entry from the lists that start with - # a "0" value if we're not the first rank with some size. - minrank = distctx.minrank(len(sizes) > 0) - firstrank = minrank if minrank is not None else 0 - if rank != firstrank: + # Drop the zero entry from the lists that start with + # a "0" value unless we're rank 0 + if rank != 0: del docs[0] # Compute total number of size and document index From 163310aa244f4516bba4b039aedc618ec640fc99 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 21:58:36 -0700 Subject: [PATCH 33/74] add write_list function --- megatron/data/indexed_dataset.py | 105 +++++++++++++------------------ 1 file changed, 42 insertions(+), 63 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 14f6861e8..b4b1fa2a1 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -630,6 +630,29 @@ def gather_files_dist_bin(outfile, filelist, distctx): distctx.barrier() +def write_list(fout, pos, vals, shift, offset, total, dtype): + """Write list of values to fout and return new file position. + + Copy list of values in vals to a numpy array of type dtype. + Add a constant shift value to all elements. + Write array to file at given offset and scaled by size of the datatype. + Compute and return new file position given a total number of elements written. + """ + + # Make a copy of the vals list using the requested datatype. + npvals = np.array(vals, dtype=dtype) + + # Shift values in the list by a constant value. + npvals += shift + + # Seek to proper offset for this rank and write + # values into file, stored as given datatype. + fout.seek(pos + offset * dtype().itemsize) + fout.write(npvals.tobytes(order='C')) + + # Advance file pointer past end of this section. + return pos + total * dtype().itemsize + def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # get our rank rank = distctx.rank @@ -637,21 +660,26 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # Create shared output file fout = distctx.open(index_file_path(outfile)) - # Read the index file for the calling rank + # Read each index file and append items to our lists sizes = [] data_offsets = [0] dim_offsets = [0] docs = [0] for f in filelist: - index = IndexedDataset(f) - doc_offset = len(sizes) + index = IndexedDataset(f) sizes.extend(index.sizes) data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) docs.extend(index.doc_idx[1:] + doc_offset) + # Capture the last value in each array before we delete any items. + # Note this may be zero on any rank that has no items, + # but zero is the correct value in that case. + dim_last = dim_offsets[-1] + data_last = data_offsets[-1] + # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 if rank != 0: @@ -699,37 +727,15 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # the sizes list for each sentence. Adjust dimension # offset values based on the number of offsets # that come before the calling rank. - dim_offsets64 = np.array(dim_offsets, dtype=np.int64) - dim_last = dim_offsets[-1] if numdim > 0 else 0 - dim_offset = distctx.exscan(dim_last) - dim_offsets64 += dim_offset - - # Seek to proper offset for this rank and write - # dim offset values into file, stored as int64. - fout.seek(pos + global_dim_offset * np.int64().itemsize) - fout.write(dim_offsets64.tobytes(order='C')) - del dim_offsets64 - - # Advance past list of dim offset values - pos += global_dim_count * np.int64().itemsize + dim_shift = distctx.exscan(dim_last) + pos = write_list(fout, pos, dim_offsets, dim_shift, global_dim_offset, global_dim_count, np.int64) # The data index records the byte offset to the start of each # sentence within the binary data file. # Adjust our data index values for number of bytes that # come before the calling rank. - data_offsets64 = np.array(data_offsets, dtype=np.int64) - byte_last = data_offsets[-1] if numdata > 0 else 0 - byte_offset = distctx.exscan(byte_last) - data_offsets64 += byte_offset - - # Seek to proper offset for this rank and write - # data (byte) offset index into file, stored as int64. - fout.seek(pos + global_data_offset * np.int64().itemsize) - fout.write(data_offsets64.tobytes(order='C')) - del data_offsets64 - - # Advance past list of data (byte) offset values - pos += global_data_count * np.int64().itemsize + data_shift = distctx.exscan(data_last) + pos = write_list(fout, pos, data_offsets, data_shift, global_data_offset, global_data_count, np.int64) # Each sentence is stored as a tensor. # The tensor for each sentence can be multidimensional. @@ -739,24 +745,11 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): # for each dimension of a sentence. # The list of size values from each rank are # concatenated and stored as int64. - fout.seek(pos + global_size_offset * np.int64().itemsize) - sizes64 = np.array(sizes, dtype=np.int64) - fout.write(sizes64.tobytes(order='C')) - del sizes64 - - # Advance past list of size values - pos += global_size_count * np.int64().itemsize + pos = write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int64) # The document index points to the position in the sizes # array for the first sentence of the sample. - docs64 = np.array(docs, dtype=np.int64) - docs64 += global_size_offset - - # Seek to proper offset for this rank and write - # document index into file, stored as int64. - fout.seek(pos + global_doc_offset * np.int64().itemsize) - fout.write(docs64.tobytes(order='C')) - del docs64 + pos = write_list(fout, pos, docs, global_size_offset, global_doc_offset, global_doc_count, np.int64) fout.close() @@ -771,14 +764,13 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # Create shared output file fout = distctx.open(index_file_path(outfile)) - # Read the index file for each of our files + # Read each index file and append items to the size and docs lists sizes = [] docs = [0] for f in filelist: - index = MMapIndexedDataset.Index(index_file_path(f)) - docs_offset = len(sizes) + index = MMapIndexedDataset.Index(index_file_path(f)) sizes.extend(index.sizes) docs.extend(index.doc_idx[1:] + docs_offset) @@ -817,13 +809,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # The list of size values from each rank are # concatenated and stored as int32. - fout.seek(pos + global_size_offset * np.int32().itemsize) - sizes32 = np.array(sizes, dtype=np.int32) - fout.write(sizes32.tobytes(order='C')) - del sizes32 - - # Advance past list of size values - pos += global_size_count * np.int32().itemsize + pos = write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int32) # The pointer values store the byte offset to each sentence. # A sentence has a variable number of tokens, given by @@ -848,7 +834,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # Finally, zero-base the offset values by subtracting # the number of bytes of the first sentence. To do that - # we first need to find the rank having the first sentence, + # we need to find the rank having the first sentence, # then bcast that size to all ranks. if global_size_count > 0: # Since global_size_count > 0, there is at least one sentence across all ranks. @@ -881,14 +867,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): # A variable number of sentences can be in each document. # Adjust document index for number of sentences that # come before the calling rank. - doc_idx = np.array(docs, dtype=np.int64) - doc_idx += global_size_offset - - # Seek to proper offset for this rank and write - # document index into file, stored as int64. - fout.seek(pos + global_docs_offset * np.int64().itemsize) - fout.write(doc_idx.tobytes(order='C')) - del doc_idx + pos = write_list(fout, pos, docs, global_size_offset, global_docs_offset, global_docs_count, np.int64) fout.close() From 01b2be076a1be647305c57c753d814d35f80e94c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 22:19:04 -0700 Subject: [PATCH 34/74] determine element size --- megatron/data/indexed_dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index b4b1fa2a1..11d7b59b8 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -665,6 +665,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): data_offsets = [0] dim_offsets = [0] docs = [0] + element_size = None for f in filelist: doc_offset = len(sizes) @@ -673,10 +674,17 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) docs.extend(index.doc_idx[1:] + doc_offset) + element_size = index.element_size + + # TODO: verify that element size is the same on all ranks for all files + # take first valid element size we can find + element_size = distctx.bcast_first(element_size) + assert element_size is not None, "Failed to find a valid element size in index files" # Capture the last value in each array before we delete any items. # Note this may be zero on any rank that has no items, # but zero is the correct value in that case. + # These values are used in scan operations to determine a shift value. dim_last = dim_offsets[-1] data_last = data_offsets[-1] @@ -712,8 +720,8 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): if rank == 0: fout.write(IndexedDataset._HDR_MAGIC) fout.write(struct.pack(' Date: Mon, 16 Aug 2021 22:53:18 -0700 Subject: [PATCH 35/74] add checks for consistent element_size values --- megatron/data/indexed_dataset.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 11d7b59b8..35dc29ca6 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -665,7 +665,8 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): data_offsets = [0] dim_offsets = [0] docs = [0] - element_size = None + element_size_valid = True # whether rank identifies inconsistent values in its files + element_size_value = None # the current element size value, if any for f in filelist: doc_offset = len(sizes) @@ -674,12 +675,27 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) docs.extend(index.doc_idx[1:] + doc_offset) - element_size = index.element_size - # TODO: verify that element size is the same on all ranks for all files - # take first valid element size we can find - element_size = distctx.bcast_first(element_size) - assert element_size is not None, "Failed to find a valid element size in index files" + # check that the element value in this index matches other our other files + if element_size_value is None: + element_size_value = index.element_size + if index.element_size != element_size_value: + element_size_valid = False + + # verify that no rank has found an inconsistent value in their own set of files + allvalid = distctx.alltrue(element_size_valid) + if not allvalid: + if not element_size_valid: + print(f"Rank {rank}: found different element_size values in {filelist}") + assert allvalid, f"Some rank found inconsistent element_size values" + + # verify that at least one rank found an element size + element_size = distctx.bcast_first(element_size_value) + assert element_size is not None, "Failed to find any element size in index files" + + # verify that all ranks that have an element size are consistent with each other + allsame = distctx.alltrue(element_size_value == element_size or element_size_value is None) + assert allsame, "Failed to find a valid element size in index files" # Capture the last value in each array before we delete any items. # Note this may be zero on any rank that has no items, From ea0855550f1bdb7d1c211f6d07fadbdc7d2241c4 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 23:36:40 -0700 Subject: [PATCH 36/74] check that at least one rank has a file to merge --- megatron/data/indexed_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 35dc29ca6..44bb6bc1b 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -957,6 +957,10 @@ def gather_files_dist_check_type(filelist, distctx): # Each rank merges its set of files into filemain collectively with all # other ranks. def gather_files_dist(filemain, filelist, distctx, dtype=np.int64): + # check that at least one input file is listed + filecount = distctx.sum(len(filelist)) + assert filecount > 0, "No rank has any input files to merge" + # TODO: seems like this could be relaxed # Check that files are all of the same index type indexstr = gather_files_dist_check_type(filelist, distctx) From 2524fce6f6a06e3eb08e073584a2f024452c680c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 09:13:45 -0700 Subject: [PATCH 37/74] assert that torch backend is gloo or mpi --- megatron/data/distdata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 932a913c8..254b3ae57 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -21,6 +21,7 @@ def __init__(self, backend='gloo', use_mpi4py=False): self.rank = self.comm.Get_rank() self.numranks = self.comm.Get_size() else: + assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" dist.init_process_group(backend, init_method="env://") self.rank = dist.get_rank() self.numranks = dist.get_world_size() From ca14d48d7bfbd2d951ef264a07c4de7f7cd5f738 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 09:30:10 -0700 Subject: [PATCH 38/74] add collectives for assert and raise --- megatron/data/distdata.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 254b3ae57..9e35b4db8 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -3,6 +3,9 @@ import torch import torch.distributed as dist +class DistDataException(Exception): + pass + class DistData(object): def __init__(self, backend='gloo', use_mpi4py=False): # use mpi4py instead of torch.distributed if requested @@ -26,6 +29,29 @@ def __init__(self, backend='gloo', use_mpi4py=False): self.rank = dist.get_rank() self.numranks = dist.get_world_size() + def assert(self, cond, msg): + """Check that condition cond is True on all ranks, assert with message everywhere if not.""" + alltrue = self.alltrue(cond) + assert alltrue, msg + + def raise(self, err): + """Raise exception if err is not None on any rank.""" + alltrue = self.alltrue(err is None) + if not alltrue: + # at least on rank raised an exception, + # re-raise the actual exception on any rank that threw one + if err is not None: + raise err + + # on other ranks, raise an "empty" exception to indicate + # that we're failing because someone else did + raise DistDataException + + def assert(self, cond, msg): + """Check that condition cond is True on all ranks.""" + alltrue = self.alltrue(cond) + assert alltrue, msg + def barrier(self): """Globally synchronize all processes""" if self.mpi4py is not None: From d482f36f2c7a5bb5d2e2a6e09fec86cd7afbf66d Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 10:12:23 -0700 Subject: [PATCH 39/74] rename to allassert and allraise_if --- megatron/data/distdata.py | 42 ++++++++++++++------------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 9e35b4db8..b042c2839 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -3,7 +3,8 @@ import torch import torch.distributed as dist -class DistDataException(Exception): +class DistDataError(Exception): + """Defines an empty exception to throw when some other rank hit a real exception.""" pass class DistData(object): @@ -29,28 +30,24 @@ def __init__(self, backend='gloo', use_mpi4py=False): self.rank = dist.get_rank() self.numranks = dist.get_world_size() - def assert(self, cond, msg): + def allassert(self, cond, msg): """Check that condition cond is True on all ranks, assert with message everywhere if not.""" alltrue = self.alltrue(cond) assert alltrue, msg - def raise(self, err): + def allraise_if(self, err): """Raise exception if err is not None on any rank.""" alltrue = self.alltrue(err is None) if not alltrue: - # at least on rank raised an exception, - # re-raise the actual exception on any rank that threw one + # At least one rank raised an exception. + # Re-raise the actual exception if this rank threw one. if err is not None: raise err - # on other ranks, raise an "empty" exception to indicate - # that we're failing because someone else did - raise DistDataException - - def assert(self, cond, msg): - """Check that condition cond is True on all ranks.""" - alltrue = self.alltrue(cond) - assert alltrue, msg + # TODO: is there a better exception to use here? + # On other ranks, raise an "empty" exception to indicate + # that we're only failing because someone else did. + raise DistDataError def barrier(self): """Globally synchronize all processes""" @@ -174,26 +171,22 @@ def all_sum_(self, vals): def open(self, filename): """Create, truncate, and open a file shared by all ranks.""" - success = True - err = None # Don't truncate existing file until all ranks reach this point self.barrier() + # We'll capture any exception in this variable + err = None + # Rank 0 creates and truncates file. if self.rank == 0: try: f = open(filename, 'wb') except Exception as e: - success = False err = e # Verify that rank 0 created the file - success = self.alltrue(success) - if not success: - if err is not None: - raise err - return None + self.allraise_if(err) # Wait for rank 0 to open (and truncate) file, # then have all ranks open file for writing. @@ -201,14 +194,9 @@ def open(self, filename): try: f = open(filename, 'r+b') except Exception as e: - success = False err = e # Verify that all ranks successfully opened the file - success = self.alltrue(success) - if not success: - if err is not None: - raise err - return None + self.allraise_if(err) return f From 28d76f57877dc6393f5ba743fe3378b288405d03 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 11:23:56 -0700 Subject: [PATCH 40/74] check dtype instead of element_size --- megatron/data/indexed_dataset.py | 69 ++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 44bb6bc1b..e13f566d2 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -653,7 +653,25 @@ def write_list(fout, pos, vals, shift, offset, total, dtype): # Advance file pointer past end of this section. return pos + total * dtype().itemsize -def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): + +def gather_files_dist_check_dtype(filelist, dtype_valid, dtype_code, distctx): + # verify that no rank has found an inconsistent value in their own set of files + allvalid = distctx.alltrue(dtype_valid) + if not allvalid: + if not dtype_valid: + print(f"Rank {distctx.rank}: found different dtype values in {filelist}") + assert allvalid, f"Some rank found inconsistent dtype values" + + # verify that at least one rank found a dtype + first_dtype_code = distctx.bcast_first(dtype_code) + assert first_dtype_code is not None, "Failed to find a dtype value in any index file" + + # verify that all ranks that have a dtype that is consistent with each other + allsame = distctx.alltrue(dtype_code == first_dtype_code or dtype_code is None) + assert allsame, "Different dtype values detected in index files" + + +def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype=np.int32): # get our rank rank = distctx.rank @@ -665,8 +683,8 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): data_offsets = [0] dim_offsets = [0] docs = [0] - element_size_valid = True # whether rank identifies inconsistent values in its files - element_size_value = None # the current element size value, if any + dtype_valid = True # whether rank identifies inconsistent values in its files + dtype_value = None # the current dtype code, if any for f in filelist: doc_offset = len(sizes) @@ -676,26 +694,15 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) docs.extend(index.doc_idx[1:] + doc_offset) - # check that the element value in this index matches other our other files - if element_size_value is None: - element_size_value = index.element_size - if index.element_size != element_size_value: - element_size_valid = False + # check that the dtype in this index matches the dtype in our other files + dtype_code = code(index.dtype) + if dtype_value is None: + dtype_value = dtype_code + if dtype_value != dtype_code: + dtype_valid = False - # verify that no rank has found an inconsistent value in their own set of files - allvalid = distctx.alltrue(element_size_valid) - if not allvalid: - if not element_size_valid: - print(f"Rank {rank}: found different element_size values in {filelist}") - assert allvalid, f"Some rank found inconsistent element_size values" - - # verify that at least one rank found an element size - element_size = distctx.bcast_first(element_size_value) - assert element_size is not None, "Failed to find any element size in index files" - - # verify that all ranks that have an element size are consistent with each other - allsame = distctx.alltrue(element_size_value == element_size or element_size_value is None) - assert allsame, "Failed to find a valid element size in index files" + # Check that we have consistent dtypes in all files from all ranks + gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) # Capture the last value in each array before we delete any items. # Note this may be zero on any rank that has no items, @@ -737,7 +744,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype): fout.write(IndexedDataset._HDR_MAGIC) fout.write(struct.pack(' Date: Tue, 17 Aug 2021 11:25:29 -0700 Subject: [PATCH 41/74] add uint32 to element_sizes table --- megatron/data/indexed_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index e13f566d2..b0438444a 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -268,6 +268,7 @@ class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, + np.uint16: 2, np.int16: 2, np.int32: 4, np.int64: 8, From f1228837605d42d3327207c467b64044fbed9f78 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 12:29:57 -0700 Subject: [PATCH 42/74] infer dtype from files being merged --- megatron/data/indexed_dataset.py | 21 ++++++++++++--------- tools/preprocess_dataset_mpi.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index b0438444a..c92290e8b 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -671,8 +671,11 @@ def gather_files_dist_check_dtype(filelist, dtype_valid, dtype_code, distctx): allsame = distctx.alltrue(dtype_code == first_dtype_code or dtype_code is None) assert allsame, "Different dtype values detected in index files" + # return the dtype that is used in all files + return dtypes[first_dtype_code] -def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype=np.int32): + +def gather_files_dist_idx_cached(outfile, filelist, distctx): # get our rank rank = distctx.rank @@ -703,7 +706,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype=np.int32): dtype_valid = False # Check that we have consistent dtypes in all files from all ranks - gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) # Capture the last value in each array before we delete any items. # Note this may be zero on any rank that has no items, @@ -789,7 +792,7 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx, dtype=np.int32): distctx.barrier() -def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): +def gather_files_dist_idx_mmap(outfile, filelist, distctx): # get our rank rank = distctx.rank @@ -816,7 +819,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx, dtype): dtype_valid = False # Check that we have consistent dtypes in all files from all ranks - gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 @@ -976,7 +979,7 @@ def gather_files_dist_check_impltype(filelist, distctx): # Each rank contributes a distinct list of zero or more files in filelist. # Each rank merges its set of files into filemain collectively with all # other ranks. -def gather_files_dist(filemain, filelist, distctx, dtype=np.int64): +def gather_files_dist(filemain, filelist, distctx): # check that at least one input file is listed filecount = distctx.sum(len(filelist)) assert filecount > 0, "No rank has any input files to merge" @@ -990,9 +993,9 @@ def gather_files_dist(filemain, filelist, distctx, dtype=np.int64): # Combine index files into a single index file if indexstr == "cached": - gather_files_dist_idx_cached(filemain, filelist, distctx, dtype) + gather_files_dist_idx_cached(filemain, filelist, distctx) elif indexstr == "mmap": - gather_files_dist_idx_mmap(filemain, filelist, distctx, dtype) + gather_files_dist_idx_mmap(filemain, filelist, distctx) def get_start_end(count, rank, numranks): @@ -1009,11 +1012,11 @@ def get_start_end(count, rank, numranks): # Given a global list of files in filelist, and a set of processed defined # by the distributed environment in distctx, collectively merge files into # a new output specified in filemain. -def merge_files_dist(filemain, filelist, distctx, dtype=np.int64): +def merge_files_dist(filemain, filelist, distctx): # TODO: if file sizes vary significantly, it might be better to consider # file size when splitting the list to different ranks. # evenly divide list of files among ranks start, end = get_start_end(len(filelist), distctx.rank, distctx.numranks) sublist = filelist[start:end] - return gather_files_dist(filemain, sublist, distctx, dtype) + return gather_files_dist(filemain, sublist, distctx) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 1145dd978..47fa11da8 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -495,7 +495,7 @@ def rank_files_merge_parallel(args): for key in args.columns: filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) - gather_files_dist(filemain, [filerank], args.distctx, dtype=best_fitting_dtype(args.vocab_size)) + gather_files_dist(filemain, [filerank], args.distctx) # total up bytes read in merge binfile = data_file_path(filerank) From 57c012e0f1144b27c3fe197b7710d646520004ca Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 13:04:46 -0700 Subject: [PATCH 43/74] add write_header function to indexed dataset classes --- megatron/data/indexed_dataset.py | 34 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index c92290e8b..32bef837e 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -276,6 +276,17 @@ class IndexedDatasetBuilder(object): np.double: 8 } + @staticmethod + def write_header(fout, dtype, numdata, numsize, numdoc): + """Writes header for cached indexed dataset to given file handle.""" + fout.write(IndexedDataset._HDR_MAGIC) + fout.write(struct.pack(' Date: Tue, 17 Aug 2021 13:13:00 -0700 Subject: [PATCH 44/74] call write_header internally from IndexedDataset classes --- megatron/data/indexed_dataset.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 32bef837e..66b08c797 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -332,11 +332,7 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack(' Date: Tue, 17 Aug 2021 13:44:21 -0700 Subject: [PATCH 45/74] return number of bytes written from write calls --- megatron/data/indexed_dataset.py | 82 ++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 66b08c797..a43c40853 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -278,7 +278,9 @@ class IndexedDatasetBuilder(object): @staticmethod def write_header(fout, dtype, numdata, numsize, numdoc): - """Writes header for cached indexed dataset to given file handle.""" + """Writes header for cached indexed dataset to given file handle, return number of bytes written.""" + startpos = fout.tell() + fout.write(IndexedDataset._HDR_MAGIC) fout.write(struct.pack(' Date: Tue, 17 Aug 2021 15:04:52 -0700 Subject: [PATCH 46/74] move scatterv to distdata class --- megatron/data/distdata.py | 23 +++++++++++++++++++++++ tools/preprocess_dataset_mpi.py | 18 +----------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index b042c2839..bd837bbc9 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -85,6 +85,29 @@ def bcast_list(self, vals, root=0): dist.broadcast(tvals, src=root) return tvals.tolist() + def scatterv_(self, invals, counts, outval, root=0): + """Scatter int64 values from invals according to counts array, receive values in outval""" + + self.allassert(len(counts) == self.numranks, + f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") + + self.allassert(outval.shape == (counts[self.rank],), + f"Rank {self.rank}: output buffer is of shape {outval.shape}, expected {(counts[self.rank],)}") + + self.allassert(outval.dtype == np.int64, + f"Requires outval to be of type numpy.int64") + + if self.mpi4py is not None: + counts = np.array(counts) + displs = np.cumsum(counts) - counts + self.comm.Scatterv([invals, counts, displs, self.mpi4py.INT64_T], outval, root=root) + else: + scatterlist = None + if self.rank == root: + scatterlist = list(torch.split(torch.from_numpy(invals), counts)) + outtensor = torch.from_numpy(outval) + dist.scatter(outtensor, scatterlist, src=root) + def alltrue(self, val): """Returns True if all procs input True, False otherwise""" if self.mpi4py is not None: diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 5197bcc50..a7de0489a 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -218,22 +218,6 @@ def format_byterate(byterate): mbps = byterate / (1024.0 * 1024.0) return f"{mbps:0.3f} MB/s" -def scatterv_(args, invals, counts, outval, root=0): - """Scatter int64 values from invals according to counts array, receive values in outval""" - assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" - assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" - - if args.use_mpi: - counts = np.array(counts) - displs = np.cumsum(counts) - counts - args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) - else: - scatterlist = None - if args.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) - def load_dset(args): # Avoid downloading datasets unless explicitly requested. # We allow the user to override this behavior if they set $HF_DATASETS_OFFLINE. @@ -346,7 +330,7 @@ def select_sample_list(args, dset_size): # scatter sample index values from rank 0 to all procs # based on distribution defined in counts list time_bcast = time.time() - scatterv_(args, idxlist, counts, idx, root=0) + args.distctx.scatterv_(idxlist, counts, idx, root=0) args.distctx.barrier() time_end = time.time() From dadb51b4a8c2473c52e340665b6f85f38cc5c49f Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 15:19:33 -0700 Subject: [PATCH 47/74] add functions to format status and error messages --- tools/preprocess_dataset_mpi.py | 103 +++++++++++++++++--------------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index a7de0489a..6f57da1a0 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -68,6 +68,13 @@ from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist from megatron.data.distdata import DistData +def msg(msg, flush=False): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") + print(f"{timestamp}: {msg}", flush=flush) + +def msgerr(msg, flush=False): + print(f"ERROR: {msg}", flush=flush) + # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -93,7 +100,7 @@ def __init__(self, args): if self.args.split_sentences: if not nltk_available: - print("NLTK is not available to split sentences.") + msgerr("NLTK is not available to split sentences.") exit() splitter = nltk.load("tokenizers/punkt/english.pickle") if self.args.keep_newlines: @@ -201,7 +208,7 @@ def get_args(): if args.tokenizer_type.lower().startswith('bert'): if not args.split_sentences: if args.rank == 0: - print("Bert tokenizer detected, are you sure you don't want to split sentences?") + msg("Bert tokenizer detected, are you sure you don't want to split sentences?") args.level = "document" if args.split_sentences: @@ -245,19 +252,19 @@ def load_dset(args): err = None dsetname = args.input if args.rank == 0: - print(f"Opening dataset {dsetname}") + msg(f"Opening dataset {dsetname}") try: dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except OfflineModeIsEnabled as e: - print(f"ERROR: Cannot download '{dsetname}' since running in offline mode.") - print(f"ERROR: If the dataset is large, it may be more efficient to download with a single process:") - print(f"ERROR: from datasets import load_dataset") - print(f"ERROR: dset = load_dataset('{dsetname}')") - print(f"ERROR: Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) + msgerr(f"Cannot download '{dsetname}' since running in offline mode.") + msgerr(f"If the dataset is large, it may be more efficient to download with a single process:") + msgerr(f" from datasets import load_dataset") + msgerr(f" dset = load_dataset('{dsetname}')") + msgerr(f"Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) success = False err = e except Exception as e: - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) + msgerr("Unexpected error:", sys.exc_info()[0], flush=True) success = False err = e @@ -273,7 +280,7 @@ def load_dset(args): dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except Exception as e: # this print might be noisy, but better than nothing - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) + msgerr("Unexpected error:", sys.exc_info()[0], flush=True) success = False err = e @@ -281,12 +288,12 @@ def load_dset(args): success = args.distctx.alltrue(success) if not success: if args.rank == 0: - print(f"ERROR: At least one process failed to load {dsetname}", flush=True) + msgerr(f"At least one process failed to load {dsetname}", flush=True) return None, err time_end = time.time() if args.rank == 0: - print(f"Seconds to load dataset: {time_end - time_start}", flush=True) + msg(f"Seconds to load dataset: {time_end - time_start}", flush=True) return dset, err @@ -335,11 +342,11 @@ def select_sample_list(args, dset_size): args.distctx.barrier() time_end = time.time() if args.rank == 0: - print(f"Select index stats:") - print(f" Shuffle: {args.shuffle}") - print(f" Seconds to select: {time_bcast - time_select}") - print(f" Seconds to broadcast: {time_end - time_bcast}") - print(f" Seconds total: {time_end - time_select}", flush=True) + msg(f"Select index stats:") + msg(f" Shuffle: {args.shuffle}") + msg(f" Seconds to select: {time_bcast - time_select}") + msg(f" Seconds to broadcast: {time_end - time_bcast}") + msg(f" Seconds total: {time_end - time_select}", flush=True) return idx @@ -379,8 +386,8 @@ def rank_files_write(args, dset, idx, encoder): try: # create data file for each rank if args.rank == 0: - print(f"Vocab size: {args.vocab_size}") - print(f"Output prefix: {args.output_prefix}") + msg(f"Vocab size: {args.vocab_size}") + msg(f"Output prefix: {args.output_prefix}") output_bin_files = {} output_idx_files = {} builders = {} @@ -424,16 +431,15 @@ def rank_files_write(args, dset, idx, encoder): progress_next = current + float(args.log_interval) elapsed = current - time_start - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") docs = dset_stats[0] * args.numranks percent = docs / num_samples * 100.0 docrate = docs / elapsed if elapsed > 0.0 else 0.0 mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) - print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", - f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", - f"{secs_left} secs left ...", - flush=True) + msg(f"Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", + f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", + f"{secs_left} secs left ...", + flush=True) # finalize file of each rank for key in args.columns: @@ -447,8 +453,7 @@ def rank_files_write(args, dset, idx, encoder): # In case rank 0 finishes early and stops printing progress messages, # inform user that it's waiting for other ranks to finish. if args.rank == 0 and args.log_interval > 0: - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") - print(f"{timestamp}: Waiting for ranks to finalize files ...", flush=True) + msg(f"Waiting for ranks to finalize files ...", flush=True) # wait for all ranks to finish their files args.distctx.barrier() @@ -465,14 +470,14 @@ def rank_files_write(args, dset, idx, encoder): secs_read_per_sample = times[0] / dset_stats[0] if dset_stats[0] > 0 else 0.0 secs_encode_per_sample = times[1] / dset_stats[0] if dset_stats[0] > 0 else 0.0 secs_write_per_sample = times[2] / dset_stats[0] if dset_stats[0] > 0 else 0.0 - print("Process stats:") - print(f" Seconds to process: {secs}") - print(f" {dset_stats[0]} docs {docrate} docs/sec") - print(f" {dset_stats[1]} sents {sentrate} sents/sec") - print(f" {dset_stats[2]} bytes {format_byterate(byterate)}") - print(f" Total read seconds {times[0]}, {secs_read_per_sample} sec/sample") - print(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") - print(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") + msg("Process stats:") + msg(f" Seconds to process: {secs}") + msg(f" {dset_stats[0]} docs {docrate} docs/sec") + msg(f" {dset_stats[1]} sents {sentrate} sents/sec") + msg(f" {dset_stats[2]} bytes {format_byterate(byterate)}") + msg(f" Total read seconds {times[0]}, {secs_read_per_sample} sec/sample") + msg(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") + msg(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") # allreduce to check whether all ranks wrote their part successfully success = args.distctx.alltrue(success) @@ -508,15 +513,15 @@ def rank_files_merge_parallel(args): if args.rank == 0: secs = merge_end - merge_start byterate = numbytes[0] / secs if secs > 0.0 else 0.0 - print("Parallel merge stats:") - print(f" Scratch: {args.scratch}") - print(f" Seconds to merge: {secs}") - print(f" {int(numbytes)} bytes {format_byterate(byterate)}") + msg("Parallel merge stats:") + msg(f" Scratch: {args.scratch}") + msg(f" Seconds to merge: {secs}") + msg(f" {int(numbytes)} bytes {format_byterate(byterate)}") def rank_files_merge_serial(args): """Rank 0 merges data from all per-rank files into the final file.""" if args.rank == 0: - print("Merging rank files ...", flush=True) + msg("Merging rank files ...", flush=True) merge_start = time.time() numbytes = 0 @@ -538,7 +543,7 @@ def rank_files_merge_serial(args): for key in args.columns: infile = get_filename(args, key, rank) -# print(f"Merging file {infile}", flush=True) +# msg(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes @@ -548,7 +553,7 @@ def rank_files_merge_serial(args): numbytes += os.stat(idxfile)[stat.ST_SIZE] # finalize the merged file - print("Finalizing merged file ...", flush=True) + msg("Finalizing merged file ...", flush=True) for key in args.columns: builders[key].finalize(output_idx_files[key]) del builders[key] # file closed in __del__ @@ -556,10 +561,10 @@ def rank_files_merge_serial(args): merge_end = time.time() secs = merge_end - merge_start byterate = numbytes / secs if secs > 0.0 else 0.0 - print(f"Merged {args.numranks} files into {args.output_prefix}") - print("Merge stats:") - print(f" Seconds to merge: {secs}") - print(f" {numbytes} bytes {format_byterate(byterate)}") + msg(f"Merged {args.numranks} files into {args.output_prefix}") + msg("Merge stats:") + msg(f" Seconds to merge: {secs}") + msg(f" {numbytes} bytes {format_byterate(byterate)}") # hold everyone until rank 0 is done args.distctx.barrier() @@ -580,7 +585,7 @@ def rank_files_merge(args): def rank_files_delete(args): # delete per-rank files if args.rank == 0: - print("Deleting rank files ...", flush=True) + msg("Deleting rank files ...", flush=True) for key in args.columns: filebase = get_filename(args, key, args.rank) @@ -608,7 +613,7 @@ def main(): return if args.rank == 0: print(dset) - print("Selecting features:", args.columns) + msg("Selecting features:", args.columns) # create sample index list, # optionally shuffle the list, @@ -625,7 +630,7 @@ def main(): args.distctx.barrier() startup_end = time.time() if args.rank == 0: - print(f"Seconds to startup: {startup_end - startup_start}") + msg(f"Seconds to startup: {startup_end - startup_start}") # have each rank write its file, returns False if any rank had a problem success, err = rank_files_write(args, dset, idx, encoder) @@ -633,7 +638,7 @@ def main(): if args.rank == 0: # If any process fails, we skip the merge since the resulting file would be invalid. # We still delete files to clean up, since those might be invalid anyway. - print(f"ERROR: At least one process failed to write its file, skipping merge and cleaning up", flush=True) + msgerr(f"At least one process failed to write its file, skipping merge and cleaning up", flush=True) # delete per-rank files, do this even on error rank_files_delete(args) From a2f8fa0f1ac1b521fb10b5c9429d28f45c556d12 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 16:12:23 -0700 Subject: [PATCH 48/74] defer merge_files_dist to future PR --- megatron/data/indexed_dataset.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index a43c40853..2a6b6126b 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -1026,27 +1026,3 @@ def gather_files_dist(filemain, filelist, distctx): gather_files_dist_idx_cached(filemain, filelist, distctx) elif indexstr == "mmap": gather_files_dist_idx_mmap(filemain, filelist, distctx) - - -def get_start_end(count, rank, numranks): - num, remainder = divmod(count, numranks) - if rank < remainder: - start = (num + 1) * rank - end = start + num + 1 - else: - start = (num + 1) * remainder + num * (rank - remainder) - end = start + num - return start, end - - -# Given a global list of files in filelist, and a set of processed defined -# by the distributed environment in distctx, collectively merge files into -# a new output specified in filemain. -def merge_files_dist(filemain, filelist, distctx): - # TODO: if file sizes vary significantly, it might be better to consider - # file size when splitting the list to different ranks. - - # evenly divide list of files among ranks - start, end = get_start_end(len(filelist), distctx.rank, distctx.numranks) - sublist = filelist[start:end] - return gather_files_dist(filemain, sublist, distctx) From 39e6cd74ac63db3e7485dd4fa86ec5be39516232 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 17:26:46 -0700 Subject: [PATCH 49/74] open files using with, refresh comments --- megatron/data/indexed_dataset.py | 335 ++++++++++++++----------------- 1 file changed, 155 insertions(+), 180 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 2a6b6126b..802299f4c 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -615,20 +615,15 @@ def finalize(self, index_file): index.write(self._sizes, self._doc_idx) -# To create the binary files given a set of per-rank binary -# files, one simply concatenates the data from the per-rank -# binary files in rank order. We stat each rank file to determine -# its size, execute a scan to compute the byte offset where -# the calling rank should write its data, seek to proper -# spot, and copy the full file. +# To merge a set of binary files, one can simply concatenate them in order. +# We stat each binary file to determine its size, execute a scan to compute +# the byte offset where the calling rank should write its data, seek to proper +# spot, and copy each file. def gather_files_dist_bin(outfile, filelist, distctx): - """Concatenate per-rank binary files into a new file given by outfile""" + """Concatenate binary files in filelist into a new file given by outfile""" import stat import shutil - # Create shared output file. - fout = distctx.open(data_file_path(outfile)) - # lookup size of each of our binary files filesizes = [os.stat(data_file_path(f))[stat.ST_SIZE] for f in filelist] @@ -637,15 +632,15 @@ def gather_files_dist_bin(outfile, filelist, distctx): numbytes = sum(filesizes) offset = distctx.exscan(numbytes) - # seek to appropriate starting offset in the merged file - fout.seek(offset) - - # copy in contents of each of our files - for f in filelist: - with open(data_file_path(f), "rb") as fsrc: - shutil.copyfileobj(fsrc, fout) + # Create shared output file. + with distctx.open(data_file_path(outfile)) as fout: + # Seek to appropriate starting offset in the merged file. + fout.seek(offset) - fout.close() + # Copy in contents of each of our files. + for f in filelist: + with open(data_file_path(f), "rb") as fsrc: + shutil.copyfileobj(fsrc, fout) # TODO: check that all ranks wrote successfully distctx.barrier() @@ -700,32 +695,21 @@ def write_list(fout, pos, vals, shift, offset, total, dtype): def gather_files_dist_check_dtype(filelist, dtype_valid, dtype_code, distctx): - # verify that no rank has found an inconsistent value in their own set of files - allvalid = distctx.alltrue(dtype_valid) - if not allvalid: - if not dtype_valid: - print(f"Rank {distctx.rank}: found different dtype values in {filelist}") - assert allvalid, f"Some rank found inconsistent dtype values" - - # verify that at least one rank found a dtype + # verify that no rank has found an inconsistent value in its own set of files + distctx.allassert(dtype_valid, "Some rank found inconsistent dtype values") + + # verify that at least one rank found a dtype value first_dtype_code = distctx.bcast_first(dtype_code) assert first_dtype_code is not None, "Failed to find a dtype value in any index file" - # verify that all ranks that have a dtype that is consistent with each other - allsame = distctx.alltrue(dtype_code == first_dtype_code or dtype_code is None) - assert allsame, "Different dtype values detected in index files" + # verify that the dtype is consistent on all ranks, if a rank has a dtype value + distctx.allassert(dtype_code == first_dtype_code or dtype_code is None, "Different dtype values detected in index files") - # return the dtype that is used in all files + # return the dtype return dtypes[first_dtype_code] def gather_files_dist_idx_cached(outfile, filelist, distctx): - # get our rank - rank = distctx.rank - - # Create shared output file - fout = distctx.open(index_file_path(outfile)) - # Read each index file and append items to our lists sizes = [] data_offsets = [0] @@ -734,9 +718,11 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): dtype_valid = True # whether rank identifies inconsistent values in its files dtype_value = None # the current dtype code, if any for f in filelist: - doc_offset = len(sizes) - + # read index file for this file index = IndexedDataset(f) + + # append its size, data, dim, and doc entries to our lists + doc_offset = len(sizes) sizes.extend(index.sizes) data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) @@ -752,25 +738,25 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): # Check that we have consistent dtypes in all files from all ranks dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) - # Capture the last value in each array before we delete any items. + # Capture the last value in dim and data arrays before we delete any items. # Note this may be zero on any rank that has no items, # but zero is the correct value in that case. - # These values are used in scan operations to determine a shift value. - dim_last = dim_offsets[-1] - data_last = data_offsets[-1] + # We use this last value to compute a shift value that will + # later be added to each element in our lists. + dim_shift = distctx.exscan(dim_offsets[-1]) + data_shift = distctx.exscan(data_offsets[-1]) # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 - if rank != 0: + if distctx.rank != 0: del data_offsets[0] del dim_offsets[0] del docs[0] - # Compute total number of size and document index - # values across all ranks. Also compute the offset - # of the calling rank for each value considering - # the values of sizes/docs for all ranks before the - # calling rank. + # Compute total number of entires in data, size, dim, + # and docs lists across all ranks. Also compute the offset + # of the calling rank for each list considering the number + # of entries for all ranks before the calling rank. numdata = len(data_offsets) numsize = len(sizes) numdim = len(dim_offsets) @@ -786,64 +772,59 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): global_dim_offset = distctx.exscan(numdim) global_doc_offset = distctx.exscan(numdoc) - # Have rank 0 write the file header - # Broadcast value of pos from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if rank == 0: - pos = IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) - pos = distctx.bcast(pos, root=0) - - # The dimension list records the offset within - # the sizes list for each sentence. Adjust dimension - # offset values based on the number of offsets - # that come before the calling rank. - dim_shift = distctx.exscan(dim_last) - pos += write_list(fout, pos, dim_offsets, dim_shift, global_dim_offset, global_dim_count, np.int64) - - # The data index records the byte offset to the start of each - # sentence within the binary data file. - # Adjust our data index values for number of bytes that - # come before the calling rank. - data_shift = distctx.exscan(data_last) - pos += write_list(fout, pos, data_offsets, data_shift, global_data_offset, global_data_count, np.int64) - - # Each sentence is stored as a tensor. - # The tensor for each sentence can be multidimensional. - # The number of tensor dimensions per sentence is variable, - # and the size of each dimension of a sentence is arbitrary. - # The size list records a flattened list of the sizes - # for each dimension of a sentence. - # The list of size values from each rank are - # concatenated and stored as int64. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int64) - - # The document index points to the position in the sizes - # array for the first sentence of the sample. - pos += write_list(fout, pos, docs, global_size_offset, global_doc_offset, global_doc_count, np.int64) - - fout.close() + # Create shared output file + with distctx.open(index_file_path(outfile)) as fout: + # Have rank 0 write the file header + # Broadcast number of bytes written from rank 0, + # and advance file position past file header on all ranks. + pos = 0 + if distctx.rank == 0: + pos = IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) + pos = distctx.bcast(pos, root=0) + + # TODO: is dim_shift == global_size_offset? + # The dimension list records the offset within + # the sizes list for each sentence. + # Adjust dimension index values for number of size values that + # come before the calling rank which is in dim_shift. + pos += write_list(fout, pos, dim_offsets, dim_shift, global_dim_offset, global_dim_count, np.int64) + + # The data index records the element offset to the start of each + # sentence within the binary data file, expressed in units of dtype().itemsize. + # Adjust data index values for number of elements that + # come before the calling rank, which is in data_shift. + pos += write_list(fout, pos, data_offsets, data_shift, global_data_offset, global_data_count, np.int64) + + # Each sentence is stored as a tensor. + # The tensor for each sentence can be multidimensional. + # The number of tensor dimensions per sentence is variable, + # and the size of each dimension of a sentence is arbitrary. + # The size list records a flattened list of the sizes + # for each dimension of a sentence. + pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int64) + + # The document index points to the position in the sizes + # array for the first sentence of each document. + # Adjust document index for number of sentences that + # come before the calling rank which is in global_size_offset. + pos += write_list(fout, pos, docs, global_size_offset, global_doc_offset, global_doc_count, np.int64) # TODO: check that all ranks wrote successfully distctx.barrier() def gather_files_dist_idx_mmap(outfile, filelist, distctx): - # get our rank - rank = distctx.rank - - # Create shared output file - fout = distctx.open(index_file_path(outfile)) - # Read each index file and append items to the size and docs lists sizes = [] docs = [0] dtype_valid = True # whether rank identifies inconsistent values in its files dtype_value = None # the current dtype code, if any for f in filelist: - docs_offset = len(sizes) - + # read index file for this file index = MMapIndexedDataset.Index(index_file_path(f)) + + # append its size and doc entries to our lists + docs_offset = len(sizes) sizes.extend(index.sizes) docs.extend(index.doc_idx[1:] + docs_offset) @@ -859,7 +840,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 - if rank != 0: + if distctx.rank != 0: del docs[0] # Compute total number of size and document index @@ -876,77 +857,77 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): global_size_offset = distctx.exscan(numsizes) global_docs_offset = distctx.exscan(numdocs) - # Have rank 0 write the file header - # Broadcast value of pos from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if rank == 0: - pos = MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) - pos = distctx.bcast(pos, root=0) - - # The list of size values from each rank are - # concatenated and stored as int32. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int32) - - # The pointer values store the byte offset to each sentence. - # A sentence has a variable number of tokens, given by - # its corresponding entry in the size array. Each token - # of a sentence is stored in units of type dtype, which consumes - # dtype().itemsize bytes (often a standard type that is just - # large enough to represent all elements of the vocabulary). - - # First compute byte offsets for each sentence of our - # local set of sentences. - pointers = np.array(sizes, dtype=np.int64) - pointer_last = 0 - if len(sizes) > 0: - np.cumsum(pointers, axis=0, out=pointers) - pointers *= dtype().itemsize - pointer_last = pointers[-1] - - # Then account for bytes for all sentences on ranks - # before the calling rank. - pointer_offset = distctx.exscan(pointer_last) - pointers += pointer_offset - - # Finally, zero-base the offset values by subtracting - # the number of bytes of the first sentence. To do that - # we need to find the rank having the first sentence, - # then bcast that size to all ranks. - if global_size_count > 0: - # Since global_size_count > 0, there is at least one sentence across all ranks. - # Get the value from the first rank that has a value, which may not be rank 0. - pointers_shift = pointers[0] if len(sizes) > 0 else None - pointers_shift = distctx.bcast_first(pointers_shift) - - # Since there is at least one, bcast_first should return some value other than None. - assert pointers_shift is not None, "Expected at least one rank to have a valid element" - - # Zero-base pointers by subtracting size of first - # sentence from all values. - pointers -= pointers_shift - - # Since the pointers array is the same length as the sizes array, - # we use global_size_offset and global_size_count to position - # within the file for writing the pointer values. - - # Seek to proper offset for this rank and write - # pointer values into file, stored as int64. - fout.seek(pos + global_size_offset * np.int64().itemsize) - fout.write(pointers.tobytes(order='C')) - del pointers - - # Advance past list of pointer values - pos += global_size_count * np.int64().itemsize - - # The document index points to the position in the sizes - # array for the starting sentence of each document. - # A variable number of sentences can be in each document. - # Adjust document index for number of sentences that - # come before the calling rank. - pos += write_list(fout, pos, docs, global_size_offset, global_docs_offset, global_docs_count, np.int64) - - fout.close() + # Create shared output file + with distctx.open(index_file_path(outfile)) as fout: + # Have rank 0 write the file header + # Broadcast number of bytes written from rank 0, + # and advance file position past file header on all ranks. + pos = 0 + if distctx.rank == 0: + pos = MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) + pos = distctx.bcast(pos, root=0) + + # The list of size values from each rank are + # concatenated and stored as int32. + pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int32) + + # The pointer values store the byte offset to each sentence when in memory. + # A sentence has a variable number of tokens, given by + # its corresponding entry in the size array. Each token + # of a sentence is stored in units of type dtype, which consumes + # dtype().itemsize bytes (often a standard type that is just + # large enough to represent all elements of the vocabulary). + + # First compute byte offsets for each sentence of our + # local set of sentences. + pointers = np.array(sizes, dtype=np.int64) + pointer_last = 0 + if len(sizes) > 0: + np.cumsum(pointers, axis=0, out=pointers) + pointers *= dtype().itemsize + pointer_last = pointers[-1] + + # Then account for bytes for all sentences on ranks + # before the calling rank. + pointer_offset = distctx.exscan(pointer_last) + pointers += pointer_offset + + # Finally, zero-base the offset values by subtracting + # the number of bytes of the first sentence. To do that + # we need to find the rank having the first sentence, + # then bcast that size to all ranks. + if global_size_count > 0: + # Since global_size_count > 0, there is at least one sentence across all ranks. + # Get the value from the first rank that has a value, which may not be rank 0. + pointers_shift = pointers[0] if len(sizes) > 0 else None + pointers_shift = distctx.bcast_first(pointers_shift) + + # Since there is at least one, bcast_first should return some value other than None. + assert pointers_shift is not None, "Expected at least one rank to have a valid element" + + # Zero-base pointers by subtracting size of first + # sentence from all values. + pointers -= pointers_shift + + # Since the pointers array is the same length as the sizes array, + # we use global_size_offset and global_size_count to position + # within the file for writing the pointer values. + + # Seek to proper offset for this rank and write + # pointer values into file, stored as int64. + fout.seek(pos + global_size_offset * np.int64().itemsize) + fout.write(pointers.tobytes(order='C')) + del pointers + + # Advance past list of pointer values + pos += global_size_count * np.int64().itemsize + + # The document index points to the position in the sizes + # array for the starting sentence of each document. + # A variable number of sentences can be in each document. + # Adjust document index for number of sentences that + # come before the calling rank which is in global_size_offset. + pos += write_list(fout, pos, docs, global_size_offset, global_docs_offset, global_docs_count, np.int64) # TODO: check that all ranks wrote successfully distctx.barrier() @@ -964,11 +945,7 @@ def gather_files_dist_check_impltype(filelist, distctx): exists = False # Check that all ranks have all of their files. - allexist = distctx.alltrue(exists) - if not allexist: - if not exists: - assert False, f"At least one of the following names was not found: {filelist}" - assert False, f"Some rank is missing its input file" + distctx.allassert(exists, "Some rank is missing its input file") # map type string to an integer for easier bcast, use 0 for unknown implmap = {"cached": 1, "mmap": 2} @@ -981,10 +958,10 @@ def gather_files_dist_check_impltype(filelist, distctx): impl = infer_dataset_impl(f) implval = implmap[impl] if impl in implmap else 0 + # check that the type matches our other files if ourtype is None: ourtype = implval - - if implval != ourtype: + if ourtype != implval: sametype = False # Check that all ranks have the same type, @@ -995,9 +972,7 @@ def gather_files_dist_check_impltype(filelist, distctx): # - the broadcast type is of a known type: {cached, mmap} bcasttype = distctx.bcast_first(ourtype) matchtype = sametype and (ourtype is None or ourtype == bcasttype) and bcasttype != 0 - allsame = distctx.alltrue(matchtype) - if not allsame: - assert False, "Cannot merge dataset files of different types" + distctx.allassert(matchtype, "Cannot merge dataset files of different types") # map back to return index string name for key in implmap.keys(): @@ -1006,15 +981,15 @@ def gather_files_dist_check_impltype(filelist, distctx): # Collectively merge files into a new output file specified in filemain. -# Each rank contributes a distinct list of zero or more files in filelist. -# Each rank merges its set of files into filemain collectively with all -# other ranks. +# Each rank contributes a distinct list of zero or more files in filelist, +# and each rank directly merges its set of files into filemain. +# It is allowed for the input files in filelist to only be readable from the calling process. +# The output file in filemain should be in a location that is writable by all processes. def gather_files_dist(filemain, filelist, distctx): - # check that at least one input file is listed + # Check that at least one input file is listed filecount = distctx.sum(len(filelist)) assert filecount > 0, "No rank has any input files to merge" - # TODO: seems like this could be relaxed # Check that files are all of the same index type indexstr = gather_files_dist_check_impltype(filelist, distctx) From 2a29d9962c9f44bd5952b8448dd338690d0e4ff9 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 17:48:16 -0700 Subject: [PATCH 50/74] rely on default torch datatypes --- megatron/data/distdata.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index bd837bbc9..b7549bed5 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -31,7 +31,7 @@ def __init__(self, backend='gloo', use_mpi4py=False): self.numranks = dist.get_world_size() def allassert(self, cond, msg): - """Check that condition cond is True on all ranks, assert with message everywhere if not.""" + """Check that cond is True on all ranks, assert with msg everywhere if not.""" alltrue = self.alltrue(cond) assert alltrue, msg @@ -116,24 +116,26 @@ def alltrue(self, val): self.comm.Allreduce(inval, outval, op=self.mpi4py.LAND) return bool(outval[0]) else: + # torch.dist does not support reductions with bool types + # so we cast to int and cast the result back to bool tensor = torch.tensor([int(val)], dtype=torch.int32) dist.all_reduce(tensor, op=dist.ReduceOp.BAND) return bool(tensor[0]) def sum(self, val): - """Compute sum of val, and return total on all ranks.""" + """Compute sum of a scalar val, and return total on all ranks.""" if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) self.comm.Allreduce(insize, outsize, op=self.mpi4py.SUM) return outsize[0] else: - tensor = torch.tensor([val], dtype=torch.int64) + tensor = torch.tensor([val]) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return tensor[0] def exscan(self, val): - """Compute prefix sum (exclusive scan) of val, and return offset of each rank.""" + """Compute prefix sum (exclusive scan) of scalar val, and return offset of each rank.""" if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) @@ -147,14 +149,14 @@ def exscan(self, val): return int(tensor[self.rank]) - val def min(self, val): - """Return minimum val to all ranks.""" + """Return minimum of scalar val to all ranks.""" if self.mpi4py is not None: insize = np.array([val], dtype=np.int64) outsize = np.zeros_like(insize) self.comm.Allreduce(insize, outsize, op=self.mpi4py.MIN) return outsize[0] else: - tensor = torch.tensor([val], dtype=torch.int64) + tensor = torch.tensor([val]) dist.all_reduce(tensor, op=dist.ReduceOp.MIN) return tensor[0] @@ -183,7 +185,7 @@ def bcast_first(self, val): return val def all_sum_(self, vals): - """Sums values in vals element-wise and updates vals with final result on all ranks""" + """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" if self.mpi4py is not None: outval = np.zeros_like(vals) self.comm.Allreduce(vals, outval, op=self.mpi4py.SUM) From d6fa895933e70c80f386d59b8d9187d52c9bbebc Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 17:49:10 -0700 Subject: [PATCH 51/74] fix some status messages from preprocess script --- tools/preprocess_dataset_mpi.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 6f57da1a0..331764a29 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -436,8 +436,8 @@ def rank_files_write(args, dset, idx, encoder): docrate = docs / elapsed if elapsed > 0.0 else 0.0 mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) - msg(f"Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", - f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", + msg(f"Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%) in {int(elapsed)} secs, " + f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s, " f"{secs_left} secs left ...", flush=True) @@ -562,7 +562,7 @@ def rank_files_merge_serial(args): secs = merge_end - merge_start byterate = numbytes / secs if secs > 0.0 else 0.0 msg(f"Merged {args.numranks} files into {args.output_prefix}") - msg("Merge stats:") + msg("Serial merge stats:") msg(f" Seconds to merge: {secs}") msg(f" {numbytes} bytes {format_byterate(byterate)}") @@ -613,7 +613,7 @@ def main(): return if args.rank == 0: print(dset) - msg("Selecting features:", args.columns) + msg(f"Processing features: {args.columns}") # create sample index list, # optionally shuffle the list, @@ -654,5 +654,10 @@ def main(): # delete per-rank files rank_files_delete(args) + end_time = time.time() + if args.rank == 0: + msg(f"Runtime: {end_time - startup_start} secs", flush=True) + msg(f"Done") + if __name__ == '__main__': main() From 1216c0ab8bde83e213b8c230057efba113c7d1d1 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 22:39:35 -0700 Subject: [PATCH 52/74] fix: exclusive scan computing pointers list --- megatron/data/indexed_dataset.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 2b2c1f405..acdf36246 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -359,22 +359,16 @@ def _get_pointers(sizes, npdtype): """Return a numpy array of byte offsets given a list of sizes. Multiplies values in the sizes array by dtype size (bytes), - and then computes a zero-based prefix scan. + and then computes an exclusive scan to get byte offsets. """ - # create numpy array of desired numpy datatype - pointers = np.array(sizes, dtype=npdtype) + # compute element sizes in bytes + bytesizes = np.array(sizes, dtype=npdtype) + bytesizes *= dtype().itemsize - if len(sizes) > 0: - # scale each element by its dtype size - dtype_size = dtype().itemsize - pointers *= dtype_size - - # in-place prefix scan to compute byte offsets - np.cumsum(pointers, axis=0, out=pointers) - - # zero-base the prefix scan (exclusive scan) - pointers -= pointers[0] + # exclusive scan to get byte offsets + pointers = np.cumsum(bytesizes, axis=0) + pointers -= bytesizes return pointers From fde439ec758aa15c3918b70b2893967dc2650e0e Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 23:20:45 -0700 Subject: [PATCH 53/74] fix: exclusive scan to compute mmap pointers list --- megatron/data/indexed_dataset.py | 41 ++++++++++---------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 57afd50e2..766f1b73e 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -872,36 +872,22 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # dtype().itemsize bytes (often a standard type that is just # large enough to represent all elements of the vocabulary). - # First compute byte offsets for each sentence of our - # local set of sentences. - pointers = np.array(sizes, dtype=np.int64) - pointer_last = 0 - if len(sizes) > 0: - np.cumsum(pointers, axis=0, out=pointers) - pointers *= dtype().itemsize - pointer_last = pointers[-1] - - # Then account for bytes for all sentences on ranks - # before the calling rank. - pointer_offset = distctx.exscan(pointer_last) - pointers += pointer_offset + # Compute byte sizes for each of our sentences given + # the token count and vocab dtype. + bytesizes = np.array(sizes, dtype=np.int64) + bytesizes *= dtype().itemsize - # Finally, zero-base the offset values by subtracting - # the number of bytes of the first sentence. To do that - # we need to find the rank having the first sentence, - # then bcast that size to all ranks. - if global_size_count > 0: - # Since global_size_count > 0, there is at least one sentence across all ranks. - # Get the value from the first rank that has a value, which may not be rank 0. - pointers_shift = pointers[0] if len(sizes) > 0 else None - pointers_shift = distctx.bcast_first(pointers_shift) + # Inclusive scan to sum number of bytes over sentences. + pointers = np.cumsum(bytesizes, axis=0) - # Since there is at least one, bcast_first should return some value other than None. - assert pointers_shift is not None, "Expected at least one rank to have a valid element" + # Account for bytes for all sentences on ranks + # before the calling rank. + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + pointer_offset = distctx.exscan(bytes_last) + pointers += pointer_offset - # Zero-base pointers by subtracting size of first - # sentence from all values. - pointers -= pointers_shift + # Convert to exclusive scan to get global offset. + pointers -= bytesizes # Since the pointers array is the same length as the sizes array, # we use global_size_offset and global_size_count to position @@ -911,7 +897,6 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # pointer values into file, stored as int64. fout.seek(pos + global_size_offset * np.int64().itemsize) fout.write(pointers.tobytes(order='C')) - del pointers # Advance past list of pointer values pos += global_size_count * np.int64().itemsize From ba14351e1420c3061ced347a43316d78bf07e430 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 19:18:32 -0700 Subject: [PATCH 54/74] note about seek --- megatron/data/indexed_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 766f1b73e..1370af411 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -964,6 +964,12 @@ def gather_files_dist_check_impltype(filelist, distctx): # and each rank directly merges its set of files into filemain. # It is allowed for the input files in filelist to only be readable from the calling process. # The output file in filemain should be in a location that is writable by all processes. +# +# NOTE: This uses parallel writes to a shared file to achieve high write bandwidth. +# To do so, this implementation seeks beyond the end of the file to write at different +# offsets from different processes via the seek() method on a python file handle. +# The behavior of seek() is not well documented, but it seems to map to fseek()/lseek(), +# and it works as desired on POSIX-compliant file systems like Lustre and GPFS. def gather_files_dist(filemain, filelist, distctx): # Check that at least one input file is listed filecount = distctx.sum(len(filelist)) From 852fdd0cc7ebc313ff1c3999bfd26a897eb8ebbd Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 19:21:09 -0700 Subject: [PATCH 55/74] rename preprocess_dataset_mpi.py to preprocess_data_dist.py --- tools/{preprocess_dataset_mpi.py => preprocess_data_dist.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tools/{preprocess_dataset_mpi.py => preprocess_data_dist.py} (100%) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_data_dist.py similarity index 100% rename from tools/preprocess_dataset_mpi.py rename to tools/preprocess_data_dist.py From 61f4b467f4035cc9c55da51002d2d3823f9c43ef Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 19:31:25 -0700 Subject: [PATCH 56/74] update usage comments at top of script --- tools/preprocess_data_dist.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 331764a29..bd12b70fc 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -20,19 +20,24 @@ from datasets import load_dataset dset = load_dataset('openwebtext', split='train') -The implementation can use `mpi4py` or `torch.distributed` for node communication, and it assumes that -files are written to a global file system, such that one process +The implementation can use `mpi4py` or `torch.distributed` for node communication, +and it assumes that files are written to a global file system, such that one process can read a file written by another process. A list of sample index values from the source dataset are selected -by rank 0 and broadcast to all ranks. +by rank 0 and scattered to all ranks. Each process tokenizes a subset of samples and writes its output to a part file. -After all ranks have finished, rank 0 merges and deletes the part files. +After all ranks have finished, the part files are merged into a final output file. + +One may optionally use storage local to each process to store the part file. +For example, on a Linux cluster, one might write the part file to /dev/shm. To run: -mpiexec -np 320 python preprocess_dataset_mpi.py \ +mpiexec -np 320 python preprocess_data_dist.py \ --input openwebtext \ + --count 1_000_000 \ + --scratch /dev/shm \ --shuffle \ --seed 100 \ --output-prefix openwebtext-bert \ @@ -54,7 +59,6 @@ import random import torch -import torch.distributed as dist try: import nltk nltk_available = True From 22400f373cf2118080509ee18763a18ba9c6f26a Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 19:59:49 -0700 Subject: [PATCH 57/74] restore commented print_rank_0 statements --- megatron/data/indexed_dataset.py | 8 ++++---- tools/preprocess_data_dist.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1370af411..ad5343db5 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -432,21 +432,21 @@ def __init__(self, path, skip_warmup=False): offset = stream.tell() if not skip_warmup: -# print_rank_0(" warming up index mmap file...") + print_rank_0(" warming up index mmap file...") _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) -# print_rank_0(" reading sizes...") + print_rank_0(" reading sizes...") self._sizes = np.frombuffer( self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) -# print_rank_0(" reading pointers...") + print_rank_0(" reading pointers...") self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) -# print_rank_0(" reading document index...") + print_rank_0(" reading document index...") self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, offset=offset + self._sizes.nbytes + self._pointers.nbytes) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index bd12b70fc..9af3721c2 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -546,8 +546,6 @@ def rank_files_merge_serial(args): for rank in range(args.numranks): for key in args.columns: infile = get_filename(args, key, rank) - -# msg(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes From 5cfcb955b32ca6855158f2861dbae1b73cc57eca Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Wed, 18 Aug 2021 20:01:35 -0700 Subject: [PATCH 58/74] restore status message in mmap merge_file_ --- megatron/data/indexed_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index ad5343db5..ffbc7ae0a 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -591,8 +591,8 @@ def merge_file_(self, another_file): index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype -# total_len = len(index.sizes)+len(self._sizes) -# print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") + total_len = len(index.sizes)+len(self._sizes) + print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") offset = len(self._sizes) self._sizes.extend(index.sizes) From 74c48831aab50d74836c47dab022147aefa58013 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 19 Aug 2021 09:23:04 -0700 Subject: [PATCH 59/74] drop mpi4py, sad :( --- megatron/data/distdata.py | 146 +++++++++++----------------------- tools/preprocess_data_dist.py | 13 +-- 2 files changed, 49 insertions(+), 110 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index b7549bed5..b6f455198 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -8,27 +8,14 @@ class DistDataError(Exception): pass class DistData(object): - def __init__(self, backend='gloo', use_mpi4py=False): - # use mpi4py instead of torch.distributed if requested - self.mpi4py = None - if use_mpi4py: - try: - from mpi4py import MPI - self.mpi4py = MPI - except: - #print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) - pass + def __init__(self, backend='gloo'): + assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" + + dist.init_process_group(backend, init_method="env://") # lookup our process rank and the group size - if self.mpi4py is not None: - self.comm = self.mpi4py.COMM_WORLD - self.rank = self.comm.Get_rank() - self.numranks = self.comm.Get_size() - else: - assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" - dist.init_process_group(backend, init_method="env://") - self.rank = dist.get_rank() - self.numranks = dist.get_world_size() + self.rank = dist.get_rank() + self.numranks = dist.get_world_size() def allassert(self, cond, msg): """Check that cond is True on all ranks, assert with msg everywhere if not.""" @@ -51,39 +38,30 @@ def allraise_if(self, err): def barrier(self): """Globally synchronize all processes""" - if self.mpi4py is not None: - self.comm.barrier() - else: - dist.barrier() + dist.barrier() def bcast(self, val, root): """Broadcast a scalar value from root to all ranks""" - if self.mpi4py is not None: - return self.comm.bcast(val, root=root) - else: - vals = [val] - dist.broadcast_object_list(vals, src=root) - return vals[0] + vals = [val] + dist.broadcast_object_list(vals, src=root) + return vals[0] def bcast_list(self, vals, root=0): """Broadcast list of vals from root to all ranks, returns newly allocated list""" - if self.mpi4py is not None: - return self.comm.bcast(vals, root=root) + # broadcast length of vals list + length = [len(vals)] + dist.broadcast_object_list(length, src=root) + + # allocate a tensor of appropriate size + # initialize tensor with list values on root + if self.rank == root: + tvals = torch.tensor(vals, dtype=torch.int64) else: - # broadcast length of vals list - length = [len(vals)] - dist.broadcast_object_list(length, src=root) - - # allocate a tensor of appropriate size - # initialize tensor with list values on root - if self.rank == root: - tvals = torch.tensor(vals, dtype=torch.int64) - else: - tvals = torch.zeros(length[0], dtype=torch.int64) + tvals = torch.zeros(length[0], dtype=torch.int64) - # broadcast tensor from root, and return as a new list - dist.broadcast(tvals, src=root) - return tvals.tolist() + # broadcast tensor from root, and return as a new list + dist.broadcast(tvals, src=root) + return tvals.tolist() def scatterv_(self, invals, counts, outval, root=0): """Scatter int64 values from invals according to counts array, receive values in outval""" @@ -97,68 +75,39 @@ def scatterv_(self, invals, counts, outval, root=0): self.allassert(outval.dtype == np.int64, f"Requires outval to be of type numpy.int64") - if self.mpi4py is not None: - counts = np.array(counts) - displs = np.cumsum(counts) - counts - self.comm.Scatterv([invals, counts, displs, self.mpi4py.INT64_T], outval, root=root) - else: - scatterlist = None - if self.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) + scatterlist = None + if self.rank == root: + scatterlist = list(torch.split(torch.from_numpy(invals), counts)) + outtensor = torch.from_numpy(outval) + dist.scatter(outtensor, scatterlist, src=root) def alltrue(self, val): """Returns True if all procs input True, False otherwise""" - if self.mpi4py is not None: - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - self.comm.Allreduce(inval, outval, op=self.mpi4py.LAND) - return bool(outval[0]) - else: - # torch.dist does not support reductions with bool types - # so we cast to int and cast the result back to bool - tensor = torch.tensor([int(val)], dtype=torch.int32) - dist.all_reduce(tensor, op=dist.ReduceOp.BAND) - return bool(tensor[0]) + # torch.dist does not support reductions with bool types + # so we cast to int and cast the result back to bool + tensor = torch.tensor([int(val)], dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.BAND) + return bool(tensor[0]) def sum(self, val): """Compute sum of a scalar val, and return total on all ranks.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.mpi4py.SUM) - return outsize[0] - else: - tensor = torch.tensor([val]) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return tensor[0] + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor[0] def exscan(self, val): """Compute prefix sum (exclusive scan) of scalar val, and return offset of each rank.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Scan(insize, outsize, op=self.mpi4py.SUM) - return outsize[0] - insize[0] - else: - # torch.distributed doesn't have a scan, so fallback to allreduce - tensor = torch.zeros(self.numranks, dtype=torch.int64) - tensor[self.rank:] = val - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return int(tensor[self.rank]) - val + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(self.numranks, dtype=torch.int64) + tensor[self.rank:] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[self.rank]) - val def min(self, val): """Return minimum of scalar val to all ranks.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.mpi4py.MIN) - return outsize[0] - else: - tensor = torch.tensor([val]) - dist.all_reduce(tensor, op=dist.ReduceOp.MIN) - return tensor[0] + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return tensor[0] def minrank(self, cond): """Find first rank whose condition is True, return that rank if any, None otherwise.""" @@ -186,13 +135,8 @@ def bcast_first(self, val): def all_sum_(self, vals): """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" - if self.mpi4py is not None: - outval = np.zeros_like(vals) - self.comm.Allreduce(vals, outval, op=self.mpi4py.SUM) - vals[:] = outval - else: - tensor = torch.from_numpy(vals) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = torch.from_numpy(vals) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) def open(self, filename): """Create, truncate, and open a file shared by all ranks.""" diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 9af3721c2..7c10a9b56 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -20,7 +20,7 @@ from datasets import load_dataset dset = load_dataset('openwebtext', split='train') -The implementation can use `mpi4py` or `torch.distributed` for node communication, +The implementation uses `torch.distributed` for inter-process communication, and it assumes that files are written to a global file system, such that one process can read a file written by another process. @@ -34,12 +34,10 @@ To run: -mpiexec -np 320 python preprocess_data_dist.py \ +python -m torch.distributed.launch --nproc_per_node 8 preprocess_data_dist.py \ --input openwebtext \ - --count 1_000_000 \ + --count 1_000 \ --scratch /dev/shm \ - --shuffle \ - --seed 100 \ --output-prefix openwebtext-bert \ --vocab bert-large-uncased-vocab.txt \ --dataset-impl mmap \ @@ -176,8 +174,6 @@ def get_args(): help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') - group.add_argument('--mpi4py', action='store_true', - help='Assume script has been launched as an MPI job, and use mpi4py for communication.') group.add_argument('--merge', type=str, default='parallel', choices=['parallel', 'serial', 'both'], help=('Method to merge intermediate per-rank files into the final data files. ' 'With "parallel", each rank writes directly to the final files, ' @@ -202,8 +198,7 @@ def get_args(): args.vocab_extra_ids = 0 # initialize our distributed environment - # use mpi4py instead of torch.distributed if requested - args.distctx = DistData(use_mpi4py=args.mpi4py, backend=args.torch_backend) + args.distctx = DistData(backend=args.torch_backend) # some functions like build_tokenizer use args.rank to filter stdout messages args.rank = args.distctx.rank From 78ab7158d57caf7ef2eeda54567a82851f266292 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 19 Aug 2021 09:56:30 -0700 Subject: [PATCH 60/74] add test case for parallel merge --- tests/test_preprocessing.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 4d27d94b3..f543065fb 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -110,3 +110,35 @@ def test_process_data_microsoft(self): self.assertTrue(Path(tgt_path).exists(), ) self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def test_process_data_dist_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + input_path = f"{self.tests_dir}/tools/openwebtext-1000.jsonl" + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + for ext in ["bin", "idx"]: + tgt_path = f"{output_prefix}_text_document.{ext}" + ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + From 002b40322200693c06934894d478166ba837598c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 19 Aug 2021 10:08:41 -0700 Subject: [PATCH 61/74] add preprocess_data_dist test for serial merge --- tests/test_preprocessing.py | 40 +++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index f543065fb..1f9454574 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -110,13 +110,44 @@ def test_process_data_microsoft(self): self.assertTrue(Path(tgt_path).exists(), ) self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def compare_meg_data_files(self, tgt, ref): + for ext in ["bin", "idx"]: + tgt_path = f"{tgt}.{ext}" + ref_path = f"{ref}.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def test_process_data_dist_microsoft(self): """We want to be stable to Microsoft version.""" src_dir = self.src_dir data_dir = f"{self.data_dir}/gpt2" output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) - input_path = f"{self.tests_dir}/tools/openwebtext-1000.jsonl" + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_serial_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" @@ -124,6 +155,7 @@ def test_process_data_dist_microsoft(self): python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py --input openwebtext-10k --count 1000 + --merge serial --output-prefix {output_prefix} --dataset-impl mmap --tokenizer-type GPT2BPETokenizer @@ -136,9 +168,5 @@ def test_process_data_dist_microsoft(self): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) - for ext in ["bin", "idx"]: - tgt_path = f"{output_prefix}_text_document.{ext}" - ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" - self.assertTrue(Path(tgt_path).exists(), ) - self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") From ba763f7ce8edcd298162abe6fb8481fc4a3b0258 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 12:31:59 -0700 Subject: [PATCH 62/74] improve error handling --- megatron/data/distdata.py | 98 ++++++-- megatron/data/indexed_dataset.py | 379 +++++++++++++++++-------------- tools/preprocess_data_dist.py | 28 ++- 3 files changed, 296 insertions(+), 209 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index b6f455198..155a4b658 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -1,3 +1,4 @@ +import os import numpy as np import torch @@ -18,12 +19,26 @@ def __init__(self, backend='gloo'): self.numranks = dist.get_world_size() def allassert(self, cond, msg): - """Check that cond is True on all ranks, assert with msg everywhere if not.""" + """Check that cond is True on all ranks, assert with msg everywhere if not. + + To prevent deadlocks in cases where an assertion might only fail on one rank, + this executes an allreduce to ensure that if any rank finds that an assertion + has been violated, all ranks fail an assertion check. + The condition must be true on all ranks for this not to assert. + """ alltrue = self.alltrue(cond) assert alltrue, msg def allraise_if(self, err): - """Raise exception if err is not None on any rank.""" + """Raise exception if err is not None on any rank. + + Similarly to allassert, this raises an exception on all ranks if err + is set to an exception on any rank. Rank(s) where err is not None + re-raise err as exception, and ranks where err is None raise DistDataError. + Thus all ranks raise an exception if any rank has an active exception, + which helps avoid deadlocks in cases where an exception may be raised + on a subset of ranks. + """ alltrue = self.alltrue(err is None) if not alltrue: # At least one rank raised an exception. @@ -46,24 +61,7 @@ def bcast(self, val, root): dist.broadcast_object_list(vals, src=root) return vals[0] - def bcast_list(self, vals, root=0): - """Broadcast list of vals from root to all ranks, returns newly allocated list""" - # broadcast length of vals list - length = [len(vals)] - dist.broadcast_object_list(length, src=root) - - # allocate a tensor of appropriate size - # initialize tensor with list values on root - if self.rank == root: - tvals = torch.tensor(vals, dtype=torch.int64) - else: - tvals = torch.zeros(length[0], dtype=torch.int64) - - # broadcast tensor from root, and return as a new list - dist.broadcast(tvals, src=root) - return tvals.tolist() - - def scatterv_(self, invals, counts, outval, root=0): + def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0): """Scatter int64 values from invals according to counts array, receive values in outval""" self.allassert(len(counts) == self.numranks, @@ -75,6 +73,11 @@ def scatterv_(self, invals, counts, outval, root=0): self.allassert(outval.dtype == np.int64, f"Requires outval to be of type numpy.int64") + # Note that we create the torch.tensor outtensor to receive incoming + # data by using from_numpy. This sets outtensor to point to the same + # underlying memory of the outval numpy array. Receiving data into + # outtensor thus writes incoming data directly into the numpy array. + scatterlist = None if self.rank == root: scatterlist = list(torch.split(torch.from_numpy(invals), counts)) @@ -95,8 +98,8 @@ def sum(self, val): dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return tensor[0] - def exscan(self, val): - """Compute prefix sum (exclusive scan) of scalar val, and return offset of each rank.""" + def exscan(self, val: int): + """Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank.""" # torch.distributed doesn't have a scan, so fallback to allreduce tensor = torch.zeros(self.numranks, dtype=torch.int64) tensor[self.rank:] = val @@ -133,12 +136,13 @@ def bcast_first(self, val): val = self.bcast(val, root=minrank) return val - def all_sum_(self, vals): + def all_sum_(self, vals: np.array): """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" + # Builds torch.tensor with from_numpy to use same underlying memory as numpy array. tensor = torch.from_numpy(vals) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - def open(self, filename): + def open(self, filename, truncate=None): """Create, truncate, and open a file shared by all ranks.""" # Don't truncate existing file until all ranks reach this point @@ -151,6 +155,12 @@ def open(self, filename): if self.rank == 0: try: f = open(filename, 'wb') + + # Some file systems like GPFS deliver faster write speed + # if the file size is known before data is written to the file. + if truncate is not None: + f.truncate(truncate) + except Exception as e: err = e @@ -169,3 +179,43 @@ def open(self, filename): self.allraise_if(err) return f + + def remove(self, filename): + """Remove a shared file.""" + + # Don't remove the file until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 removes the file if it exists. + if self.rank == 0: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + err = e + + # Verify that rank 0 successfully removed the file. + self.allraise_if(err) + + def rename(self, srcfile, destfile): + """Rename a shared file.""" + + # Don't rename until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 renames the file. + if self.rank == 0: + try: + if os.path.exists(srcfile): + os.rename(srcfile, destfile) + except Exception as e: + err = e + + # Verify that the rename succeeded + self.allraise_if(err) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index ffbc7ae0a..9b91aaaf0 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -12,6 +12,7 @@ from functools import lru_cache import os +import stat import shutil import struct from itertools import accumulate @@ -615,62 +616,72 @@ def finalize(self, index_file): # spot, and copy each file. def gather_files_dist_bin(outfile, filelist, distctx): """Concatenate binary files in filelist into a new file given by outfile""" - import stat - import shutil - # lookup size of each of our binary files filesizes = [os.stat(data_file_path(f))[stat.ST_SIZE] for f in filelist] - # compute offset this rank should start copying - # its data into the merged file + # compute total bytes of the merged file and the offset + # at which this rank will write data from its files numbytes = sum(filesizes) + count = distctx.sum(numbytes) offset = distctx.exscan(numbytes) - # Create shared output file. - with distctx.open(data_file_path(outfile)) as fout: - # Seek to appropriate starting offset in the merged file. - fout.seek(offset) + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = data_file_path(outfile) + finaltmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + + # Catch I/O errors from any process + err = None + try: + # Create shared output file and pre-truncate to its final size. + with distctx.open(finaltmp, truncate=count) as fout: + # Seek to appropriate starting offset in the merged file. + fout.seek(offset) - # Copy in contents of each of our files. - for f in filelist: - with open(data_file_path(f), "rb") as fsrc: - shutil.copyfileobj(fsrc, fout) + # Copy in contents of each of our files. + for f in filelist: + with open(data_file_path(f), "rb") as fsrc: + shutil.copyfileobj(fsrc, fout) + except Exception as e: + err = e - # TODO: check that all ranks wrote successfully - distctx.barrier() + # Check that all ranks wrote successfully. + # This will raise an exception all on ranks if we detect + # an exception on any rank. + distctx.allraise_if(err) + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finaltmp, finalname) -def write_list(fout, pos, vals, shift, offset, total, dtype): - """Write list of values to fout and return the number of bytes written assuming the total list size. + +def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype): + """Write list of values to fout. Copies list of values in vals to a numpy array of type dtype. Adds a constant value to all elements as given in shift. Writes the numpy array to the file handle at given offset and scaled by size of the datatype. byteoffset = pos + vals * dtype().itemsize - Computes and return the total bytes written to write total elements of type dtype. Parameters ---------- fout : file handle Opened file handle to which to write vals - pos : int + file_offset : int Byte offset within the file where the global list starts - vals : list(int) + vals : list[int] List of values to be written shift : int Value to add to each element in vals before writing (use 0 for no change) - offset : int - Zero-based element index where vals starts within the global list - total : int - Total number of elements within the global list - dtype : numpy datatype + elem_offset : int + Zero-based element index where vals starts within the global list. + This value will be scaled by dtype().itemsize to convert to the corresponding number of bytes. + dtype : np.dtype numpy datatype to be used when writing the list to the file - - Returns - ------- - int - Number of bytes that would be required to write the global list - of length 'total' and of type 'dtype' """ # Make a copy of the vals list using the requested datatype. @@ -681,22 +692,21 @@ def write_list(fout, pos, vals, shift, offset, total, dtype): # Seek to proper offset for this rank and write # values into file, stored as given datatype. - fout.seek(pos + offset * dtype().itemsize) + fout.seek(file_offset + elem_offset * dtype().itemsize) fout.write(npvals.tobytes(order='C')) - # Return number of bytes written - return total * dtype().itemsize - -def gather_files_dist_check_dtype(filelist, dtype_valid, dtype_code, distctx): - # verify that no rank has found an inconsistent value in its own set of files - distctx.allassert(dtype_valid, "Some rank found inconsistent dtype values") +def gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_code, distctx): + # Verify that no rank has found an inconsistent value in its own set of files. + # This includes an allreduce to verify that dtype_rank_consistent is True everywhere. + distctx.allassert(dtype_rank_consistent, "Some rank found inconsistent dtype values") - # verify that at least one rank found a dtype value + # Verify that at least one rank found a dtype value. + # Because of the bcast, the the value of first_dtype_code will be the same on all ranks. first_dtype_code = distctx.bcast_first(dtype_code) assert first_dtype_code is not None, "Failed to find a dtype value in any index file" - # verify that the dtype is consistent on all ranks, if a rank has a dtype value + # Verify that the dtype is consistent on all ranks, if a rank has a dtype value. distctx.allassert(dtype_code == first_dtype_code or dtype_code is None, "Different dtype values detected in index files") # return the dtype @@ -708,9 +718,9 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): sizes = [] data_offsets = [0] dim_offsets = [0] - docs = [0] - dtype_valid = True # whether rank identifies inconsistent values in its files - dtype_value = None # the current dtype code, if any + doc_idx = [0] + dtype_rank_consistent = True # whether this rank identifies inconsistent dtype values in its files + dtype_value = None # the current dtype code to compare against, if any for f in filelist: # read index file for this file index = IndexedDataset(f) @@ -720,41 +730,41 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): sizes.extend(index.sizes) data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) - docs.extend(index.doc_idx[1:] + doc_offset) + doc_idx.extend(index.doc_idx[1:] + doc_offset) # check that the dtype in this index matches the dtype in our other files dtype_code = code(index.dtype) if dtype_value is None: dtype_value = dtype_code if dtype_value != dtype_code: - dtype_valid = False + dtype_rank_consistent = False - # Check that we have consistent dtypes in all files from all ranks - dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + # Check that we have consistent dtypes in all files from all ranks, + # and return the dtype being used. + dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx) - # Capture the last value in dim and data arrays before we delete any items. + # Capture the last value in the data array before we delete any items. # Note this may be zero on any rank that has no items, # but zero is the correct value in that case. - # We use this last value to compute a shift value that will - # later be added to each element in our lists. - dim_shift = distctx.exscan(dim_offsets[-1]) + # We use this last value to compute a shift value that + # is later be added to each element in our data list. data_shift = distctx.exscan(data_offsets[-1]) # Drop the zero entry from the lists that start with - # a "0" value unless we're rank 0 + # a "0" value unless we're rank 0. if distctx.rank != 0: del data_offsets[0] del dim_offsets[0] - del docs[0] + del doc_idx[0] # Compute total number of entires in data, size, dim, - # and docs lists across all ranks. Also compute the offset + # and doc_idx lists across all ranks. Also compute the offset # of the calling rank for each list considering the number # of entries for all ranks before the calling rank. numdata = len(data_offsets) numsize = len(sizes) numdim = len(dim_offsets) - numdoc = len(docs) + numdoc = len(doc_idx) global_data_count = distctx.sum(numdata) global_size_count = distctx.sum(numsize) @@ -766,53 +776,71 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): global_dim_offset = distctx.exscan(numdim) global_doc_offset = distctx.exscan(numdoc) - # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: - # Have rank 0 write the file header - # Broadcast number of bytes written from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if distctx.rank == 0: - pos = IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) - pos = distctx.bcast(pos, root=0) - - # TODO: is dim_shift == global_size_offset? - # The dimension list records the offset within - # the sizes list for each sentence. - # Adjust dimension index values for number of size values that - # come before the calling rank which is in dim_shift. - pos += write_list(fout, pos, dim_offsets, dim_shift, global_dim_offset, global_dim_count, np.int64) - - # The data index records the element offset to the start of each - # sentence within the binary data file, expressed in units of dtype().itemsize. - # Adjust data index values for number of elements that - # come before the calling rank, which is in data_shift. - pos += write_list(fout, pos, data_offsets, data_shift, global_data_offset, global_data_count, np.int64) - - # Each sentence is stored as a tensor. - # The tensor for each sentence can be multidimensional. - # The number of tensor dimensions per sentence is variable, - # and the size of each dimension of a sentence is arbitrary. - # The size list records a flattened list of the sizes - # for each dimension of a sentence. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int64) - - # The document index points to the position in the sizes - # array for the first sentence of each document. - # Adjust document index for number of sentences that - # come before the calling rank which is in global_size_offset. - pos += write_list(fout, pos, docs, global_size_offset, global_doc_offset, global_doc_count, np.int64) - - # TODO: check that all ranks wrote successfully - distctx.barrier() + # Catch and I/O errors to later determine whether all ranks wrote successfully. + err = None + try: + # Create shared output file + with distctx.open(index_file_path(outfile)) as fout: + # Have rank 0 write the file header + file_offset = 0 + if distctx.rank == 0: + try: + file_offset = fout.tell() + file_offset += IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) + except Exception as e: + err = e + distctx.allraise_if(err) + + # Broadcast current file position from rank 0. + file_offset = distctx.bcast(file_offset, root=0) + + # The dimension list records the offset within + # the sizes list for each sentence. + # We shift our dimension index values to account for the number of size values + # that come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, dim_offsets, global_size_offset, global_dim_offset, np.int64) + file_offset += global_dim_count * np.int64().itemsize + + # The data index records the element offset to the start of each + # sentence within the binary data file. Note that this is an + # element offset, not a byte offset. Each element is pyhsically stored + # in the data file as dtype().itemsize bytes. + # We shift the data index values according to the number of elements that + # come before the calling rank, which is stored in data_shift. + write_list_at_offset(fout, file_offset, data_offsets, data_shift, global_data_offset, np.int64) + file_offset += global_data_count * np.int64().itemsize + + # Each sentence is stored as a tensor. + # The tensor for each sentence can be multidimensional. + # The number of tensor dimensions per sentence is variable, + # and the size of each dimension of a sentence is arbitrary. + # The size list records a flattened list of the sizes + # for each dimension of a sentence. + # No shift value is needed. + write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int64) + file_offset += global_size_count * np.int64().itemsize + + # The document index records the offset within the sizes + # array for the first sentence of each document. + # We shift the document index values for number of size values that + # come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_doc_offset, np.int64) + file_offset += global_doc_count * np.int64().itemsize + + except Exception as e: + # if we encounter any exception while writing, store it for later + err = e + + # Check that all ranks wrote successfully + distctx.allraise_if(err) def gather_files_dist_idx_mmap(outfile, filelist, distctx): - # Read each index file and append items to the size and docs lists + # Read each index file and append items to the size and doc_idx lists sizes = [] - docs = [0] - dtype_valid = True # whether rank identifies inconsistent values in its files - dtype_value = None # the current dtype code, if any + doc_idx = [0] + dtype_rank_consistent = True # whether rank identifies inconsistent dtype values in its files + dtype_value = None # the current dtype code to compare against, if any for f in filelist: # read index file for this file index = MMapIndexedDataset.Index(index_file_path(f)) @@ -820,22 +848,23 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # append its size and doc entries to our lists docs_offset = len(sizes) sizes.extend(index.sizes) - docs.extend(index.doc_idx[1:] + docs_offset) + doc_idx.extend(index.doc_idx[1:] + docs_offset) # check that the dtype in this index matches the dtype in our other files dtype_code = code(index.dtype) if dtype_value is None: dtype_value = dtype_code if dtype_value != dtype_code: - dtype_valid = False + dtype_rank_consistent = False - # Check that we have consistent dtypes in all files from all ranks - dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + # Check that we have consistent dtypes in all files from all ranks, + # and return the dtype being used. + dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx) # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 if distctx.rank != 0: - del docs[0] + del doc_idx[0] # Compute total number of size and document index # values across all ranks. Also compute the offset @@ -843,7 +872,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # the values of sizes/docs for all ranks before the # calling rank. numsizes = len(sizes) - numdocs = len(docs) + numdocs = len(doc_idx) global_size_count = distctx.sum(numsizes) global_docs_count = distctx.sum(numdocs) @@ -851,65 +880,76 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): global_size_offset = distctx.exscan(numsizes) global_docs_offset = distctx.exscan(numdocs) - # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: - # Have rank 0 write the file header - # Broadcast number of bytes written from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if distctx.rank == 0: - pos = MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) - pos = distctx.bcast(pos, root=0) - - # The list of size values from each rank are - # concatenated and stored as int32. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int32) - - # The pointer values store the byte offset to each sentence when in memory. - # A sentence has a variable number of tokens, given by - # its corresponding entry in the size array. Each token - # of a sentence is stored in units of type dtype, which consumes - # dtype().itemsize bytes (often a standard type that is just - # large enough to represent all elements of the vocabulary). - - # Compute byte sizes for each of our sentences given - # the token count and vocab dtype. - bytesizes = np.array(sizes, dtype=np.int64) - bytesizes *= dtype().itemsize - - # Inclusive scan to sum number of bytes over sentences. - pointers = np.cumsum(bytesizes, axis=0) - - # Account for bytes for all sentences on ranks - # before the calling rank. - bytes_last = pointers[-1] if len(sizes) > 0 else 0 - pointer_offset = distctx.exscan(bytes_last) - pointers += pointer_offset - - # Convert to exclusive scan to get global offset. - pointers -= bytesizes - - # Since the pointers array is the same length as the sizes array, - # we use global_size_offset and global_size_count to position - # within the file for writing the pointer values. - - # Seek to proper offset for this rank and write - # pointer values into file, stored as int64. - fout.seek(pos + global_size_offset * np.int64().itemsize) - fout.write(pointers.tobytes(order='C')) - - # Advance past list of pointer values - pos += global_size_count * np.int64().itemsize - - # The document index points to the position in the sizes - # array for the starting sentence of each document. - # A variable number of sentences can be in each document. - # Adjust document index for number of sentences that - # come before the calling rank which is in global_size_offset. - pos += write_list(fout, pos, docs, global_size_offset, global_docs_offset, global_docs_count, np.int64) - - # TODO: check that all ranks wrote successfully - distctx.barrier() + # Catch and I/O errors to later determine whether all ranks wrote successfully. + err = None + try: + # Create shared output file + with distctx.open(index_file_path(outfile)) as fout: + # Have rank 0 write the file header + file_offset = 0 + if distctx.rank == 0: + try: + file_offset = fout.tell() + file_offset += MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) + except Exception as e: + err = e + distctx.allraise_if(err) + + # Broadcast current file position from rank 0. + file_offset = distctx.bcast(file_offset, root=0) + + # The list of size values from each rank are + # concatenated and stored as int32. + write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int32) + file_offset += global_size_count * np.int32().itemsize + + # The pointer values store the byte offset to each sentence when in memory. + # A sentence has a variable number of tokens, given by + # its corresponding entry in the size array. Each token + # of a sentence is stored in units of type dtype, which consumes + # dtype().itemsize bytes (often a standard type that is just + # large enough to represent all elements of the vocabulary). + + # Compute byte sizes for each of our sentences given + # the token count and vocab dtype. + bytesizes = np.array(sizes, dtype=np.int64) + bytesizes *= dtype().itemsize + + # Inclusive scan to sum number of bytes over sentences. + pointers = np.cumsum(bytesizes, axis=0) + + # Account for bytes for all sentences on ranks + # before the calling rank. + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + pointer_offset = distctx.exscan(bytes_last) + pointers += pointer_offset + + # Convert to exclusive scan to get global offset. + pointers -= bytesizes + + # Since the pointers array is the same length as the sizes array, + # we use global_size_offset and global_size_count to position + # within the file for writing the pointer values. + + # Seek to proper offset for this rank and write + # pointer values into file, stored as int64. + write_list_at_offset(fout, file_offset, pointers, 0, global_size_offset, np.int64) + file_offset += global_size_count * np.int64().itemsize + + # The document index points to the position in the sizes + # array for the starting sentence of each document. + # A variable number of sentences can be in each document. + # We shift the document index for number of sentences that + # come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_docs_offset, np.int64) + file_offset += global_docs_count * np.int64().itemsize + + except Exception as e: + # if we encounter any exception while writing, store it for later + err = e + + # Check that all ranks wrote successfully + distctx.allraise_if(err) # Verify that all files in filelist are of the same index type. @@ -917,14 +957,10 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): def gather_files_dist_check_impltype(filelist, distctx): # Sanity check for typos in file names. # Check that a data file exists for each of our files. - exists = True - for f in filelist: - binfile = data_file_path(f) - if not os.path.exists(binfile): - exists = False + all_files_exist = all([os.path.exists(data_file_path(f)) for f in filelist]) # Check that all ranks have all of their files. - distctx.allassert(exists, "Some rank is missing its input file") + distctx.allassert(all_files_exist, "Some rank is missing its input file") # map type string to an integer for easier bcast, use 0 for unknown implmap = {"cached": 1, "mmap": 2} @@ -958,6 +994,9 @@ def gather_files_dist_check_impltype(filelist, distctx): if implmap[key] == bcasttype: return key + # raise exception if key for bcasttype was not found + raise UnreachableCode + # Collectively merge files into a new output file specified in filemain. # Each rank contributes a distinct list of zero or more files in filelist, @@ -973,7 +1012,7 @@ def gather_files_dist_check_impltype(filelist, distctx): def gather_files_dist(filemain, filelist, distctx): # Check that at least one input file is listed filecount = distctx.sum(len(filelist)) - assert filecount > 0, "No rank has any input files to merge" + assert filecount > 0, "All ranks have no input files to merge" # Check that files are all of the same index type indexstr = gather_files_dist_check_impltype(filelist, distctx) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 7c10a9b56..0c37e2bc6 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -34,15 +34,14 @@ To run: -python -m torch.distributed.launch --nproc_per_node 8 preprocess_data_dist.py \ - --input openwebtext \ - --count 1_000 \ - --scratch /dev/shm \ - --output-prefix openwebtext-bert \ - --vocab bert-large-uncased-vocab.txt \ - --dataset-impl mmap \ - --tokenizer-type BertWordPieceLowerCase \ - --split-sentences +python -m torch.distributed.launch --nproc_per_node 40 --nnodes 8 \ + preprocess_data_dist.py \ + --input openwebtext \ + --output-prefix openwebtext-bert \ + --vocab bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences """ import argparse @@ -191,12 +190,6 @@ def get_args(): args = parser.parse_args() args.keep_empty = False - # some default/dummy values for the tokenizer - args.rank = 0 - args.make_vocab_size_divisible_by = 128 - args.tensor_model_parallel_size = 1 - args.vocab_extra_ids = 0 - # initialize our distributed environment args.distctx = DistData(backend=args.torch_backend) @@ -204,6 +197,11 @@ def get_args(): args.rank = args.distctx.rank args.numranks = args.distctx.numranks + # some default/dummy values for the tokenizer + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + args.vocab_extra_ids = 0 + if args.tokenizer_type.lower().startswith('bert'): if not args.split_sentences: if args.rank == 0: From fa111591879fadedeeb3fec985418e03499eadb0 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 13:35:50 -0700 Subject: [PATCH 63/74] refactor get_pointers code --- megatron/data/indexed_dataset.py | 50 ++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 9b91aaaf0..22eabd18a 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -352,6 +352,26 @@ def _warmup_mmap_file(path): pass +def get_pointers(sizes, elemsize, dtype): + """Return a numpy array of type dtype giving the inclusive scan of byte sizes. + + Multiplies values in the sizes array by elemsize (bytes), + and then computes an inclusive scan to get cumulative byte sizes. + """ + pointers = np.array(sizes, dtype=dtype) + pointers *= elemsize + np.cumsum(pointers, axis=0, out=pointers) + return pointers + +def exscan_from_cumsum_(arr): + # given an array holding the result of an inclusive scan (cumsum), + # convert to an exclusive scan (shift to the right) + # [10, 30, 35, 50] --> [0, 10, 30, 35] + if arr.size > 0: + arr[1:] = arr[:-1] + arr[0] = 0 + + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @@ -386,12 +406,10 @@ def _get_pointers(sizes, npdtype): """ # compute element sizes in bytes - bytesizes = np.array(sizes, dtype=npdtype) - bytesizes *= dtype().itemsize - - # exclusive scan to get byte offsets - pointers = np.cumsum(bytesizes, axis=0) - pointers -= bytesizes + pointers = np.array(sizes, dtype=npdtype) + pointers *= dtype().itemsize + np.cumsum(pointers, axis=0, out=pointers) + exscan_from_cumsum_(pointers) return pointers @@ -402,7 +420,8 @@ def write(self, sizes, doc_idx): self._file.write(sizes32.tobytes(order='C')) del sizes32 - pointers = self._get_pointers(sizes, np.int64) + pointers = get_pointers(sizes, dtype().itemsize, np.int64) + exscan_from_cumsum_(pointers) self._file.write(pointers.tobytes(order='C')) del pointers @@ -910,22 +929,17 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # dtype().itemsize bytes (often a standard type that is just # large enough to represent all elements of the vocabulary). - # Compute byte sizes for each of our sentences given + # Compute cumulative byte sizes for each of our sentences given # the token count and vocab dtype. - bytesizes = np.array(sizes, dtype=np.int64) - bytesizes *= dtype().itemsize - - # Inclusive scan to sum number of bytes over sentences. - pointers = np.cumsum(bytesizes, axis=0) + pointers = get_pointers(sizes, dtype().itemsize, np.int64) - # Account for bytes for all sentences on ranks + # Determine total number of bytes for all sentences on ranks # before the calling rank. bytes_last = pointers[-1] if len(sizes) > 0 else 0 pointer_offset = distctx.exscan(bytes_last) - pointers += pointer_offset - # Convert to exclusive scan to get global offset. - pointers -= bytesizes + # Convert to (local) byte offsets. + exscan_from_cumsum_(pointers) # Since the pointers array is the same length as the sizes array, # we use global_size_offset and global_size_count to position @@ -933,7 +947,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # Seek to proper offset for this rank and write # pointer values into file, stored as int64. - write_list_at_offset(fout, file_offset, pointers, 0, global_size_offset, np.int64) + write_list_at_offset(fout, file_offset, pointers, pointer_offset, global_size_offset, np.int64) file_offset += global_size_count * np.int64().itemsize # The document index points to the position in the sizes From 7e53fd34b60365c2b0cf8cd3621b2ba8d8e81540 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 13:49:42 -0700 Subject: [PATCH 64/74] bug fix in exscan --- megatron/data/indexed_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 22eabd18a..47740e94c 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -367,8 +367,9 @@ def exscan_from_cumsum_(arr): # given an array holding the result of an inclusive scan (cumsum), # convert to an exclusive scan (shift to the right) # [10, 30, 35, 50] --> [0, 10, 30, 35] - if arr.size > 0: + if arr.size > 1: arr[1:] = arr[:-1] + if arr.size > 0: arr[0] = 0 From 53df36f49a48a9fc76a0ea0cb6e5a1b58caacfd0 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 14:05:54 -0700 Subject: [PATCH 65/74] further refactor get_pointers --- megatron/data/indexed_dataset.py | 58 +++++++++++++++++--------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 47740e94c..b1407771c 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -352,17 +352,6 @@ def _warmup_mmap_file(path): pass -def get_pointers(sizes, elemsize, dtype): - """Return a numpy array of type dtype giving the inclusive scan of byte sizes. - - Multiplies values in the sizes array by elemsize (bytes), - and then computes an inclusive scan to get cumulative byte sizes. - """ - pointers = np.array(sizes, dtype=dtype) - pointers *= elemsize - np.cumsum(pointers, axis=0, out=pointers) - return pointers - def exscan_from_cumsum_(arr): # given an array holding the result of an inclusive scan (cumsum), # convert to an exclusive scan (shift to the right) @@ -373,6 +362,28 @@ def exscan_from_cumsum_(arr): arr[0] = 0 +def get_pointers_with_total(sizes, elemsize, dtype): + """Return a numpy array of type np.dtype giving the byte offsets. + + Multiplies values in the sizes array by elemsize (bytes), + and then computes an exclusive scan to get byte offsets. + Returns the total number of bytes as second item in a tuple. + """ + + # scale values in sizes array by elemsize to get sizes in bytes + pointers = np.array(sizes, dtype=dtype) + pointers *= elemsize + np.cumsum(pointers, axis=0, out=pointers) + + # get total number of bytes from all sizes (last element) + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + + # convert to byte offsets + exscan_from_cumsum_(pointers) + + return pointers, bytes_last + + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @@ -407,11 +418,7 @@ def _get_pointers(sizes, npdtype): """ # compute element sizes in bytes - pointers = np.array(sizes, dtype=npdtype) - pointers *= dtype().itemsize - np.cumsum(pointers, axis=0, out=pointers) - exscan_from_cumsum_(pointers) - + pointers, _ = get_pointers_with_total(sizes, dtype().itemsize, npdtype) return pointers def write(self, sizes, doc_idx): @@ -421,8 +428,7 @@ def write(self, sizes, doc_idx): self._file.write(sizes32.tobytes(order='C')) del sizes32 - pointers = get_pointers(sizes, dtype().itemsize, np.int64) - exscan_from_cumsum_(pointers) + pointers = self._get_pointers(sizes, np.int64) self._file.write(pointers.tobytes(order='C')) del pointers @@ -930,21 +936,17 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # dtype().itemsize bytes (often a standard type that is just # large enough to represent all elements of the vocabulary). + # Since the pointers array is the same length as the sizes array, + # we use global_size_offset and global_size_count to position + # within the file for writing the pointer values. + # Compute cumulative byte sizes for each of our sentences given # the token count and vocab dtype. - pointers = get_pointers(sizes, dtype().itemsize, np.int64) + pointers, pointers_bytes = get_pointers_with_total(sizes, dtype().itemsize, np.int64) # Determine total number of bytes for all sentences on ranks # before the calling rank. - bytes_last = pointers[-1] if len(sizes) > 0 else 0 - pointer_offset = distctx.exscan(bytes_last) - - # Convert to (local) byte offsets. - exscan_from_cumsum_(pointers) - - # Since the pointers array is the same length as the sizes array, - # we use global_size_offset and global_size_count to position - # within the file for writing the pointer values. + pointer_offset = distctx.exscan(pointers_bytes) # Seek to proper offset for this rank and write # pointer values into file, stored as int64. From c43348ff08ed6bcc95eb2b82a5c23142e71bd5ce Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 14:52:19 -0700 Subject: [PATCH 66/74] move exscan collective for pointers outside of try block --- megatron/data/indexed_dataset.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index b1407771c..5b58658b0 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -906,6 +906,14 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): global_size_offset = distctx.exscan(numsizes) global_docs_offset = distctx.exscan(numdocs) + # Compute local byte offsets for each of our sentences given + # the token count and byte size of the vocab dtype. + pointers, pointers_bytes = get_pointers_with_total(sizes, dtype().itemsize, np.int64) + + # Determine total number of bytes for all sentences on ranks + # before the calling rank. + pointer_offset = distctx.exscan(pointers_bytes) + # Catch and I/O errors to later determine whether all ranks wrote successfully. err = None try: @@ -935,21 +943,9 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # of a sentence is stored in units of type dtype, which consumes # dtype().itemsize bytes (often a standard type that is just # large enough to represent all elements of the vocabulary). - # Since the pointers array is the same length as the sizes array, # we use global_size_offset and global_size_count to position # within the file for writing the pointer values. - - # Compute cumulative byte sizes for each of our sentences given - # the token count and vocab dtype. - pointers, pointers_bytes = get_pointers_with_total(sizes, dtype().itemsize, np.int64) - - # Determine total number of bytes for all sentences on ranks - # before the calling rank. - pointer_offset = distctx.exscan(pointers_bytes) - - # Seek to proper offset for this rank and write - # pointer values into file, stored as int64. write_list_at_offset(fout, file_offset, pointers, pointer_offset, global_size_offset, np.int64) file_offset += global_size_count * np.int64().itemsize From 81c21dd5552556471968f2abed408bc73f9b97ea Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 15:21:08 -0700 Subject: [PATCH 67/74] clarify some comments --- megatron/data/indexed_dataset.py | 13 +++++++------ tools/preprocess_data_dist.py | 21 +++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 5b58658b0..972b3a652 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -672,6 +672,7 @@ def gather_files_dist_bin(outfile, filelist, distctx): for f in filelist: with open(data_file_path(f), "rb") as fsrc: shutil.copyfileobj(fsrc, fout) + except Exception as e: err = e @@ -686,17 +687,17 @@ def gather_files_dist_bin(outfile, filelist, distctx): def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype): - """Write list of values to fout. + """Write list of vals to fout starting at an offset given by file_offset, elem_offset, and dtype. Copies list of values in vals to a numpy array of type dtype. - Adds a constant value to all elements as given in shift. + Adds a constant shift value to all elements. Writes the numpy array to the file handle at given offset and scaled by size of the datatype. - byteoffset = pos + vals * dtype().itemsize + offset = file_offset + elem_offset * dtype().itemsize Parameters ---------- fout : file handle - Opened file handle to which to write vals + Open file handle to which to write list of vals file_offset : int Byte offset within the file where the global list starts vals : list[int] @@ -705,7 +706,7 @@ def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype): Value to add to each element in vals before writing (use 0 for no change) elem_offset : int Zero-based element index where vals starts within the global list. - This value will be scaled by dtype().itemsize to convert to the corresponding number of bytes. + This value is scaled by dtype().itemsize to convert to a corresponding byte offset. dtype : np.dtype numpy datatype to be used when writing the list to the file """ @@ -728,7 +729,7 @@ def gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_code, d distctx.allassert(dtype_rank_consistent, "Some rank found inconsistent dtype values") # Verify that at least one rank found a dtype value. - # Because of the bcast, the the value of first_dtype_code will be the same on all ranks. + # Because of the bcast, the the value of first_dtype_code is the same on all ranks. first_dtype_code = distctx.bcast_first(dtype_code) assert first_dtype_code is not None, "Failed to find a dtype value in any index file" diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 0c37e2bc6..a3e697a00 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -178,7 +178,7 @@ def get_args(): 'With "parallel", each rank writes directly to the final files, ' 'while rank 0 copies data from all per-rank files with "serial". ' 'A parallel merge can be faster, but for correctness, it requires the underlying file system ' - 'to support parallel write operations to a file shared among multiple processes. ' + 'to support parallel write operations to a file that is shared among multiple processes. ' 'One can choose "both" for testing purposes, in which case the final files written ' 'by the parallel method are given an additional ".par" extension.')) group.add_argument('--scratch', type=str, default=None, @@ -485,23 +485,24 @@ def rank_files_merge_parallel(args): merge_start = time.time() numbytes = np.zeros(1, dtype=np.int64) for key in args.columns: + # merge the per-rank file from each process into a single shared file filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) gather_files_dist(filemain, [filerank], args.distctx) - # total up bytes read in merge - binfile = data_file_path(filerank) - idxfile = index_file_path(filerank) - numbytes[0] += os.stat(binfile)[stat.ST_SIZE] - numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] + # total up bytes read during the merge + binfilerank = data_file_path(filerank) + idxfilerank = index_file_path(filerank) + numbytes[0] += os.stat(binfilerank)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfilerank)[stat.ST_SIZE] # If user want to use both a parallel and serial merge (for testing), # rename the parallel output files so that the serial merge does not clobber them. if args.merge == 'both' and args.rank == 0: - binfile = data_file_path(filemain) - idxfile = index_file_path(filemain) - os.rename(binfile, binfile + ".par") - os.rename(idxfile, idxfile + ".par") + binfilemain = data_file_path(filemain) + idxfilemain = index_file_path(filemain) + os.rename(binfilemain, binfilemain + ".par") + os.rename(idxfilemain, idxfilemain + ".par") # Total up number of bytes read across all ranks, # and wait on all ranks before stopping the timer. From adee502b6957647f95e48348bc75a8e524847e81 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 15:29:02 -0700 Subject: [PATCH 68/74] include string 1k in name of test files --- tests/test_preprocessing.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 1f9454574..3723f97bc 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -78,6 +78,13 @@ def test_preprocess_data(self): tgt_path = f"{output_prefix}_text_document.{ext}" self.assertTrue(Path(tgt_path).exists(), ) + def compare_meg_data_files(self, tgt, ref): + for ext in ["bin", "idx"]: + tgt_path = f"{tgt}.{ext}" + ref_path = f"{ref}.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def test_process_data_microsoft(self): """We want to be stable to Microsoft version.""" src_dir = self.src_dir @@ -104,18 +111,7 @@ def test_process_data_microsoft(self): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) - for ext in ["bin", "idx"]: - tgt_path = f"{output_prefix}_text_document.{ext}" - ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" - self.assertTrue(Path(tgt_path).exists(), ) - self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) - - def compare_meg_data_files(self, tgt, ref): - for ext in ["bin", "idx"]: - tgt_path = f"{tgt}.{ext}" - ref_path = f"{ref}.{ext}" - self.assertTrue(Path(tgt_path).exists(), ) - self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") def test_process_data_dist_microsoft(self): """We want to be stable to Microsoft version.""" @@ -123,7 +119,7 @@ def test_process_data_dist_microsoft(self): data_dir = f"{self.data_dir}/gpt2" output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) - output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" cmd = f""" python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py @@ -149,7 +145,7 @@ def test_process_data_dist_serial_microsoft(self): data_dir = f"{self.data_dir}/gpt2" output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) - output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" cmd = f""" python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py From 13ae421db1f06c59fb52fd8399e7ba6d92f86b0b Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Fri, 20 Aug 2021 16:12:02 -0700 Subject: [PATCH 69/74] use temporary file for index --- megatron/data/indexed_dataset.py | 36 +++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 972b3a652..025e9e333 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -655,7 +655,7 @@ def gather_files_dist_bin(outfile, filelist, distctx): # if successful or delete the temporary file if not. # This way if the final name appears, the user knows it's a valid file. finalname = data_file_path(outfile) - finaltmp = finalname + ".tmp" + finalnametmp = finalname + ".tmp" # First delete the final file if it already exists distctx.remove(finalname) @@ -664,7 +664,7 @@ def gather_files_dist_bin(outfile, filelist, distctx): err = None try: # Create shared output file and pre-truncate to its final size. - with distctx.open(finaltmp, truncate=count) as fout: + with distctx.open(finalnametmp, truncate=count) as fout: # Seek to appropriate starting offset in the merged file. fout.seek(offset) @@ -683,7 +683,7 @@ def gather_files_dist_bin(outfile, filelist, distctx): # Everyone wrote their part successfully. # Rename the temporary file to the final file. - distctx.rename(finaltmp, finalname) + distctx.rename(finalnametmp, finalname) def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype): @@ -803,11 +803,20 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): global_dim_offset = distctx.exscan(numdim) global_doc_offset = distctx.exscan(numdoc) + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = index_file_path(outfile) + finalnametmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + # Catch and I/O errors to later determine whether all ranks wrote successfully. err = None try: # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: + with distctx.open(finalnametmp) as fout: # Have rank 0 write the file header file_offset = 0 if distctx.rank == 0: @@ -861,6 +870,10 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): # Check that all ranks wrote successfully distctx.allraise_if(err) + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finalnametmp, finalname) + def gather_files_dist_idx_mmap(outfile, filelist, distctx): # Read each index file and append items to the size and doc_idx lists @@ -915,11 +928,20 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # before the calling rank. pointer_offset = distctx.exscan(pointers_bytes) + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = index_file_path(outfile) + finalnametmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + # Catch and I/O errors to later determine whether all ranks wrote successfully. err = None try: # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: + with distctx.open(finalnametmp) as fout: # Have rank 0 write the file header file_offset = 0 if distctx.rank == 0: @@ -965,6 +987,10 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # Check that all ranks wrote successfully distctx.allraise_if(err) + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finalnametmp, finalname) + # Verify that all files in filelist are of the same index type. # Returns the identified type {cached, mmap} as a string. From f3e1b1dce3b00e5ca46a095f9b4e637167505c5d Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 23 Aug 2021 15:51:45 -0700 Subject: [PATCH 70/74] fix: implement scatterv from torch.distributed.scatter --- megatron/data/distdata.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 155a4b658..6d9c1cff9 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -73,16 +73,23 @@ def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0 self.allassert(outval.dtype == np.int64, f"Requires outval to be of type numpy.int64") - # Note that we create the torch.tensor outtensor to receive incoming - # data by using from_numpy. This sets outtensor to point to the same - # underlying memory of the outval numpy array. Receiving data into - # outtensor thus writes incoming data directly into the numpy array. - + # Define list of tensors to scatter on the root. + # torch.distributed.scatter requires each tensor to be the same shape, + # so find the max size across all count values and pad. scatterlist = None if self.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) + scatterlist = [] + slices = list(torch.split(torch.from_numpy(invals), counts)) + for num, s in zip(counts, slices): + padtensor = torch.zeros(max(counts), dtype=torch.int64) + padtensor[:num] = s + scatterlist.append(padtensor) + + # Receive a tensor of the max count size from the root, + # then copy values into output numpy array, which may be smaller. + recvtensor = torch.zeros(max(counts), dtype=torch.int64) + dist.scatter(recvtensor, scatterlist, src=root) + outval[:] = recvtensor[:counts[self.rank]] def alltrue(self, val): """Returns True if all procs input True, False otherwise""" From 42962e1b5df1a12d414fd7efb5699214ee8d5439 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 24 Aug 2021 20:38:48 -0700 Subject: [PATCH 71/74] switch to pad method in torch.nn.functional --- megatron/data/distdata.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 6d9c1cff9..5770733ce 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -2,6 +2,7 @@ import numpy as np import torch +import torch.nn.functional as F import torch.distributed as dist class DistDataError(Exception): @@ -76,18 +77,15 @@ def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0 # Define list of tensors to scatter on the root. # torch.distributed.scatter requires each tensor to be the same shape, # so find the max size across all count values and pad. + max_size = max(counts) scatterlist = None if self.rank == root: - scatterlist = [] slices = list(torch.split(torch.from_numpy(invals), counts)) - for num, s in zip(counts, slices): - padtensor = torch.zeros(max(counts), dtype=torch.int64) - padtensor[:num] = s - scatterlist.append(padtensor) + scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices] # Receive a tensor of the max count size from the root, # then copy values into output numpy array, which may be smaller. - recvtensor = torch.zeros(max(counts), dtype=torch.int64) + recvtensor = torch.zeros(max_size, dtype=torch.int64) dist.scatter(recvtensor, scatterlist, src=root) outval[:] = recvtensor[:counts[self.rank]] From 9a2f3838cea7ba9ef13e767aa6b254976a8114d3 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 24 Aug 2021 21:05:12 -0700 Subject: [PATCH 72/74] return data received in scatterv as new tensor --- megatron/data/distdata.py | 12 +++--------- tools/preprocess_data_dist.py | 5 +---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index 5770733ce..659d51bb9 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -62,18 +62,12 @@ def bcast(self, val, root): dist.broadcast_object_list(vals, src=root) return vals[0] - def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0): - """Scatter int64 values from invals according to counts array, receive values in outval""" + def scatterv_(self, invals: np.array, counts: list, root:int=0): + """Scatter int64 values from invals according to counts array, return received portion in a new tensor""" self.allassert(len(counts) == self.numranks, f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") - self.allassert(outval.shape == (counts[self.rank],), - f"Rank {self.rank}: output buffer is of shape {outval.shape}, expected {(counts[self.rank],)}") - - self.allassert(outval.dtype == np.int64, - f"Requires outval to be of type numpy.int64") - # Define list of tensors to scatter on the root. # torch.distributed.scatter requires each tensor to be the same shape, # so find the max size across all count values and pad. @@ -87,7 +81,7 @@ def scatterv_(self, invals: np.array, counts: list, outval: np.array, root:int=0 # then copy values into output numpy array, which may be smaller. recvtensor = torch.zeros(max_size, dtype=torch.int64) dist.scatter(recvtensor, scatterlist, src=root) - outval[:] = recvtensor[:counts[self.rank]] + return recvtensor[:counts[self.rank]] def alltrue(self, val): """Returns True if all procs input True, False otherwise""" diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index a3e697a00..5050177b6 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -328,13 +328,10 @@ def select_sample_list(args, dset_size): # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) - # allocate space to hold its portion of the list - idx = np.zeros(counts[args.rank], np.int64) - # scatter sample index values from rank 0 to all procs # based on distribution defined in counts list time_bcast = time.time() - args.distctx.scatterv_(idxlist, counts, idx, root=0) + idx = args.distctx.scatterv_(idxlist, counts, root=0) args.distctx.barrier() time_end = time.time() From 15b7603ab0099911acb970f38071d1d98ba7d658 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 24 Aug 2021 21:39:01 -0700 Subject: [PATCH 73/74] raise exception if conflicting scratch and merge options --- tools/preprocess_data_dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index 5050177b6..a1334cd1b 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -214,7 +214,7 @@ def get_args(): # TODO: perhaps more user friendly to disable scratch and print a warning? # check that serial merge is not attempted with scratch if args.scratch is not None and args.merge != 'parallel': - assert False, "The --scratch option is only valid with --merge=parallel" + raise ValueError("The --scratch option is only valid with --merge=parallel") return args From 4adaddd7660b586497a362ba3dffd1afa9dcc2f9 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 24 Aug 2021 22:00:02 -0700 Subject: [PATCH 74/74] use allraise method from distdata in preprocess_data_dist --- tools/preprocess_data_dist.py | 51 ++++++++++++----------------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index a1334cd1b..a2c9c0970 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -245,7 +245,6 @@ def load_dset(args): # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. - success = True err = None dsetname = args.input if args.rank == 0: @@ -258,17 +257,13 @@ def load_dset(args): msgerr(f" from datasets import load_dataset") msgerr(f" dset = load_dataset('{dsetname}')") msgerr(f"Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) - success = False err = e except Exception as e: - msgerr("Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # determine whether rank 0 succeeded in loading the dataset - success = args.distctx.alltrue(success) - if not success: - return None, err + args.distctx.allraise_if(err) # Rank 0 succeeded, attempt to load dataset on all other ranks. # This should load from cache now. @@ -277,22 +272,17 @@ def load_dset(args): dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except Exception as e: # this print might be noisy, but better than nothing - msgerr("Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # verify that all ranks loaded the dataset - success = args.distctx.alltrue(success) - if not success: - if args.rank == 0: - msgerr(f"At least one process failed to load {dsetname}", flush=True) - return None, err + args.distctx.allraise_if(err) time_end = time.time() if args.rank == 0: msg(f"Seconds to load dataset: {time_end - time_start}", flush=True) - return dset, err + return dset def get_num_samples(args, dset_size): """Given a dataset size and optional count argument, return number of samples to process.""" @@ -374,7 +364,6 @@ def rank_files_write(args, dset, idx, encoder): dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes # we'll set this to false on any problem - success = True err = None times = np.zeros(3, dtype=np.float32) # read, tokenize, write try: @@ -441,7 +430,6 @@ def rank_files_write(args, dset, idx, encoder): del builders[key] # file closed in __del__ except Exception as e: # caught an exception, assume our file is invalid - success = False err = e # In case rank 0 finishes early and stops printing progress messages, @@ -473,9 +461,8 @@ def rank_files_write(args, dset, idx, encoder): msg(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") msg(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") - # allreduce to check whether all ranks wrote their part successfully - success = args.distctx.alltrue(success) - return success, err + # check whether all ranks wrote their part successfully + args.distctx.allraise_if(err) def rank_files_merge_parallel(args): """Each process directly writes its portion of the data from its per-rank file into the final file.""" @@ -599,11 +586,7 @@ def main(): startup_start = time.time() # load the dataset - dset, err = load_dset(args) - if dset is None: - if err is not None: - raise err - return + dset = load_dset(args) if args.rank == 0: print(dset) msg(f"Processing features: {args.columns}") @@ -625,21 +608,21 @@ def main(): if args.rank == 0: msg(f"Seconds to startup: {startup_end - startup_start}") - # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, idx, encoder) - if not success: + # have each rank write its file, + # all ranks should raise an exception if any rank has a problem + try: + rank_files_write(args, dset, idx, encoder) + except Exception as e: + # If any process fails, we skip the merge since the resulting file would be invalid. + # We still delete files to clean up, since those might be invalid anyway. if args.rank == 0: - # If any process fails, we skip the merge since the resulting file would be invalid. - # We still delete files to clean up, since those might be invalid anyway. msgerr(f"At least one process failed to write its file, skipping merge and cleaning up", flush=True) # delete per-rank files, do this even on error rank_files_delete(args) - # raise exception caught during write phase - if err is not None: - raise err - return + # re-raise exception caught during write phase + raise e # all ranks were successful writing their file, merge them into one rank_files_merge(args)