Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def process_batch_arrow(table: pyarrow.Table):

try:
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
# Use concurrency= directly for better GPU utilization
# (get_compute_strategy may limit parallelism)
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
Expand All @@ -247,7 +248,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
Expand Down Expand Up @@ -280,7 +281,7 @@ def process_batch_arrow(table: pyarrow.Table):
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
# Use concurrency= directly for better GPU utilization
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
Expand All @@ -290,7 +291,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
Expand Down
20 changes: 20 additions & 0 deletions data_juicer/core/executor/concurrency_scoping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utility for scoping op concurrency when running partitions concurrently."""


def scope_op_concurrency(op, max_concurrent_partitions: int) -> int:
"""Returns the concurrency a single partition should use for this op.

When multiple partitions run concurrently, each partition should use a
fraction of the total GPU/actor resources to avoid over-subscription.

Args:
op: An operator instance with ``use_ray_actor()`` and ``num_proc``.
max_concurrent_partitions: How many partitions will run in parallel.

Returns:
The concurrency value the partition should pass through to
``map_batches``.
"""
if not op.use_ray_actor() or not op.num_proc or op.num_proc <= 0:
return op.num_proc # CPU ops or auto-mode unchanged
return max(1, op.num_proc // max_concurrent_partitions)
271 changes: 255 additions & 16 deletions data_juicer/core/executor/ray_executor_partitioned.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,61 @@ def _configure_partitioning(self):
logger.warning("Legacy num_partitions detected, overriding partition configuration")

self.partition_mode = mode
self.num_partitions = num_of_partitions
self.partition_size = partition_size
self.max_size_mb = max_size_mb

# Resolve max_concurrent_partitions.
# "auto" (default) → detect from Ray cluster GPU count, fall back to 1.
# Explicit int → use as-is.
raw_max_conc = ConfigAccessor.get(partition_cfg, "max_concurrent_partitions", "auto")
self.max_concurrent_partitions = self._resolve_max_concurrent(raw_max_conc)

# Ensure we have at least as many partitions as concurrent slots,
# otherwise some GPUs would sit idle.
if self.max_concurrent_partitions > num_of_partitions:
logger.info(
f"num_of_partitions ({num_of_partitions}) < "
f"max_concurrent_partitions ({self.max_concurrent_partitions}), "
f"raising num_of_partitions to {self.max_concurrent_partitions}"
)
num_of_partitions = self.max_concurrent_partitions

self.num_partitions = num_of_partitions

if mode == "manual":
logger.info(f"Manual partition mode: using {self.num_partitions} partitions")
else: # auto mode
logger.info(f"Auto partition mode: will determine optimal partitioning based on data characteristics")
logger.info(f"Fallback partition size: {self.partition_size} samples, max {self.max_size_mb} MB")

if self.max_concurrent_partitions > 1:
logger.info(
f"Concurrent partition processing enabled: "
f"max_concurrent_partitions={self.max_concurrent_partitions}"
)

@staticmethod
def _resolve_max_concurrent(raw_value) -> int:
"""Resolve max_concurrent_partitions from config value.

* ``"auto"`` → number of GPUs visible to Ray (falls back to 1).
* An explicit int is returned as-is (minimum 1).
"""
if isinstance(raw_value, str) and raw_value.lower() == "auto":
try:
num_gpus = int(ray.cluster_resources().get("GPU", 0))
except Exception as e:
logger.warning(f"Could not get GPU resources from Ray cluster, defaulting to 0. Error: {e}")
num_gpus = 0
if num_gpus > 1:
logger.info(
f"Auto-detected {num_gpus} GPUs in Ray cluster, " f"setting max_concurrent_partitions={num_gpus}"
)
return num_gpus
# No GPUs or single GPU → sequential
return 1
return max(1, int(raw_value))

def _configure_auto_partitioning(self, dataset, ops):
"""Configure partitioning using the partition size optimizer for auto mode."""
try:
Expand Down Expand Up @@ -498,6 +543,10 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
f"{partitioning_info.total_rows} total rows"
)

# Branch: concurrent vs sequential partition processing
if self.max_concurrent_partitions > 1:
return self._process_partitions_concurrent(partitions, ops, partitioning_info)

# Process each partition separately with checkpointing
logger.info("Processing partitions with checkpointing support...")
processed_partitions = []
Expand Down Expand Up @@ -541,6 +590,197 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
# Return as RayDataset wrapper
return RayDataset(merged_dataset, cfg=self.cfg)

def _process_partitions_concurrent(self, partitions, ops, partitioning_info):
"""Process partitions concurrently as Ray remote tasks.

