Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 82 additions & 77 deletions tools/preprocess_dataset_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,26 +219,21 @@ def barrier(args):
else:
dist.barrier()

def bcast(args, vals, root=0):
"""Broadcast list of vals from root to all ranks, returns newly allocated list"""
def scatterv_(args, invals, counts, outval, root=0):
"""Scatter values from invals according to counts array, receive values in outval"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should mention that this supports only int64 for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note to the docstring about that.

if args.use_mpi:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could actually assert that len(counts) == args.numranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

vals = args.mpi_comm.bcast(vals, root=root)
return vals
displ = [sum(counts[:rank]) for rank in range(args.numranks)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since displ is converted to numpy array, you might as well use cumsum instead in order to have linear computation. You'd convert counts to array too beforehand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I made that change.

args.mpi_comm.Scatterv([invals, np.array(counts), np.array(displ), args.MPI.INT64_T], outval, root=root)
else:
# broadcast length of vals list
length = [len(vals)]
dist.broadcast_object_list(length, src=root)

# allocate a tensor of appropriate size
# initialize tensor with list values on root
tensors = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tensors = []
chunks = []

if args.rank == root:
tvals = torch.tensor(vals, dtype=torch.int64)
else:
tvals = torch.zeros(length[0], dtype=torch.int64)
for rank in range(args.numranks):
start = sum(counts[:rank])
end = start + counts[rank]
tensors.append(torch.tensor(invals[start:end]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you convert it into torch tensor, you can use cumsum in order to get end. Or better there seems to have a build in function for what you want: https://pytorch.org/docs/stable/generated/torch.split.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks for the tip on torch.split. That looks to be perfect.


# broadcast tensor from root, and return as a new list
dist.broadcast(tvals, src=root)
return tvals.tolist()
tensor = torch.from_numpy(outval)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: tensor doesn't express much

Suggested change
tensor = torch.from_numpy(outval)
out = torch.from_numpy(outval)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there an issue where the size of tensors don't match anymore? Could you add a check that outval, correspond to the same size as corresponding count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that name is not very descriptive rename. And good idea to check the size of the output array.

dist.scatter(tensor, tensors, src=root)

def all_sum_(args, vals):
"""Sums values in vals element-wise and updates vals with final result on all ranks"""
Expand Down Expand Up @@ -330,63 +325,75 @@ def load_dset(args):

def select_sample_list(args, dset_size):
"""Given the total number of samples, select a list of sample index values"""
# determine total number of samples that we'll read
num_samples = dset_size
if args.count is not None and args.count < dset_size:
num_samples = args.count

# create sample index list on rank 0,
# optionally shuffle the list,
# and optionally limit the sample count
idx = []
idxlist = None
if args.rank == 0:
# generate a list of all index values
idx = list(range(dset_size))
idxlist = np.arange(dset_size, dtype=np.int64)

# optionally shuffle
if args.shuffle:
if args.seed is not None:
random.seed(args.seed)
random.shuffle(idx)
# args.seed may be an int (to seed) or None (to not)
rng = np.random.default_rng(args.seed)
rng.shuffle(idxlist)

# optionally limit the sample count
if args.count is not None:
idx = idx[:args.count]

# broadcast sample index values from rank 0 to all procs
idx = bcast(args, idx, root=0)
return idx

def get_start_end(num, rank, num_ranks):
"""Compute start and end index values to evenly divide num items
among ranks.

If num is not evenly divisible by num_ranks, ranks from
[0,remainder) will each be assigned one extra item.
Returns a (start, end) tuple, such that the calling rank
should take items in a list[start:end]

Parameters
----------
num : int
Number of items to be divided
rank : int
Rank of the calling process
num_ranks : int
Number of processes among which to divide items

