-
Notifications
You must be signed in to change notification settings - Fork 228
shuffle index list with numpy, scatter list, use file for large lists #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a5b01fb
b626416
1070dc2
11e4df0
360ff19
5c0ca62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -176,6 +176,10 @@ def get_args(): | |
| if not args.split_sentences: | ||
| print("Bert tokenizer detected, are you sure you don't want to split sentences?") | ||
|
|
||
| args.level = "document" | ||
| if args.split_sentences: | ||
| args.level = "sentence" | ||
|
|
||
| # some default/dummy values for the tokenizer | ||
| args.rank = 0 | ||
| args.make_vocab_size_divisible_by = 128 | ||
|
|
@@ -219,26 +223,21 @@ def barrier(args): | |
| else: | ||
| dist.barrier() | ||
|
|
||
| def bcast(args, vals, root=0): | ||
| """Broadcast list of vals from root to all ranks, returns newly allocated list""" | ||
| def scatterv_(args, invals, counts, outval, root=0): | ||
| """Scatter int64 values from invals according to counts array, receive values in outval""" | ||
| assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" | ||
| assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" | ||
|
|
||
| if args.use_mpi: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could actually assert that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks. |
||
| vals = args.mpi_comm.bcast(vals, root=root) | ||
| return vals | ||
| counts = np.array(counts) | ||
| displs = np.cumsum(counts) - counts | ||
| args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) | ||
| else: | ||
| # broadcast length of vals list | ||
| length = [len(vals)] | ||
| dist.broadcast_object_list(length, src=root) | ||
|
|
||
| # allocate a tensor of appropriate size | ||
| # initialize tensor with list values on root | ||
| scatterlist = [] | ||
adammoody marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if args.rank == root: | ||
| tvals = torch.tensor(vals, dtype=torch.int64) | ||
| else: | ||
| tvals = torch.zeros(length[0], dtype=torch.int64) | ||
|
|
||
| # broadcast tensor from root, and return as a new list | ||
| dist.broadcast(tvals, src=root) | ||
| return tvals.tolist() | ||
| scatterlist = list(torch.split(torch.from_numpy(invals), counts)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So leaving the tuple didn't work here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the |
||
| outtensor = torch.from_numpy(outval) | ||
| dist.scatter(outtensor, scatterlist, src=root) | ||
|
|
||
| def all_sum_(args, vals): | ||
| """Sums values in vals element-wise and updates vals with final result on all ranks""" | ||
|
|
@@ -328,65 +327,51 @@ def load_dset(args): | |
|
|
||
| return dset, err | ||
|
|
||
| def get_num_samples(args, dset_size): | ||
| """Given a dataset size and optional count argument, return number of samples to process.""" | ||
| num_samples = dset_size | ||
| if args.count is not None and args.count < dset_size: | ||
| num_samples = args.count | ||
| return num_samples | ||
|
|
||
| def select_sample_list(args, dset_size): | ||
| """Given the total number of samples, select a list of sample index values""" | ||
| # determine total number of samples that we'll read | ||
| num_samples = get_num_samples(args, dset_size) | ||
|
|
||
| # create sample index list on rank 0, | ||
| # optionally shuffle the list, | ||
| # and optionally limit the sample count | ||
| idx = [] | ||
| idxlist = None | ||
| if args.rank == 0: | ||
| # generate a list of all index values | ||
| idx = list(range(dset_size)) | ||
| idxlist = np.arange(dset_size, dtype=np.int64) | ||
|
|
||
| # optionally shuffle | ||
| if args.shuffle: | ||
| if args.seed is not None: | ||
| random.seed(args.seed) | ||
| random.shuffle(idx) | ||
| # args.seed may be an int (to seed) or None (to not) | ||
| rng = np.random.default_rng(args.seed) | ||
| rng.shuffle(idxlist) | ||
|
|
||
| # optionally limit the sample count | ||
| if args.count is not None: | ||
| idx = idx[:args.count] | ||
| idxlist = idxlist[:num_samples] | ||
|
|
||
| # get a list of the number of elements each rank will hold | ||
| counts = get_proc_counts(num_samples, args.numranks) | ||
|
|
||
| # allocate space to hold its portion of the list | ||
| idx = np.zeros(counts[args.rank], np.int64) | ||
|
|
||
| # scatter sample index values from rank 0 to all procs | ||
| # based on distribution defined in counts list | ||
| scatterv_(args, idxlist, counts, idx, root=0) | ||
|
|
||
| # broadcast sample index values from rank 0 to all procs | ||
| idx = bcast(args, idx, root=0) | ||
| return idx | ||
|
|
||
| def get_start_end(num, rank, num_ranks): | ||
| """Compute start and end index values to evenly divide num items | ||
| among ranks. | ||
|
|
||
| If num is not evenly divisible by num_ranks, ranks from | ||
| [0,remainder) will each be assigned one extra item. | ||
| Returns a (start, end) tuple, such that the calling rank | ||
| should take items in a list[start:end] | ||
|
|
||
| Parameters | ||
| ---------- | ||
| num : int | ||
| Number of items to be divided | ||
| rank : int | ||
| Rank of the calling process | ||
| num_ranks : int | ||
| Number of processes among which to divide items | ||
|
|
||
| Returns | ||
| ------- | ||
| int | ||
| start index value | ||
| int | ||
| end index value | ||
| """ | ||
|
|
||
| num_per_rank = num // num_ranks | ||
| remainder = num % num_ranks | ||
| if rank < remainder: | ||
| start = (num_per_rank + 1) * rank; | ||
| end = start + (num_per_rank + 1) | ||
| else: | ||
| start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder); | ||
| end = start + num_per_rank | ||
| return start, end | ||
| def get_proc_counts(num, num_ranks): | ||
| num_per_rank, remainder = divmod(num, num_ranks) | ||
| return [num_per_rank + 1 if rank < remainder else num_per_rank for rank in range(num_ranks)] | ||
|
|
||
| def get_filename(args, key, rank=None): | ||
| pathname = args.output_prefix | ||
|
|
@@ -401,6 +386,9 @@ def get_filename(args, key, rank=None): | |
| def rank_files_write(args, dset, idx, encoder): | ||
| tokenize_start = time.time() | ||
|
|
||
| # compute total number of samples we'e processing | ||
| num_samples = get_num_samples(args, len(dset)) | ||
|
|
||
| # we'll total up the number of docs, sentences, and bytes | ||
| # processed across all ranks | ||
| dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes | ||
|
|
@@ -424,15 +412,13 @@ def rank_files_write(args, dset, idx, encoder): | |
| impl=args.dataset_impl, | ||
| dtype=best_fitting_dtype(args.vocab_size)) | ||
|
|
||
| # divide index list evenly among ranks | ||
| idx_start, idx_end = get_start_end(len(idx), args.rank, args.numranks) | ||
|
|
||
| # each rank tokenizes its samples and writes its own file | ||
| progress_next = time.time() + float(args.log_interval) | ||
| for i in idx[idx_start:idx_end]: | ||
| for i in idx: | ||
| sample_id = int(i) | ||
| for key in args.columns: | ||
| # tokenize text for the given sample index | ||
| text = dset[i][key] | ||
| text = dset[sample_id][key] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can even run for row in dset[idx]:
for key in args.columns:
text = rows[key]
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, that's useful to know. If you don't mind, I'd like to stick with the harder way for now, because I'm using this same code to support a JSON file and that particular code only supports an |
||
| doc, bytes_processed = encoder.encode_text(text) | ||
|
|
||
| # add tokenized sequence to our data file | ||
|
|
@@ -451,11 +437,11 @@ def rank_files_write(args, dset, idx, encoder): | |
| elapsed = current - tokenize_start | ||
| timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") | ||
| docs = dset_stats[0] * args.numranks | ||
| percent = docs / len(idx) * 100.0 | ||
| percent = docs / num_samples * 100.0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't that just Seeing as it's already an approximation, you could say that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, this is to adjust for count being less than |
||
| docrate = docs / elapsed if elapsed > 0.0 else 0.0 | ||
| mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 | ||
| secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0) | ||
| print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),", | ||
| secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could extrapolate also? Overall I feel it's a shame to add arguments for the sake of a logger. If you really want to keep it, I'm also okay with that. |
||
| f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", | ||
| f"{secs_left} secs left ...", | ||
| flush=True) | ||
|
|
@@ -585,17 +571,13 @@ def main(): | |
| # optionally shuffle the list, | ||
| # and optionally limit the sample count | ||
| idx = select_sample_list(args, len(dset)) | ||
|
|
||
| if nltk_available and args.split_sentences: | ||
| nltk.download("punkt", quiet=True) | ||
|
|
||
| encoder = Encoder(args) | ||
| args.vocab_size = encoder.tokenizer.vocab_size | ||
|
|
||
| args.level = "document" | ||
| if args.split_sentences: | ||
| args.level = "sentence" | ||
|
|
||
| # wait for all ranks before stopping timer | ||
| barrier(args) | ||
| startup_end = time.time() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure, if this raises an exception, other ranks will keep going and deadlock? Or do we need the same pattern as everyone where else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes. Good point. That would deadlock if one rank failed the assert and the others did not.
Let's leave it this way for now in this PR. The
preprocess_dataset_mpi.pyscript is the only one callingscatterv_right now, and it should be fine. Longer term, I've got an idea for theDistDataclass to help with these collective assert checks.