Skip to content

Commit b5eb5b7

Browse files
aaaandychenchenzhenyang
authored andcommitted
fix problems in refactoring train engine api
Signed-off-by: chenzhenyang <[email protected]>
1 parent 4fcf930 commit b5eb5b7

File tree

4 files changed

+400
-295
lines changed

4 files changed

+400
-295
lines changed

areal/api/engine_api.py

Lines changed: 41 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
import torch.distributed as dist
1111
from torchdata.stateful_dataloader import StatefulDataLoader
12-
from megatron.core import parallel_state as mpu
1312

1413
from areal.api.alloc_mode import ParallelStrategy
1514
from areal.api.io_struct import (
@@ -20,18 +19,24 @@
2019
SaveLoadMeta,
2120
WeightUpdateMeta,
2221
)
23-
from areal.utils.data import pack_tensor_dict, MicroBatchList, unpack_sequence, reorder_list, \
24-
pad_and_stack_tensors_along_first_dim
25-
from areal.utils.ulysses import set_ulysses_sequence_parallel_group
22+
from areal.utils.data import (
23+
MicroBatchList,
24+
pack_tensor_dict,
25+
pad_and_stack_tensors_along_first_dim,
26+
reorder_list,
27+
unpack_sequence,
28+
)
2629

2730
if TYPE_CHECKING:
2831
from areal.api.workflow_api import RolloutWorkflow
2932

33+
3034
@dataclass
3135
class ForwardBackwardOutputs:
3236
mb_outputs: list[torch.Tensor] | None
3337
losses: list[torch.Tensor] | None
3438

39+
3540
class TrainEngine(abc.ABC):
3641
def __init__(self):
3742
self.is_offload = None
@@ -218,33 +223,10 @@ def load(self, meta: SaveLoadMeta):
218223
"""
219224
raise NotImplementedError()
220225

221-
def aggregate_result(
222-
self,
223-
result: torch.Tensor,
224-
) -> Any | None:
225-
"""Aggregate results across parallel ranks in distributed training.
226-
227-
In distributed settings (especially pipeline parallelism), results may only
228-
exist on certain ranks (e.g., the last pipeline stage). This method handles
229-
broadcasting or aggregating results to make them available on all ranks.
230-
231-
Parameters
232-
----------
233-
result : torch.Tensor
234-
The result tensor to aggregate.
235-
236-
Returns
237-
-------
238-
Any or None
239-
The aggregated result tensor, broadcasted to all ranks if necessary.
240-
Returns None if no result is available on any rank.
241-
"""
242-
raise NotImplementedError()
243-
244-
def split_micro_batch(
226+
def _split_micro_batch(
245227
self,
246228
input_: dict[str, Any],
247-
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor]| None = None,
229+
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor] | None = None,
248230
) -> tuple[Iterable[dict[str, torch.Tensor]], MicroBatchList]:
249231
"""Split input batch into micro-batches for gradient accumulation.
250232
@@ -258,7 +240,7 @@ def split_micro_batch(
258240
input_ : dict[str, Any]
259241
The input batch dictionary.
260242
loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor], optional
261-
A function to compute the loss weight for each micro-batch.
243+
A function to compute the loss weight for each micro-batch.
262244
263245
Returns
264246
-------
@@ -272,7 +254,7 @@ def split_micro_batch(
272254
def _forward_compute_mb(
273255
self,
274256
mb_input: dict[str, Any],
275-
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
257+
post_process_fn: Callable[[torch.Tensor, dict[str, Any]], Any],
276258
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
277259
**kwargs,
278260
) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]:
@@ -283,34 +265,33 @@ def _forward_compute_mb(
283265
closure is used by the training framework to compute loss and perform
284266
backward pass during gradient accumulation.
285267
286-
The exact structure of `mb_input` and the returned loss function depends
287-
on the engine implementation, but the interface contract remains the same.
288-
289268
Parameters
290269
----------
291270
mb_input : dict[str, Any]
292271
A dictionary containing the micro-batch input data. The exact structure
293272
depends on the engine implementation, but typically includes packed/padded
294273
tensors, padding metadata, and total loss weight.
295-
loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor]
274+
loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], optional
296275
A function that computes the normalized loss given model output and
297-
input data.
298-
loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor]
276+
input data. Optional for pure forward passes.
277+
loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor], optional
299278
A function that computes the weight for this micro-batch, typically
300279
the number of tokens. Used for proper loss scaling across micro-batches.
280+
Optional when `loss_fn` is not provided.
301281
**kwargs
302282
Additional keyword arguments that may be used by specific implementations,
303283
such as model reference for pipeline parallel, batch type, post-processing
304284
hooks, etc.
305285
306286
Returns
307287
-------
308-
tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]
288+
tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]] | None]
309289
A tuple containing:
310290
- The model output tensor (logits) from the forward pass
311291
- A callable loss function that takes the output tensor and returns:
312292
- A loss tensor (scaled appropriately for gradient accumulation)
313293
- A dictionary with additional data
294+
If `loss_fn` is None (e.g., pure forward pass), the callable can be None.
314295
"""
315296
raise NotImplementedError()
316297

@@ -376,8 +357,8 @@ def step_lr_scheduler(self):
376357
def forward_backward_batch(
377358
self,
378359
data_iterator: Iterable[dict[str, torch.Tensor]],
379-
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor]| None = None,
380-
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor]| None = None,
360+
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor] | None = None,
361+
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor] | None = None,
381362
output_post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None,
382363
return_outputs: bool = False,
383364
forward_only: bool = False,
@@ -392,9 +373,9 @@ def forward_backward_batch(
392373
Parameters
393374
----------
394375
data_iterator : Iterable[dict[str, torch.Tensor]]
395-
An iterable of micro-batch dictionaries, typically produced by
396-
`split_micro_batch`. Each dictionary contains micro-batch input data
397-
and metadata (padding info, loss weights, etc.).
376+
`data_iterator` is typically produced by converting a `MicroBatchList` into
377+
an iterator (e.g., via `create_mb_iterator`), yielding per-micro-batch
378+
payloads and any metadata computed during splitting for downstream use.
398379
loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], optional
399380
A function that computes the normalized loss given model output and
400381
input data. Required when `forward_only=False` or when `return_outputs=False`
@@ -457,11 +438,9 @@ def train_batch(
457438
Scalar statistics after training, e.g., the current learning rate,
458439
gradient norm, etc.
459440
"""
460-
if self.is_offload:
461-
self.onload()
462441
self.optimizer_zero_grad()
463-
_data_iterator, mb_list = self.split_micro_batch(input_,loss_weight_fn)
464-
self.forward_backward_batch(_data_iterator,loss_fn,loss_weight_fn)
442+
_data_iterator, _ = self._split_micro_batch(input_, loss_weight_fn)
443+
self.forward_backward_batch(_data_iterator, loss_fn, loss_weight_fn)
465444
return self.optimizer_step()
466445

