-
Notifications
You must be signed in to change notification settings - Fork 228
shuffle index list with numpy, scatter list, use file for large lists #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@thomasw21 , and would you please also take a look at this one if you have some time? It's on the shorter side. Thanks! BTW, I've got a commit ready to push that drops the scatter-via-file implementation. I'm leaning toward dropping that to simplify the code. I think the |
thomasw21
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, I'm okay with this PR, but I'd really consider looking into HF shuffle. Despite being slow (and we're not even sure if we're much slower than here), it requires a lot less code IMO.
tools/preprocess_dataset_mpi.py
Outdated
| if args.use_mpi: | ||
| vals = args.mpi_comm.bcast(vals, root=root) | ||
| return vals | ||
| displ = [sum(counts[:rank]) for rank in range(args.numranks)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since displ is converted to numpy array, you might as well use cumsum instead in order to have linear computation. You'd convert counts to array too beforehand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I made that change.
| """Broadcast list of vals from root to all ranks, returns newly allocated list""" | ||
| def scatterv_(args, invals, counts, outval, root=0): | ||
| """Scatter values from invals according to counts array, receive values in outval""" | ||
| if args.use_mpi: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could actually assert that len(counts) == args.numranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks.
tools/preprocess_dataset_mpi.py
Outdated
| def bcast(args, vals, root=0): | ||
| """Broadcast list of vals from root to all ranks, returns newly allocated list""" | ||
| def scatterv_(args, invals, counts, outval, root=0): | ||
| """Scatter values from invals according to counts array, receive values in outval""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should mention that this supports only int64 for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a note to the docstring about that.
tools/preprocess_dataset_mpi.py
Outdated
| for rank in range(args.numranks): | ||
| start = sum(counts[:rank]) | ||
| end = start + counts[rank] | ||
| tensors.append(torch.tensor(invals[start:end])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you convert it into torch tensor, you can use cumsum in order to get end. Or better there seems to have a build in function for what you want: https://pytorch.org/docs/stable/generated/torch.split.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Thanks for the tip on torch.split. That looks to be perfect.
tools/preprocess_dataset_mpi.py
Outdated
| # broadcast tensor from root, and return as a new list | ||
| dist.broadcast(tvals, src=root) | ||
| return tvals.tolist() | ||
| tensor = torch.from_numpy(outval) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: tensor doesn't express much
| tensor = torch.from_numpy(outval) | |
| out = torch.from_numpy(outval) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there an issue where the size of tensors don't match anymore? Could you add a check that outval, correspond to the same size as corresponding count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that name is not very descriptive rename. And good idea to check the size of the output array.
| for key in args.columns: | ||
| # tokenize text for the given sample index | ||
| text = dset[i][key] | ||
| text = dset[sample_id][key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can even run dset[idx] if I'm not wrong, to be confirmed. If so I'd suggest something like
for row in dset[idx]:
for key in args.columns:
text = rows[key]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that's useful to know. If you don't mind, I'd like to stick with the harder way for now, because I'm using this same code to support a JSON file and that particular code only supports an int for the argument in the __get__ call.
| timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") | ||
| docs = dset_stats[0] * args.numranks | ||
| percent = docs / len(idx) * 100.0 | ||
| percent = docs / num_samples * 100.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't that just len(dset)? Unless you want to take in account count also. I don't think it's worth adding a parameter to the signature of the method, since you can recover it from args.count and len(dset).
Seeing as it's already an approximation, you could say that percent = dset_stats[0] / len(idx) * 100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this is to adjust for count being less than len(dset). I'll create a separate function to compute num_samples and drop it from the param list.
tools/preprocess_dataset_mpi.py
Outdated
| args.level = "document" | ||
| if args.split_sentences: | ||
| args.level = "sentence" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this in get_args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, done.
| mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 | ||
| secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0) | ||
| print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),", | ||
| secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0) | ||
| print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),", | ||
| secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) | ||
| print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could extrapolate also? Overall I feel it's a shame to add arguments for the sake of a logger. If you really want to keep it, I'm also okay with that.
Yes, it's still worth looking at the HF dataset shuffle. I can still look at that in the future. However, having some native shuffle support is useful, because PR #60 adds support for JSON files. If we find that the HF dataset shuffle is better, I can use that with a HF dataset and fallback to this for a JSON file. I'll also go ahead and push the change to drop the scatter file, which simplifies the code quite a bit. |
|
You can actually load json files using HF datasets. https://huggingface.co/docs/datasets/loading_datasets.html#json-files |
I did give that a try, even by accident a few times. For the oscar json, the load time is very slow. I haven't investigated the reason yet. I can imagine it may be scanning the json file to identify line breaks, meaning it's probably reading the full file from every rank. One of the additions in PR #60 is that it implements a parallel scan of the json file, so the full file can be scanned and indexed in just a few seconds. The ranks scan the file collectively rather than having each rank read the full file. |
|
@thomasw21 , when you get a chance, would you please take another look at this one? Some of your suggestions became moot since I dropped the scatter file, but I think I've addressed the rest. |
thomasw21
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall! Two small questions before merging!
| def scatterv_(args, invals, counts, outval, root=0): | ||
| """Scatter int64 values from invals according to counts array, receive values in outval""" | ||
| assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" | ||
| assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure, if this raises an exception, other ranks will keep going and deadlock? Or do we need the same pattern as everyone where else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes. Good point. That would deadlock if one rank failed the assert and the others did not.
Let's leave it this way for now in this PR. The preprocess_dataset_mpi.py script is the only one calling scatterv_ right now, and it should be fine. Longer term, I've got an idea for the DistData class to help with these collective assert checks.
| # broadcast tensor from root, and return as a new list | ||
| dist.broadcast(tvals, src=root) | ||
| return tvals.tolist() | ||
| scatterlist = list(torch.split(torch.from_numpy(invals), counts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So leaving the tuple didn't work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the list() in there, it throws this error:
RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].
|
Btw there's this PR that might peak your interest: huggingface/datasets#2747 |
thomasw21
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can merge as soon as you give me the go
Oh, nice. Thanks for the pointer. |
Co-authored-by: Thomas Wang <[email protected]>
* add tensor parallelism for MoE
This updates the index selection logic in a couple ways:
random.shuffletonumpy.random.shuffle(timed to be about 10x faster)The scatter is implemented in two ways. The first uses a scatter collective. The second writes the full list to a temporary file that is then read back by each process.
I suppose if the scatter collective seems to be reliable, the temporary file is not needed and could be dropped. I've left it in there for now.
Update:
Decided to drop the scatter file to simplify the code. I think the scatterv should be sufficient, and if not, it's easy to add the file method back.