Each partition is submitted as a Ray remote task that independently
loads ops from config, scopes concurrency, and processes data with
its own checkpoint manager. Results are collected and unioned.
"""
max_conc = min(self.max_concurrent_partitions, len(partitions))
logger.info(f"Processing {len(partitions)} partitions concurrently " f"(max_concurrent_partitions={max_conc})")

# Serialisable values extracted from self (avoid serialising the executor)
cfg = self.cfg
ckpt_enabled = self.ckpt_manager.checkpoint_enabled
ckpt_strategy = self.ckpt_manager.checkpoint_strategy
ckpt_dir = self.ckpt_manager.ckpt_dir
ckpt_n_ops = getattr(self.ckpt_manager, "checkpoint_n_ops", 1)
ckpt_op_names = getattr(self.ckpt_manager, "checkpoint_op_names", [])
op_fusion_enabled = getattr(cfg, "op_fusion", False)

@ray.remote(num_cpus=0)
def _process_single_partition_task(
partition_data,
partition_id,
cfg,
max_concurrent_partitions,
ckpt_enabled,
ckpt_strategy,
ckpt_dir,
ckpt_n_ops,
ckpt_op_names,
op_fusion_enabled,
):
"""Ray remote task that processes one partition end-to-end."""
from loguru import logger as task_logger

from data_juicer.core.data.ray_dataset import RayDataset
from data_juicer.core.executor.concurrency_scoping import (
scope_op_concurrency,
)
from data_juicer.ops import load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils.ckpt_utils import RayCheckpointManager

task_logger.info(f"[Partition {partition_id}] Starting remote processing")

# Re-create ops from config to avoid serialisation issues
task_ops = load_ops(cfg.process)
if op_fusion_enabled:
task_ops = fuse_operators(task_ops)

# Scope concurrency and fix actor mode for each op.
# The remote task has no GPU, so use_cuda() returns False and
# ops default to task mode (model reloads per batch). Force
# actor mode for GPU ops so the model loads once per actor.
for op in task_ops:
if getattr(op, "num_gpus", 0) and op.num_gpus > 0:
op.ray_execution_mode = "actor"
op.num_proc = scope_op_concurrency(op, max_concurrent_partitions)

# Create local checkpoint manager
ckpt_manager = RayCheckpointManager(
ckpt_dir=ckpt_dir,
checkpoint_enabled=ckpt_enabled,
checkpoint_strategy=ckpt_strategy,
checkpoint_n_ops=ckpt_n_ops,
checkpoint_op_names=ckpt_op_names,
)

# Check for existing checkpoint
latest_checkpoint = ckpt_manager.find_latest_checkpoint(partition_id)

# If all ops are already checkpointed, load from checkpoint
if latest_checkpoint and latest_checkpoint[0] >= len(task_ops) - 1:
task_logger.info(f"[Partition {partition_id}] All ops checkpointed, " f"loading from checkpoint")
loaded = ckpt_manager.load_checkpoint(
latest_checkpoint[0],
latest_checkpoint[1],
partition_id,
cfg=cfg,
)
if loaded is not None:
return loaded.data.materialize()

# Determine resume point
start_op_idx = 0
partition_dataset = RayDataset(partition_data, cfg=cfg)

if latest_checkpoint:
loaded = ckpt_manager.load_checkpoint(
latest_checkpoint[0],
latest_checkpoint[1],
partition_id,
cfg=cfg,
)
if loaded is not None:
partition_dataset = loaded
start_op_idx = latest_checkpoint[0] + 1
task_logger.info(f"[Partition {partition_id}] Resuming from op " f"{start_op_idx}")

# Process ops one-by-one with checkpointing
remaining_ops = task_ops[start_op_idx:]
for rel_idx, op in enumerate(remaining_ops):
abs_idx = start_op_idx + rel_idx
task_logger.info(f"[Partition {partition_id}] Processing op {abs_idx}: " f"{op._name}")
partition_dataset = partition_dataset.process([op])

# Checkpoint if needed
if ckpt_manager.should_checkpoint(abs_idx, op._name):
partition_dataset.data = partition_dataset.data.materialize()
ckpt_manager.save_checkpoint(
partition_dataset.data,
abs_idx,
partition_id,
)

# Final materialize
partition_dataset.data = partition_dataset.data.materialize()
return partition_dataset.data

# Submit tasks (skip empty partitions)
futures = {}
for i, partition in enumerate(partitions):
# Skip empty partitions to avoid wasting GPU resources
try:
row_count = partition.count()
except Exception:
row_count = -1 # can't determine, submit anyway
if row_count == 0:
logger.info(f"Partition {i}: empty (0 rows), skipping")
continue

# Check if partition is fully checkpointed before submitting
latest_ckpt = self.ckpt_manager.find_latest_checkpoint(i)
if latest_ckpt and latest_ckpt[0] >= len(ops) - 1:
logger.info(f"Partition {i}: already fully checkpointed, " f"loading from checkpoint")
loaded = self.ckpt_manager.load_checkpoint(latest_ckpt[0], latest_ckpt[1], i, cfg=self.cfg)
if loaded is not None:
futures[i] = loaded.data.materialize()
continue

self._log_event(
event_type=EventType.PARTITION_START,
message=f"Starting concurrent processing of partition " f"{i + 1}/{len(partitions)}",
partition_id=i,
)
futures[i] = _process_single_partition_task.remote(
partition,
i,
cfg,
max_conc,
ckpt_enabled,
ckpt_strategy,
ckpt_dir,
ckpt_n_ops,
ckpt_op_names,
op_fusion_enabled,
)

# Collect results
processed_partitions = []
for i in sorted(futures.keys()):
result = futures[i]
if isinstance(result, ray.ObjectRef):
try:
result = ray.get(result)
logger.info(f"Partition {i}: completed successfully")
except Exception as e:
logger.error(f"Partition {i}: failed with error: {e}")
raise
processed_partitions.append(result)
self._log_event(
event_type=EventType.PARTITION_COMPLETE,
message=f"Completed concurrent processing of partition " f"{i + 1}/{len(partitions)}",
partition_id=i,
)

# Union results
logger.info("Merging concurrently processed partitions...")
if not processed_partitions:
logger.warning("All partitions were empty or skipped. Returning an empty dataset.")
return RayDataset(ray.data.from_items([]), cfg=self.cfg)

if len(processed_partitions) == 1:
merged_dataset = processed_partitions[0]
else:
merged_dataset = processed_partitions[0]
for partition in processed_partitions[1:]:
merged_dataset = merged_dataset.union(partition)
Comment thread
cyruszhang marked this conversation as resolved.

return RayDataset(merged_dataset, cfg=self.cfg)

def _process_with_convergence(self, dataset: RayDataset, ops: List, convergence_points: List[int]):
"""
Process dataset with convergence support for global operations.
Expand Down Expand Up @@ -954,7 +1194,14 @@ def _split_dataset_deterministic(self, dataset: RayDataset) -> tuple:
# Check for existing partitioning info (resumption case)
saved_info = self._load_partitioning_info()

