From a5b01fbffe1dd831331182ba86ba61eab6c5653c Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 12 Aug 2021 11:38:25 -0700 Subject: [PATCH 1/6] shuffle index list with numpy, scatter list, use file for large lists --- tools/preprocess_dataset_mpi.py | 159 ++++++++++++++++---------------- 1 file changed, 82 insertions(+), 77 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index b80ad4a25..495f9e28d 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -219,26 +219,21 @@ def barrier(args): else: dist.barrier() -def bcast(args, vals, root=0): - """Broadcast list of vals from root to all ranks, returns newly allocated list""" +def scatterv_(args, invals, counts, outval, root=0): + """Scatter values from invals according to counts array, receive values in outval""" if args.use_mpi: - vals = args.mpi_comm.bcast(vals, root=root) - return vals + displ = [sum(counts[:rank]) for rank in range(args.numranks)] + args.mpi_comm.Scatterv([invals, np.array(counts), np.array(displ), args.MPI.INT64_T], outval, root=root) else: - # broadcast length of vals list - length = [len(vals)] - dist.broadcast_object_list(length, src=root) - - # allocate a tensor of appropriate size - # initialize tensor with list values on root + tensors = [] if args.rank == root: - tvals = torch.tensor(vals, dtype=torch.int64) - else: - tvals = torch.zeros(length[0], dtype=torch.int64) + for rank in range(args.numranks): + start = sum(counts[:rank]) + end = start + counts[rank] + tensors.append(torch.tensor(invals[start:end])) - # broadcast tensor from root, and return as a new list - dist.broadcast(tvals, src=root) - return tvals.tolist() + tensor = torch.from_numpy(outval) + dist.scatter(tensor, tensors, src=root) def all_sum_(args, vals): """Sums values in vals element-wise and updates vals with final result on all ranks""" @@ -330,63 +325,75 @@ def load_dset(args): def select_sample_list(args, dset_size): """Given the total number of samples, select a list of sample index values""" + # determine total number of samples that we'll read + num_samples = dset_size + if args.count is not None and args.count < dset_size: + num_samples = args.count + # create sample index list on rank 0, # optionally shuffle the list, # and optionally limit the sample count - idx = [] + idxlist = None if args.rank == 0: # generate a list of all index values - idx = list(range(dset_size)) + idxlist = np.arange(dset_size, dtype=np.int64) # optionally shuffle if args.shuffle: - if args.seed is not None: - random.seed(args.seed) - random.shuffle(idx) + # args.seed may be an int (to seed) or None (to not) + rng = np.random.default_rng(args.seed) + rng.shuffle(idxlist) # optionally limit the sample count if args.count is not None: - idx = idx[:args.count] - - # broadcast sample index values from rank 0 to all procs - idx = bcast(args, idx, root=0) - return idx - -def get_start_end(num, rank, num_ranks): - """Compute start and end index values to evenly divide num items - among ranks. - - If num is not evenly divisible by num_ranks, ranks from - [0,remainder) will each be assigned one extra item. - Returns a (start, end) tuple, such that the calling rank - should take items in a list[start:end] - - Parameters - ---------- - num : int - Number of items to be divided - rank : int - Rank of the calling process - num_ranks : int - Number of processes among which to divide items - - Returns - ------- - int - start index value - int - end index value - """ - - num_per_rank = num // num_ranks - remainder = num % num_ranks - if rank < remainder: - start = (num_per_rank + 1) * rank; - end = start + (num_per_rank + 1) + idxlist = idxlist[:args.count] + + # get a list of the number of elements each rank will hold + counts = get_proc_counts(num_samples, args.numranks) + idx_start = sum(counts[:args.rank]) + idx_end = idx_start + counts[args.rank] + idx = np.zeros(counts[args.rank], np.int64) + + # scatter index list if small enough + scatter_limit = 20000000 + if num_samples < scatter_limit: + # scatter sample index values from rank 0 to all procs + # based on distribution defined in counts list + scatterv_(args, idxlist, counts, idx, root=0) else: - start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder); - end = start + num_per_rank - return start, end + # The index list is too big to send to every process. + # Write it to a shared file to be read by other ranks instead. + indexlistfile = f"{args.output_prefix}_{args.level}.sampleidx" + if args.rank == 0: + with open(indexlistfile, "wb") as f: + #vals = np.array(idx, dtype=np.int64) + f.write(idxlist.tobytes(order='C')) + + # wait for rank 0 to write the file + barrier(args) + + # All ranks read their respective portion + length = counts[args.rank] + if length > 0: + with open(indexlistfile, "rb") as f: + offset = idx_start * 8 + f.seek(offset) + f.readinto(idx) + + # wait for all to read + barrier(args) + + # delete the temporary file + if args.rank == 0: + os.remove(indexlistfile) + + barrier(args) + + return num_samples, idx + +def get_proc_counts(num, num_ranks): + num_per_rank, remainder = divmod(num, num_ranks) + return [num_per_rank + 1 if rank < remainder else num_per_rank for rank in range(num_ranks)] def get_filename(args, key, rank=None): pathname = args.output_prefix @@ -398,7 +405,7 @@ def get_filename(args, key, rank=None): return filename -def rank_files_write(args, dset, idx, encoder): +def rank_files_write(args, dset, num_samples, idx, encoder): tokenize_start = time.time() # we'll total up the number of docs, sentences, and bytes @@ -424,15 +431,13 @@ def rank_files_write(args, dset, idx, encoder): impl=args.dataset_impl, dtype=best_fitting_dtype(args.vocab_size)) - # divide index list evenly among ranks - idx_start, idx_end = get_start_end(len(idx), args.rank, args.numranks) - # each rank tokenizes its samples and writes its own file progress_next = time.time() + float(args.log_interval) - for i in idx[idx_start:idx_end]: + for i in idx: + sample_id = int(i) for key in args.columns: # tokenize text for the given sample index - text = dset[i][key] + text = dset[sample_id][key] doc, bytes_processed = encoder.encode_text(text) # add tokenized sequence to our data file @@ -451,11 +456,11 @@ def rank_files_write(args, dset, idx, encoder): elapsed = current - tokenize_start timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") docs = dset_stats[0] * args.numranks - percent = docs / len(idx) * 100.0 + percent = docs / num_samples * 100.0 docrate = docs / elapsed if elapsed > 0.0 else 0.0 mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 - secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0) - print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),", + secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) + print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", f"{secs_left} secs left ...", flush=True) @@ -581,21 +586,21 @@ def main(): print(dset) print("Selecting features:", args.columns) + args.level = "document" + if args.split_sentences: + args.level = "sentence" + # create sample index list, # optionally shuffle the list, # and optionally limit the sample count - idx = select_sample_list(args, len(dset)) - + num_samples, idx = select_sample_list(args, len(dset)) + if nltk_available and args.split_sentences: nltk.download("punkt", quiet=True) encoder = Encoder(args) args.vocab_size = encoder.tokenizer.vocab_size - args.level = "document" - if args.split_sentences: - args.level = "sentence" - # wait for all ranks before stopping timer barrier(args) startup_end = time.time() @@ -603,7 +608,7 @@ def main(): print("Seconds to startup:", startup_end - startup_start) # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, idx, encoder) + success, err = rank_files_write(args, dset, num_samples, idx, encoder) if not success: if args.rank == 0: # If any process fails, we skip the merge since the resulting file would be invalid. From b626416ed547d8f7d788c2b2588eb7d1d80a20ec Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Thu, 12 Aug 2021 13:17:41 -0700 Subject: [PATCH 2/6] drop unused idx_end from index scatter --- tools/preprocess_dataset_mpi.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index 495f9e28d..b06cbbbd7 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -350,8 +350,8 @@ def select_sample_list(args, dset_size): # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) - idx_start = sum(counts[:args.rank]) - idx_end = idx_start + counts[args.rank] + + # allocate space to hold its portion of the list idx = np.zeros(counts[args.rank], np.int64) # scatter index list if small enough @@ -366,16 +366,16 @@ def select_sample_list(args, dset_size): indexlistfile = f"{args.output_prefix}_{args.level}.sampleidx" if args.rank == 0: with open(indexlistfile, "wb") as f: - #vals = np.array(idx, dtype=np.int64) f.write(idxlist.tobytes(order='C')) # wait for rank 0 to write the file barrier(args) # All ranks read their respective portion - length = counts[args.rank] - if length > 0: + idx_count = counts[args.rank] + if idx_count > 0: with open(indexlistfile, "rb") as f: + idx_start = sum(counts[:args.rank]) offset = idx_start * 8 f.seek(offset) f.readinto(idx) From 1070dc2db46f337d10e85a69798754995698eb93 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Sun, 15 Aug 2021 16:54:10 -0700 Subject: [PATCH 3/6] drop scatter list file to simplify, can add back if needed --- tools/preprocess_dataset_mpi.py | 37 +++------------------------------ 1 file changed, 3 insertions(+), 34 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index b06cbbbd7..d4a230a18 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -354,40 +354,9 @@ def select_sample_list(args, dset_size): # allocate space to hold its portion of the list idx = np.zeros(counts[args.rank], np.int64) - # scatter index list if small enough - scatter_limit = 20000000 - if num_samples < scatter_limit: - # scatter sample index values from rank 0 to all procs - # based on distribution defined in counts list - scatterv_(args, idxlist, counts, idx, root=0) - else: - # The index list is too big to send to every process. - # Write it to a shared file to be read by other ranks instead. - indexlistfile = f"{args.output_prefix}_{args.level}.sampleidx" - if args.rank == 0: - with open(indexlistfile, "wb") as f: - f.write(idxlist.tobytes(order='C')) - - # wait for rank 0 to write the file - barrier(args) - - # All ranks read their respective portion - idx_count = counts[args.rank] - if idx_count > 0: - with open(indexlistfile, "rb") as f: - idx_start = sum(counts[:args.rank]) - offset = idx_start * 8 - f.seek(offset) - f.readinto(idx) - - # wait for all to read - barrier(args) - - # delete the temporary file - if args.rank == 0: - os.remove(indexlistfile) - - barrier(args) + # scatter sample index values from rank 0 to all procs + # based on distribution defined in counts list + scatterv_(args, idxlist, counts, idx, root=0) return num_samples, idx From 11e4df0c74fc6ec894686788765e318e922bea15 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Mon, 16 Aug 2021 18:12:21 -0700 Subject: [PATCH 4/6] rework scatterv, recompute num_samples when needed --- tools/preprocess_dataset_mpi.py | 54 +++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index d4a230a18..ba650d00e 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -176,6 +176,10 @@ def get_args(): if not args.split_sentences: print("Bert tokenizer detected, are you sure you don't want to split sentences?") + args.level = "document" + if args.split_sentences: + args.level = "sentence" + # some default/dummy values for the tokenizer args.rank = 0 args.make_vocab_size_divisible_by = 128 @@ -220,20 +224,20 @@ def barrier(args): dist.barrier() def scatterv_(args, invals, counts, outval, root=0): - """Scatter values from invals according to counts array, receive values in outval""" + """Scatter int64 values from invals according to counts array, receive values in outval""" + assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" + assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" + if args.use_mpi: - displ = [sum(counts[:rank]) for rank in range(args.numranks)] - args.mpi_comm.Scatterv([invals, np.array(counts), np.array(displ), args.MPI.INT64_T], outval, root=root) + counts = np.array(counts) + displs = np.cumsum(counts) - counts + args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) else: - tensors = [] + scatterlist = [] if args.rank == root: - for rank in range(args.numranks): - start = sum(counts[:rank]) - end = start + counts[rank] - tensors.append(torch.tensor(invals[start:end])) - - tensor = torch.from_numpy(outval) - dist.scatter(tensor, tensors, src=root) + scatterlist = list(torch.split(torch.from_numpy(invals), counts)) + outtensor = torch.from_numpy(outval) + dist.scatter(outtensor, scatterlist, src=root) def all_sum_(args, vals): """Sums values in vals element-wise and updates vals with final result on all ranks""" @@ -323,12 +327,17 @@ def load_dset(args): return dset, err -def select_sample_list(args, dset_size): - """Given the total number of samples, select a list of sample index values""" - # determine total number of samples that we'll read +def get_num_samples(args, dset_size): + """Given a dataset size and optional count argument, return number of samples to process.""" num_samples = dset_size if args.count is not None and args.count < dset_size: num_samples = args.count + return num_samples + +def select_sample_list(args, dset_size): + """Given the total number of samples, select a list of sample index values""" + # determine total number of samples that we'll read + num_samples = get_num_samples(args, dset_size) # create sample index list on rank 0, # optionally shuffle the list, @@ -346,7 +355,7 @@ def select_sample_list(args, dset_size): # optionally limit the sample count if args.count is not None: - idxlist = idxlist[:args.count] + idxlist = idxlist[:num_samples] # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) @@ -358,7 +367,7 @@ def select_sample_list(args, dset_size): # based on distribution defined in counts list scatterv_(args, idxlist, counts, idx, root=0) - return num_samples, idx + return idx def get_proc_counts(num, num_ranks): num_per_rank, remainder = divmod(num, num_ranks) @@ -374,9 +383,12 @@ def get_filename(args, key, rank=None): return filename -def rank_files_write(args, dset, num_samples, idx, encoder): +def rank_files_write(args, dset, idx, encoder): tokenize_start = time.time() + # compute total number of samples we'e processing + num_samples = get_num_samples(args, len(dset)) + # we'll total up the number of docs, sentences, and bytes # processed across all ranks dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes @@ -555,14 +567,10 @@ def main(): print(dset) print("Selecting features:", args.columns) - args.level = "document" - if args.split_sentences: - args.level = "sentence" - # create sample index list, # optionally shuffle the list, # and optionally limit the sample count - num_samples, idx = select_sample_list(args, len(dset)) + idx = select_sample_list(args, len(dset)) if nltk_available and args.split_sentences: nltk.download("punkt", quiet=True) @@ -577,7 +585,7 @@ def main(): print("Seconds to startup:", startup_end - startup_start) # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, num_samples, idx, encoder) + success, err = rank_files_write(args, dset, idx, encoder) if not success: if args.rank == 0: # If any process fails, we skip the merge since the resulting file would be invalid. From 360ff19a9bc1b1f6d943c7126de59421dc2e84b4 Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 12:16:18 -0500 Subject: [PATCH 5/6] Update tools/preprocess_dataset_mpi.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- tools/preprocess_dataset_mpi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index ba650d00e..cb5481dd6 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -233,7 +233,7 @@ def scatterv_(args, invals, counts, outval, root=0): displs = np.cumsum(counts) - counts args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) else: - scatterlist = [] + scatterlist = None if args.rank == root: scatterlist = list(torch.split(torch.from_numpy(invals), counts)) outtensor = torch.from_numpy(outval) From 5c0ca62e6b80a3de075eb01a1c2df33beba1b71f Mon Sep 17 00:00:00 2001 From: Adam Moody Date: Tue, 17 Aug 2021 10:18:06 -0700 Subject: [PATCH 6/6] fix spacing --- tools/preprocess_dataset_mpi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_dataset_mpi.py index cb5481dd6..3cd366160 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_dataset_mpi.py @@ -233,7 +233,7 @@ def scatterv_(args, invals, counts, outval, root=0): displs = np.cumsum(counts) - counts args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) else: - scatterlist = None + scatterlist = None if args.rank == root: scatterlist = list(torch.split(torch.from_numpy(invals), counts)) outtensor = torch.from_numpy(outval)