Returns
-------
int
start index value
int
end index value
"""

num_per_rank = num // num_ranks
remainder = num % num_ranks
if rank < remainder:
start = (num_per_rank + 1) * rank;
end = start + (num_per_rank + 1)
idxlist = idxlist[:args.count]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idxlist = idxlist[:num_samples]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks.


# get a list of the number of elements each rank will hold
counts = get_proc_counts(num_samples, args.numranks)

# allocate space to hold its portion of the list
idx = np.zeros(counts[args.rank], np.int64)

# scatter index list if small enough
scatter_limit = 20000000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Also if you don't mind can you use 20_000_000 it's easier to read IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will go away. That 20 million was an arbitrary number I picked out of the air anyway. Thanks for the tip on the underscores for readability. I'll do that if this has to come back.

if num_samples < scatter_limit:
# scatter sample index values from rank 0 to all procs
# based on distribution defined in counts list
scatterv_(args, idxlist, counts, idx, root=0)
else:
start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder);
end = start + num_per_rank
return start, end
# The index list is too big to send to every process.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there such a thing? That we can't load all indices in memory? 8bytes per int64, so 20_000_000 elements represent around 160Mb which should be more than okay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more worried about the communication method being used. We can revisit this if it happens to show up.

# Write it to a shared file to be read by other ranks instead.
indexlistfile = f"{args.output_prefix}_{args.level}.sampleidx"
if args.rank == 0:
with open(indexlistfile, "wb") as f:
f.write(idxlist.tobytes(order='C'))

# wait for rank 0 to write the file
barrier(args)

# All ranks read their respective portion
idx_count = counts[args.rank]
if idx_count > 0:
with open(indexlistfile, "rb") as f:
idx_start = sum(counts[:args.rank])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should compute those offsets only once and dispatch to every rank.

offset = idx_start * 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment that this is linked to int64.

f.seek(offset)
f.readinto(idx)

# wait for all to read
barrier(args)

# delete the temporary file
if args.rank == 0:
os.remove(indexlistfile)

barrier(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this last barrier is needed, you could just let other ranks go to the next step without worrying no?


return num_samples, idx

def get_proc_counts(num, num_ranks):
num_per_rank, remainder = divmod(num, num_ranks)
return [num_per_rank + 1 if rank < remainder else num_per_rank for rank in range(num_ranks)]

def get_filename(args, key, rank=None):
pathname = args.output_prefix
Expand All @@ -398,7 +405,7 @@ def get_filename(args, key, rank=None):

return filename

def rank_files_write(args, dset, idx, encoder):
def rank_files_write(args, dset, num_samples, idx, encoder):
tokenize_start = time.time()

# we'll total up the number of docs, sentences, and bytes
Expand All @@ -424,15 +431,13 @@ def rank_files_write(args, dset, idx, encoder):
impl=args.dataset_impl,
dtype=best_fitting_dtype(args.vocab_size))

# divide index list evenly among ranks
idx_start, idx_end = get_start_end(len(idx), args.rank, args.numranks)

# each rank tokenizes its samples and writes its own file
progress_next = time.time() + float(args.log_interval)
for i in idx[idx_start:idx_end]:
for i in idx:
sample_id = int(i)
for key in args.columns:
# tokenize text for the given sample index
text = dset[i][key]
text = dset[sample_id][key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can even run dset[idx] if I'm not wrong, to be confirmed. If so I'd suggest something like

for row in dset[idx]:
     for key in args.columns:
          text = rows[key]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's useful to know. If you don't mind, I'd like to stick with the harder way for now, because I'm using this same code to support a JSON file and that particular code only supports an int for the argument in the __get__ call.

doc, bytes_processed = encoder.encode_text(text)

# add tokenized sequence to our data file
Expand All @@ -451,11 +456,11 @@ def rank_files_write(args, dset, idx, encoder):
elapsed = current - tokenize_start
timestamp = time.strftime("%Y-%m-%dT%H:%M:%S")
docs = dset_stats[0] * args.numranks
percent = docs / len(idx) * 100.0
percent = docs / num_samples * 100.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that just len(dset)? Unless you want to take in account count also. I don't think it's worth adding a parameter to the signature of the method, since you can recover it from args.count and len(dset).

Seeing as it's already an approximation, you could say that percent = dset_stats[0] / len(idx) * 100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is to adjust for count being less than len(dset). I'll create a separate function to compute num_samples and drop it from the param list.

docrate = docs / elapsed if elapsed > 0.0 else 0.0
mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0
secs_left = int((len(idx) - docs) / docrate if docrate > 0.0 else 0.0)
print(f"{timestamp}: Processed (estimated) {docs} of {len(idx)} docs ({percent:0.2f}%),",
secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could extrapolate also? Overall I feel it's a shame to add arguments for the sake of a logger. If you really want to keep it, I'm also okay with that.

f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,",
f"{secs_left} secs left ...",
flush=True)
Expand Down Expand Up @@ -581,29 +586,29 @@ def main():
print(dset)
print("Selecting features:", args.columns)

args.level = "document"
if args.split_sentences:
args.level = "sentence"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this in get_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, done.


# create sample index list,
# optionally shuffle the list,
# and optionally limit the sample count
idx = select_sample_list(args, len(dset))
num_samples, idx = select_sample_list(args, len(dset))

if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)

encoder = Encoder(args)
args.vocab_size = encoder.tokenizer.vocab_size

args.level = "document"
if args.split_sentences:
args.level = "sentence"

# wait for all ranks before stopping timer
barrier(args)
startup_end = time.time()
if args.rank == 0:
print("Seconds to startup:", startup_end - startup_start)

# have each rank write its file, returns False if any rank had a problem
success, err = rank_files_write(args, dset, idx, encoder)
success, err = rank_files_write(args, dset, num_samples, idx, encoder)
if not success:
if args.rank == 0:
# If any process fails, we skip the merge since the resulting file would be invalid.
Expand Down