Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 83 additions & 11 deletions nemo/collections/llm/gpt/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import os
import pickle
import re
import signal
import time
from functools import lru_cache, partial
from queue import Empty
from typing import Any, Callable, List, Optional, Type, Union

import numpy as np
Expand Down Expand Up @@ -81,6 +83,77 @@ def build_index_from_memdata(fn, newline_int):
return midx


def safe_map(fn, iterable, workers=1, ctx="fork"):
"""
Crash-resilient alternative to multiprocessing.Pool.map() that can handle
worker process crashes gracefully without hanging the entire operation.
"""
ctx = mp.get_context(ctx)
input_queue = ctx.Queue()
output_queue = ctx.Queue()
indexed_inputs = list(enumerate(iterable))
for job in indexed_inputs:
input_queue.put(job)
for _ in range(workers):
input_queue.put(None) # poison pill

def worker_loop():
while True:
job = input_queue.get()
if job is None:
break
i, item = job
try:
result = fn(item)
output_queue.put((i, True, result, None))
except Exception as e:
output_queue.put((i, False, None, str(e)))

processes = [ctx.Process(target=worker_loop) for _ in range(workers)]
for p in processes:
p.start()

results = [None] * len(indexed_inputs)
seen_indices = set()
expected = len(indexed_inputs)
received = 0

# Collect whatever gets returned from live workers
while received < expected:
try:
i, success, result, err = output_queue.get(timeout=0.5)
seen_indices.add(i)
results[i] = result if success else None
if not success:
logger.warning(f"Item {i}: {err}")
received += 1
except Empty:
# Check if all workers are dead
if all(not p.is_alive() for p in processes):
logger.error("All workers exited before completing all tasks.")
break
continue

# Join and check for crashes
for p in processes:
p.join()
if p.exitcode is not None and p.exitcode < 0:
sig = -p.exitcode
try:
sig_name = signal.Signals(sig).name
except Exception:
sig_name = f"signal {sig}"
logger.warning(f"PID {p.pid} died from {sig_name}")

# Patch any missing results from crashed workers
for i in range(len(results)):
if i not in seen_indices:
logger.warning(f"No result for item {i}, likely crash")
results[i] = None

return results


class _TextMemMapDataset(Dataset):
"""
Allow per-line lazy access to multiple text files using numpy memmap.
Expand Down Expand Up @@ -547,17 +620,16 @@ def build_index_files(
logger.info(f"Processing {len(dataset_paths)} data files using {workers} workers")
# load all files into memmap
start_time = time.time()
ctx = mp.get_context("fork")
with ctx.Pool(workers) as p:
build_status = p.map(
partial(
_build_memmap_index_files,
newline_int,
build_index_fn,
index_mapping_dir=index_mapping_dir,
),
dataset_paths,
)
build_status = safe_map(
partial(
_build_memmap_index_files,
newline_int,
build_index_fn,
index_mapping_dir=index_mapping_dir,
),
dataset_paths,
workers=workers,
)

logger.info(
f"Time building {sum(build_status)} / {len(build_status)} "
Expand Down
Loading