Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 55 additions & 73 deletions tools/preprocess_dataset_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def get_args():
if not args.split_sentences:
print("Bert tokenizer detected, are you sure you don't want to split sentences?")

args.level = "document"
if args.split_sentences:
args.level = "sentence"

# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
Expand Down Expand Up @@ -219,26 +223,21 @@ def barrier(args):
else:
dist.barrier()

def bcast(args, vals, root=0):
"""Broadcast list of vals from root to all ranks, returns newly allocated list"""
def scatterv_(args, invals, counts, outval, root=0):
"""Scatter int64 values from invals according to counts array, receive values in outval"""
assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}"
assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, if this raises an exception, other ranks will keep going and deadlock? Or do we need the same pattern as everyone where else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Good point. That would deadlock if one rank failed the assert and the others did not.

Let's leave it this way for now in this PR. The preprocess_dataset_mpi.py script is the only one calling scatterv_ right now, and it should be fine. Longer term, I've got an idea for the DistData class to help with these collective assert checks.


if args.use_mpi:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could actually assert that len(counts) == args.numranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

vals = args.mpi_comm.bcast(vals, root=root)
return vals
counts = np.array(counts)
displs = np.cumsum(counts) - counts
args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root)
else:
# broadcast length of vals list
length = [len(vals)]
dist.broadcast_object_list(length, src=root)

# allocate a tensor of appropriate size
# initialize tensor with list values on root
scatterlist = []
if args.rank == root:
tvals = torch.tensor(vals, dtype=torch.int64)
else:
tvals = torch.zeros(length[0], dtype=torch.int64)

# broadcast tensor from root, and return as a new list
dist.broadcast(tvals, src=root)
return tvals.tolist()
scatterlist = list(torch.split(torch.from_numpy(invals), counts))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So leaving the tuple didn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the list() in there, it throws this error:

RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].

outtensor = torch.from_numpy(outval)
dist.scatter(outtensor, scatterlist, src=root)

def all_sum_(args, vals):
"""Sums values in vals element-wise and updates vals with final result on all ranks"""
Expand Down Expand Up @@ -328,65 +327,51 @@ def load_dset(args):

return dset, err

def get_num_samples(args, dset_size):
"""Given a dataset size and optional count argument, return number of samples to process."""
num_samples = dset_size
if args.count is not None and args.count < dset_size:
num_samples = args.count
return num_samples

def select_sample_list(args, dset_size):
"""Given the total number of samples, select a list of sample index values"""
# determine total number of samples that we'll read
num_samples = get_num_samples(args, dset_size)

# create sample index list on rank 0,
# optionally shuffle the list,
# and optionally limit the sample count
idx = []
idxlist = None
if args.rank == 0:
# generate a list of all index values
idx = list(range(dset_size))
idxlist = np.arange(dset_size, dtype=np.int64)

# optionally shuffle
if args.shuffle:
if args.seed is not None:
random.seed(args.seed)
random.shuffle(idx)
# args.seed may be an int (to seed) or None (to not)
rng = np.random.default_rng(args.seed)
rng.shuffle(idxlist)

# optionally limit the sample count
if args.count is not None:
idx = idx[:args.count]
idxlist = idxlist[:num_samples]

# get a list of the number of elements each rank will hold
counts = get_proc_counts(num_samples, args.numranks)

# allocate space to hold its portion of the list
idx = np.zeros(counts[args.rank], np.int64)

# scatter sample index values from rank 0 to all procs
# based on distribution defined in counts list
scatterv_(args, idxlist, counts, idx, root=0)

# broadcast sample index values from rank 0 to all procs
idx = bcast(args, idx, root=0)
return idx

def get_start_end(num, rank, num_ranks):
"""Compute start and end index values to evenly divide num items
among ranks.

If num is not evenly divisible by num_ranks, ranks from
[0,remainder) will each be assigned one extra item.
Returns a (start, end) tuple, such that the calling rank
should take items in a list[start:end]

Parameters
----------
num : int
Number of items to be divided
rank : int
Rank of the calling process
num_ranks : int
Number of processes among which to divide items

Returns
-------
int
start index value
int
end index value
"""

num_per_rank = num // num_ranks
remainder = num % num_ranks
if rank < remainder:
start = (num_per_rank + 1) * rank;
end = start + (num_per_rank + 1)
else:
start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder);
end = start + num_per_rank
return start, end
def get_proc_counts(num, num_ranks):
num_per_rank, remainder = divmod(num, num_ranks)
return [num_per_rank + 1 if rank < remainder else num_per_rank for rank in range(num_ranks)]

def get_filename(args, key, rank=None):
pathname = args.output_prefix
Expand All @@ -401,6 +386,9 @@ def get_filename(args, key, rank=None):
def rank_files_write(args, dset, idx, encoder):
tokenize_start = time.time()

# compute total number of samples we'e processing
num_samples = get_num_samples(args, len(dset))

# we'll total up the number of docs, sentences, and bytes
# processed across all ranks
dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes
Expand All @@ -424,15 +412,13 @@ def rank_files_write(args, dset, idx, encoder):
impl=args.dataset_impl,
dtype=best_fitting_dtype(args.vocab_size))

# divide index list evenly among ranks
idx_start, idx_end = get_start_end(len(idx), args.rank, args.numranks)

# each rank tokenizes its samples and writes its own file
progress_next = time.time() + float(args.log_interval)
for i in idx[idx_start:idx_end]:
for i in idx:
sample_id = int(i)
for key in args.columns:
# tokenize text for the given sample index
text = dset[i][key]
text = dset[sample_id][key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can even run dset[idx] if I'm not wrong, to be confirmed. If so I'd suggest something like

for row in dset[idx]:
     for key in args.columns:
          text = rows[key]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's useful to know. If you don't mind, I'd like to stick with the harder way for now, because I'm using this same code to support a JSON file and that particular code only supports an int for the argument in the __get__ call.

doc, bytes_processed = encoder.encode_text(text)

# add tokenized sequence to our data file
Expand All @@ -451,11 +437,11 @@ def rank_files_write(args, dset, idx, encoder):
elapsed = current - tokenize_start
timestamp = time.strftime("%Y-%m-%dT%H:%M:%S")
docs = dset_stats[0] * args.numranks
percent = docs / len(idx) * 100.0
percent = docs / num_samples * 100.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that just len(dset)? Unless you want to take in account count also. I don't think it's worth adding a parameter to the signature of the method, since you can recover it from args.count and len(dset).

Seeing as it's already an approximation, you could say that percent = dset_stats[0] / len(idx) * 100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is to adjust for count being less than len(dset). I'll create a separate function to compute num_samples and drop it from the param list.

docrate = docs / elapsed if elapsed > 0.0 else 0.0
mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0
secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0)
print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),",
secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could extrapolate also? Overall I feel it's a shame to add arguments for the sake of a logger. If you really want to keep it, I'm also okay with that.

f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,",
f"{secs_left} secs left ...",
flush=True)
Expand Down Expand Up @@ -585,17 +571,13 @@ def main():
# optionally shuffle the list,
# and optionally limit the sample count
idx = select_sample_list(args, len(dset))

if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)

encoder = Encoder(args)
args.vocab_size = encoder.tokenizer.vocab_size

args.level = "document"
if args.split_sentences:
args.level = "sentence"

# wait for all ranks before stopping timer
barrier(args)
startup_end = time.time()
Expand Down