# Split the dataset
# Split using the dataset's natural block structure. split()
# distributes existing blocks round-robin, so partitions inherit
# multiple blocks and Ray Data's streaming executor can pipeline
# stages within each partition. Avoid repartition() here — it
# adds a costly shuffle and may reduce block count (e.g. 96 source
# blocks repartitioned to 32 loses parallelism). If there are
# fewer blocks than partitions, some partitions will be empty —
# that's handled downstream (empty partitions are skipped).
logger.info(f"Splitting dataset into {self.num_partitions} partitions (deterministic mode)...")
partitions = dataset.data.split(self.num_partitions)
logger.info(f"Created {len(partitions)} partitions")
Expand All @@ -974,24 +1221,16 @@ def _split_dataset_deterministic(self, dataset: RayDataset) -> tuple:
self._clear_invalid_checkpoints()
saved_info = None

# Collect metadata for new partitions
logger.info("Collecting partition metadata for checkpoint validation...")
total_rows = sum(p.count() for p in partitions)
partition_metadata = []

for i, partition in enumerate(partitions):
meta = self._collect_partition_metadata(partition, i)
partition_metadata.append(meta)
logger.debug(f"Partition {i}: {meta.row_count} rows, hash={meta.first_row_hash[:8]}...")

# On first run, skip expensive metadata collection (count(), take())
# which triggers redundant pipeline executions on lazy datasets.
# Save only the partition count; full metadata is not needed until
# resume validation.
partitioning_info = PartitioningInfo(
num_partitions=self.num_partitions,
total_rows=total_rows,
partitions=partition_metadata,
total_rows=-1, # unknown until processing completes
partitions=[],
deterministic=True,
)

# Save partitioning info
self._save_partitioning_info(partitioning_info)

return partitions, partitioning_info
Expand Down
Loading
Loading