467446
@torch.no_grad()
@@ -497,10 +476,10 @@ def eval_batch(
497476
A scalar loss or None. The evaluation statistics should be aggregated
498477
with `stats_tracker`.
499478
"""
500-
if self.is_offload:
501-
self.onload()
502-
_data_iterator, mb_list = self.split_micro_batch(input_,loss_weight_fn)
503-
output = self.forward_backward_batch(_data_iterator, loss_fn, loss_weight_fn, forward_only=True)
479+
_data_iterator, _ = self._split_micro_batch(input_, loss_weight_fn)
480+
output = self.forward_backward_batch(
481+
_data_iterator, loss_fn, loss_weight_fn, forward_only=True
482+
)
504483
loss = torch.stack(output.losses).sum(dtype=torch.float32)
505484
dist.all_reduce(loss, group=self.dp_group)
506485
return loss
@@ -538,28 +517,26 @@ def forward_batch(
538517
Any or None
539518
The result produced by `post_hook` and `aggregate_fn`.
540519
"""
541-
if self.is_offload:
542-
self.onload()
543-
544-
if self.parallel_helper.sp_size > 1:
545-
set_ulysses_sequence_parallel_group(self.sp_group)
546-
547520
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
548521

549522
if output_seqlens is None:
550523
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
551524
assert output_seqlens is not None
552525

553-
_data_iterator, mb_list = self.split_micro_batch(input_)
526+
_data_iterator, mb_list = self._split_micro_batch(input_)
554527

555-
result = self.forward_backward_batch(_data_iterator, forward_only=True,return_outputs=True,output_post_hook=post_hook)
528+
result = self.forward_backward_batch(
529+
_data_iterator,
530+
forward_only=True,
531+
return_outputs=True,
532+
output_post_hook=post_hook,
533+
)
556534

557535
res = aggregate_fn(result.mb_outputs)
558536
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
559537
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
560538
reordered = reorder_list(unpacked, mb_list.backward_indices)
561-
result = pad_and_stack_tensors_along_first_dim(reordered)
562-
return self.aggregate_result(result)
539+
return pad_and_stack_tensors_along_first_dim(reordered)
563540

564541
@torch.no_grad()
565542
def forward(
@@ -592,6 +569,9 @@ def export_stats(self) -> dict[str, float]:
592569
def onload(self):
593570
raise NotImplementedError()
594571

572+
def offload(self):
573+
raise NotImplementedError()
574+
595575

596576
class InferenceEngine(abc.ABC):
597577
def initialize(self, *args, **kwargs):

0 commit comments

Comments
 (0)