Skip to content

Conversation

@adammoody
Copy link
Contributor

@adammoody adammoody commented Aug 12, 2021

This updates the index selection logic in a couple ways:

  • switches from random.shuffle to numpy.random.shuffle (timed to be about 10x faster)
  • replaces the broadcast with a scatter

The scatter is implemented in two ways. The first uses a scatter collective. The second writes the full list to a temporary file that is then read back by each process.

I suppose if the scatter collective seems to be reliable, the temporary file is not needed and could be dropped. I've left it in there for now.

Update:
Decided to drop the scatter file to simplify the code. I think the scatterv should be sufficient, and if not, it's easy to add the file method back.

@adammoody
Copy link
Contributor Author

@thomasw21 , and would you please also take a look at this one if you have some time? It's on the shorter side. Thanks!

BTW, I've got a commit ready to push that drops the scatter-via-file implementation. I'm leaning toward dropping that to simplify the code. I think the scatterv_ should be sufficient. I haven't seen a problem yet in my testing on oscar. If we do hit a problem with scatter, the file-based method is easy to add back.

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I'm okay with this PR, but I'd really consider looking into HF shuffle. Despite being slow (and we're not even sure if we're much slower than here), it requires a lot less code IMO.

if args.use_mpi:
vals = args.mpi_comm.bcast(vals, root=root)
return vals
displ = [sum(counts[:rank]) for rank in range(args.numranks)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since displ is converted to numpy array, you might as well use cumsum instead in order to have linear computation. You'd convert counts to array too beforehand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I made that change.

"""Broadcast list of vals from root to all ranks, returns newly allocated list"""
def scatterv_(args, invals, counts, outval, root=0):
"""Scatter values from invals according to counts array, receive values in outval"""
if args.use_mpi:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could actually assert that len(counts) == args.numranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

def bcast(args, vals, root=0):
"""Broadcast list of vals from root to all ranks, returns newly allocated list"""
def scatterv_(args, invals, counts, outval, root=0):
"""Scatter values from invals according to counts array, receive values in outval"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should mention that this supports only int64 for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note to the docstring about that.

Comment on lines 230 to 233
for rank in range(args.numranks):
start = sum(counts[:rank])
end = start + counts[rank]
tensors.append(torch.tensor(invals[start:end]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you convert it into torch tensor, you can use cumsum in order to get end. Or better there seems to have a build in function for what you want: https://pytorch.org/docs/stable/generated/torch.split.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks for the tip on torch.split. That looks to be perfect.

# broadcast tensor from root, and return as a new list
dist.broadcast(tvals, src=root)
return tvals.tolist()
tensor = torch.from_numpy(outval)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: tensor doesn't express much

Suggested change
tensor = torch.from_numpy(outval)
out = torch.from_numpy(outval)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there an issue where the size of tensors don't match anymore? Could you add a check that outval, correspond to the same size as corresponding count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that name is not very descriptive rename. And good idea to check the size of the output array.

for key in args.columns:
# tokenize text for the given sample index
text = dset[i][key]
text = dset[sample_id][key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can even run dset[idx] if I'm not wrong, to be confirmed. If so I'd suggest something like

for row in dset[idx]:
     for key in args.columns:
          text = rows[key]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's useful to know. If you don't mind, I'd like to stick with the harder way for now, because I'm using this same code to support a JSON file and that particular code only supports an int for the argument in the __get__ call.

timestamp = time.strftime("%Y-%m-%dT%H:%M:%S")
docs = dset_stats[0] * args.numranks
percent = docs / len(idx) * 100.0
percent = docs / num_samples * 100.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that just len(dset)? Unless you want to take in account count also. I don't think it's worth adding a parameter to the signature of the method, since you can recover it from args.count and len(dset).

Seeing as it's already an approximation, you could say that percent = dset_stats[0] / len(idx) * 100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is to adjust for count being less than len(dset). I'll create a separate function to compute num_samples and drop it from the param list.

Comment on lines 589 to 591
args.level = "document"
if args.split_sentences:
args.level = "sentence"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this in get_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, done.

mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0
secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0)
print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),",
secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0)
print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),",
secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0)
print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could extrapolate also? Overall I feel it's a shame to add arguments for the sake of a logger. If you really want to keep it, I'm also okay with that.

@adammoody
Copy link
Contributor Author

Interesting, I'm okay with this PR, but I'd really consider looking into HF shuffle. Despite being slow (and we're not even sure if we're much slower than here), it requires a lot less code IMO.

Yes, it's still worth looking at the HF dataset shuffle. I can still look at that in the future. However, having some native shuffle support is useful, because PR #60 adds support for JSON files. If we find that the HF dataset shuffle is better, I can use that with a HF dataset and fallback to this for a JSON file.

I'll also go ahead and push the change to drop the scatter file, which simplifies the code quite a bit.

@thomasw21
Copy link
Member

You can actually load json files using HF datasets. https://huggingface.co/docs/datasets/loading_datasets.html#json-files

@adammoody
Copy link
Contributor Author

You can actually load json files using HF datasets. https://huggingface.co/docs/datasets/loading_datasets.html#json-files

I did give that a try, even by accident a few times. For the oscar json, the load time is very slow. I haven't investigated the reason yet. I can imagine it may be scanning the json file to identify line breaks, meaning it's probably reading the full file from every rank.

One of the additions in PR #60 is that it implements a parallel scan of the json file, so the full file can be scanned and indexed in just a few seconds. The ranks scan the file collectively rather than having each rank read the full file.

@adammoody
Copy link
Contributor Author

@thomasw21 , when you get a chance, would you please take another look at this one? Some of your suggestions became moot since I dropped the scatter file, but I think I've addressed the rest.

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall! Two small questions before merging!

def scatterv_(args, invals, counts, outval, root=0):
"""Scatter int64 values from invals according to counts array, receive values in outval"""
assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}"
assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, if this raises an exception, other ranks will keep going and deadlock? Or do we need the same pattern as everyone where else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Good point. That would deadlock if one rank failed the assert and the others did not.

Let's leave it this way for now in this PR. The preprocess_dataset_mpi.py script is the only one calling scatterv_ right now, and it should be fine. Longer term, I've got an idea for the DistData class to help with these collective assert checks.

# broadcast tensor from root, and return as a new list
dist.broadcast(tvals, src=root)
return tvals.tolist()
scatterlist = list(torch.split(torch.from_numpy(invals), counts))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So leaving the tuple didn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the list() in there, it throws this error:

RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].

@thomasw21
Copy link
Member

Btw there's this PR that might peak your interest: huggingface/datasets#2747

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can merge as soon as you give me the go

@adammoody
Copy link
Contributor Author

Btw there's this PR that might peak your interest: huggingface/datasets#2747

Oh, nice. Thanks for the pointer.

@thomasw21 thomasw21 merged commit a59d93c into bigscience-workshop:main Aug 17, 2021
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Oct 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants