99import torch
1010import torch .distributed as dist
1111from torchdata .stateful_dataloader import StatefulDataLoader
12- from megatron .core import parallel_state as mpu
1312
1413from areal .api .alloc_mode import ParallelStrategy
1514from areal .api .io_struct import (
2019 SaveLoadMeta ,
2120 WeightUpdateMeta ,
2221)
23- from areal .utils .data import pack_tensor_dict , MicroBatchList , unpack_sequence , reorder_list , \
24- pad_and_stack_tensors_along_first_dim
25- from areal .utils .ulysses import set_ulysses_sequence_parallel_group
22+ from areal .utils .data import (
23+ MicroBatchList ,
24+ pack_tensor_dict ,
25+ pad_and_stack_tensors_along_first_dim ,
26+ reorder_list ,
27+ unpack_sequence ,
28+ )
2629
2730if TYPE_CHECKING :
2831 from areal .api .workflow_api import RolloutWorkflow
2932
33+
3034@dataclass
3135class ForwardBackwardOutputs :
3236 mb_outputs : list [torch .Tensor ] | None
3337 losses : list [torch .Tensor ] | None
3438
39+
3540class TrainEngine (abc .ABC ):
3641 def __init__ (self ):
3742 self .is_offload = None
@@ -218,33 +223,10 @@ def load(self, meta: SaveLoadMeta):
218223 """
219224 raise NotImplementedError ()
220225
221- def aggregate_result (
222- self ,
223- result : torch .Tensor ,
224- ) -> Any | None :
225- """Aggregate results across parallel ranks in distributed training.
226-
227- In distributed settings (especially pipeline parallelism), results may only
228- exist on certain ranks (e.g., the last pipeline stage). This method handles
229- broadcasting or aggregating results to make them available on all ranks.
230-
231- Parameters
232- ----------
233- result : torch.Tensor
234- The result tensor to aggregate.
235-
236- Returns
237- -------
238- Any or None
239- The aggregated result tensor, broadcasted to all ranks if necessary.
240- Returns None if no result is available on any rank.
241- """
242- raise NotImplementedError ()
243-
244- def split_micro_batch (
226+ def _split_micro_batch (
245227 self ,
246228 input_ : dict [str , Any ],
247- loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ]| None = None ,
229+ loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ] | None = None ,
248230 ) -> tuple [Iterable [dict [str , torch .Tensor ]], MicroBatchList ]:
249231 """Split input batch into micro-batches for gradient accumulation.
250232
@@ -258,7 +240,7 @@ def split_micro_batch(
258240 input_ : dict[str, Any]
259241 The input batch dictionary.
260242 loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor], optional
261- A function to compute the loss weight for each micro-batch.
243+ A function to compute the loss weight for each micro-batch.
262244
263245 Returns
264246 -------
@@ -272,7 +254,7 @@ def split_micro_batch(
272254 def _forward_compute_mb (
273255 self ,
274256 mb_input : dict [str , Any ],
275- loss_fn : Callable [[torch .Tensor , dict [str , Any ]], torch . Tensor ],
257+ post_process_fn : Callable [[torch .Tensor , dict [str , Any ]], Any ],
276258 loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ],
277259 ** kwargs ,
278260 ) -> tuple [torch .Tensor , Callable [[torch .Tensor ], tuple [torch .Tensor , dict ]]]:
@@ -283,34 +265,33 @@ def _forward_compute_mb(
283265 closure is used by the training framework to compute loss and perform
284266 backward pass during gradient accumulation.
285267
286- The exact structure of `mb_input` and the returned loss function depends
287- on the engine implementation, but the interface contract remains the same.
288-
289268 Parameters
290269 ----------
291270 mb_input : dict[str, Any]
292271 A dictionary containing the micro-batch input data. The exact structure
293272 depends on the engine implementation, but typically includes packed/padded
294273 tensors, padding metadata, and total loss weight.
295- loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor]
274+ loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], optional
296275 A function that computes the normalized loss given model output and
297- input data.
298- loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor]
276+ input data. Optional for pure forward passes.
277+ loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor], optional
299278 A function that computes the weight for this micro-batch, typically
300279 the number of tokens. Used for proper loss scaling across micro-batches.
280+ Optional when `loss_fn` is not provided.
301281 **kwargs
302282 Additional keyword arguments that may be used by specific implementations,
303283 such as model reference for pipeline parallel, batch type, post-processing
304284 hooks, etc.
305285
306286 Returns
307287 -------
308- tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]
288+ tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]] | None ]
309289 A tuple containing:
310290 - The model output tensor (logits) from the forward pass
311291 - A callable loss function that takes the output tensor and returns:
312292 - A loss tensor (scaled appropriately for gradient accumulation)
313293 - A dictionary with additional data
294+ If `loss_fn` is None (e.g., pure forward pass), the callable can be None.
314295 """
315296 raise NotImplementedError ()
316297
@@ -376,8 +357,8 @@ def step_lr_scheduler(self):
376357 def forward_backward_batch (
377358 self ,
378359 data_iterator : Iterable [dict [str , torch .Tensor ]],
379- loss_fn : Callable [[torch .Tensor , dict [str , Any ]], torch .Tensor ]| None = None ,
380- loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ]| None = None ,
360+ loss_fn : Callable [[torch .Tensor , dict [str , Any ]], torch .Tensor ] | None = None ,
361+ loss_weight_fn : Callable [[dict [str , Any ]], torch .Tensor ] | None = None ,
381362 output_post_hook : Callable [[torch .Tensor , dict [str , Any ]], Any ] | None = None ,
382363 return_outputs : bool = False ,
383364 forward_only : bool = False ,
@@ -392,9 +373,9 @@ def forward_backward_batch(
392373 Parameters
393374 ----------
394375 data_iterator : Iterable[dict[str, torch.Tensor]]
395- An iterable of micro-batch dictionaries, typically produced by
396- `split_micro_batch`. Each dictionary contains micro-batch input data
397- and metadata (padding info, loss weights, etc.) .
376+ `data_iterator` is typically produced by converting a `MicroBatchList` into
377+ an iterator (e.g., via `create_mb_iterator`), yielding per- micro-batch
378+ payloads and any metadata computed during splitting for downstream use .
398379 loss_fn : Callable[[torch.Tensor, dict[str, Any]], torch.Tensor], optional
399380 A function that computes the normalized loss given model output and
400381 input data. Required when `forward_only=False` or when `return_outputs=False`
@@ -457,11 +438,9 @@ def train_batch(
457438 Scalar statistics after training, e.g., the current learning rate,
458439 gradient norm, etc.
459440 """
460- if self .is_offload :
461- self .onload ()
462441 self .optimizer_zero_grad ()
463- _data_iterator , mb_list = self .split_micro_batch (input_ ,loss_weight_fn )
464- self .forward_backward_batch (_data_iterator ,loss_fn ,loss_weight_fn )
442+ _data_iterator , _ = self ._split_micro_batch (input_ , loss_weight_fn )
443+ self .forward_backward_batch (_data_iterator , loss_fn , loss_weight_fn )
465444 return self .optimizer_step ()
466445
467446 @torch .no_grad ()
@@ -497,10 +476,10 @@ def eval_batch(
497476 A scalar loss or None. The evaluation statistics should be aggregated
498477 with `stats_tracker`.
499478 """
500- if self .is_offload :
501- self .onload ()
502- _data_iterator , mb_list = self . split_micro_batch ( input_ , loss_weight_fn )
503- output = self . forward_backward_batch ( _data_iterator , loss_fn , loss_weight_fn , forward_only = True )
479+ _data_iterator , _ = self ._split_micro_batch ( input_ , loss_weight_fn )
480+ output = self .forward_backward_batch (
481+ _data_iterator , loss_fn , loss_weight_fn , forward_only = True
482+ )
504483 loss = torch .stack (output .losses ).sum (dtype = torch .float32 )
505484 dist .all_reduce (loss , group = self .dp_group )
506485 return loss
@@ -538,28 +517,26 @@ def forward_batch(
538517 Any or None
539518 The result produced by `post_hook` and `aggregate_fn`.
540519 """
541- if self .is_offload :
542- self .onload ()
543-
544- if self .parallel_helper .sp_size > 1 :
545- set_ulysses_sequence_parallel_group (self .sp_group )
546-
547520 cu_seqlens = pack_tensor_dict (input_ )["cu_seqlens" ]
548521
549522 if output_seqlens is None :
550523 output_seqlens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).cpu ().numpy ().tolist ()
551524 assert output_seqlens is not None
552525
553- _data_iterator , mb_list = self .split_micro_batch (input_ )
526+ _data_iterator , mb_list = self ._split_micro_batch (input_ )
554527
555- result = self .forward_backward_batch (_data_iterator , forward_only = True ,return_outputs = True ,output_post_hook = post_hook )
528+ result = self .forward_backward_batch (
529+ _data_iterator ,
530+ forward_only = True ,
531+ return_outputs = True ,
532+ output_post_hook = post_hook ,
533+ )
556534
557535 res = aggregate_fn (result .mb_outputs )
558536 output_seqlens = [output_seqlens [i ] for i in mb_list .forward_indices ]
559537 unpacked = unpack_sequence (res , lens = output_seqlens , dim = 0 )
560538 reordered = reorder_list (unpacked , mb_list .backward_indices )
561- result = pad_and_stack_tensors_along_first_dim (reordered )
562- return self .aggregate_result (result )
539+ return pad_and_stack_tensors_along_first_dim (reordered )
563540
564541 @torch .no_grad ()
565542 def forward (
@@ -592,6 +569,9 @@ def export_stats(self) -> dict[str, float]:
592569 def onload (self ):
593570 raise NotImplementedError ()
594571
572+ def offload (self ):
573+ raise NotImplementedError ()
574+
595575
596576class InferenceEngine (abc .ABC ):
597577 def initialize (self , * args , ** kwargs ):
0 commit comments