Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 126 additions & 6 deletions areal/api/engine_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from collections.abc import Callable
from collections.abc import Callable, Iterable
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any

Expand All @@ -18,6 +18,9 @@
SaveLoadMeta,
WeightUpdateMeta,
)
from areal.utils.data import (
MicroBatchList,
)

if TYPE_CHECKING:
from areal.api.workflow_api import RolloutWorkflow
Expand Down Expand Up @@ -204,12 +207,111 @@ def load(self, meta: SaveLoadMeta):
"""
raise NotImplementedError()

@abc.abstractmethod
def _split_micro_batch(
self,
input_: dict[str, Any],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor] | None = None,
) -> tuple[Iterable[dict[str, torch.Tensor]], MicroBatchList, torch.Tensor]:
"""Split input batch into micro-batches for gradient accumulation.

Parameters
----------
input_ : dict[str, Any]
The input batch dictionary.
loss_weight_fn : Callable[[dict[str, Any]], torch.Tensor], optional
A function to compute the loss weight for each micro-batch.

Returns
-------
tuple[Iterable[dict[str, torch.Tensor]], MicroBatchList]
An iterator over micro-batch dictionaries, the MicroBatchList iterator with metadata and total_loss_weight.
"""
raise NotImplementedError()

@abc.abstractmethod
def _forward_compute_mb(
self,
mb_input: tuple[Any, ...],
post_process_fn: Callable[[torch.Tensor, dict[str, Any]], Any],
**kwargs,
) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]:
"""Compute forward pass and prepare loss function closure for a single micro-batch.

Parameters
----------
mb_input : tuple[Any, ...]
A tuple containing the micro-batch input data.
post_process_fn : Callable[[torch.Tensor, dict[str, Any]], Any]
A function that processes the model output.
**kwargs
Additional keyword arguments for specific implementations.

Returns
-------
tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]
The model output (logits) and a callable that computes loss and returns
(loss_tensor, result_dict) for gradient accumulation.
"""
raise NotImplementedError()

@abc.abstractmethod
def optimizer_zero_grad(self):
"""Zero out all gradients in the optimizer."""
raise NotImplementedError()

@abc.abstractmethod
def optimizer_step(self):
"""Perform a single optimization step.

Returns
-------
dict[str, float]
Training statistics containing ``update_successful``, ``grad_norm``, and ``lr``.
"""
raise NotImplementedError()

@abc.abstractmethod
def lr_scheduler_step(self):
"""Advance the learning rate scheduler by one step."""
raise NotImplementedError()

def step_lr_scheduler(self):
"""Step the learning rate scheduler.

Since PPO uses minibatch updates, this method should be called periodically
(e.g., once per PPO step). It is separated from train_batch to allow
for more flexible learning rate scheduling.
This is an alias for `lr_scheduler_step()`.
"""
return self.lr_scheduler_step()

@abc.abstractmethod
def forward_backward_batch(
self,
data_iterator: Iterable[dict[str, torch.Tensor]],
post_process: Callable[[torch.Tensor, dict], torch.Tensor] | None = None,
return_outputs: bool = False,
forward_only: bool = False,
):
"""Process micro-batches through forward and optionally backward pass.

Parameters
----------
data_iterator : Iterable[dict[str, torch.Tensor]]
An iterable that yields micro-batch dictionaries containing packed tensors.
post_process : Callable[[torch.Tensor, dict[str, Any]], Any], optional
A function that processes the model output.

return_outputs : bool, optional
If True, collect and return model outputs (logits) instead of losses.
Only used when ``forward_only=True``. By default False.
forward_only : bool, optional
If True, only perform forward pass (no backward pass or gradient computation).
If False, perform both forward and backward passes for training. By default False.

Returns
-------
ForwardBackwardOutputs
A dataclass containing ``mb_outputs`` (list of output tensors) when
``return_outputs=True``, or ``losses`` (list of loss tensors) otherwise.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -285,7 +387,7 @@ def eval_batch(
raise NotImplementedError()

@torch.no_grad()
def forward(
def forward_batch(
self,
input_: dict[str, Any],
output_seqlens: list[int] | None = None,
Expand Down Expand Up @@ -313,7 +415,19 @@ def forward(
For actor (is_critic=False): logprobs tensor aggregated by `aggregate_fn`.
For critic (is_critic=True): values tensor aggregated by `aggregate_fn`.
"""
raise NotImplementedError()
raise NotImplementedError

@torch.no_grad()
def forward(
self,
input_: dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> Any | None:
"""
alias for forward_batch
"""
return self.forward_batch(input_, output_seqlens, aggregate_fn)

def export_stats(self) -> dict[str, float]:
"""Export the statistics recorded in this engine process.
Expand All @@ -330,6 +444,12 @@ def export_stats(self) -> dict[str, float]:
"""
raise NotImplementedError()

def onload(self):
raise NotImplementedError()

def offload(self):
raise NotImplementedError()


class InferenceEngine(abc.ABC):
def initialize(self, *args, **kwargs):
Expand Down
163 changes: 163 additions & 0 deletions areal/engine/base_train_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from collections.abc import Callable
from typing import Any

import torch

from areal.api.engine_api import TrainEngine
from areal.utils.data import (
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
reorder_list,
unpack_sequence,
)

"""
provide template method of high level APIs
"""


class BaseTrainEngine(TrainEngine):
def __init__(self):
pass

def train_batch(
self,
input_: dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> dict[str, float]:
"""
template method of train_batch
"""
self._ensure_ready()
self.optimizer_zero_grad()
_data_iterator, _, total_loss_weight = self._split_micro_batch(
input_, loss_weight_fn
)

def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor:
return self._loss_compute(
output=logits,
inputs=inputs,
forward_only=False,
loss_fn=loss_fn,
total_loss_weight=total_loss_weight,
loss_weight_fn=loss_weight_fn,
)

self.forward_backward_batch(_data_iterator, post_process=post_process)
return self.optimizer_step()

@torch.no_grad()
def eval_batch(
self,
input_: dict[str, Any],
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
"""
template method of eval_batch
"""
self._ensure_ready()
_data_iterator, _, total_loss_weight = self._split_micro_batch(
input_, loss_weight_fn
)

losses: list[torch.Tensor] = []

def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor:
loss = self._loss_compute(
output=logits,
inputs=inputs,
forward_only=True,
loss_fn=loss_fn,
total_loss_weight=total_loss_weight,
loss_weight_fn=loss_weight_fn,
)
losses.append(loss)
return loss

self.forward_backward_batch(
_data_iterator, post_process=post_process, forward_only=True
)
return self._post_eval(losses)

@torch.no_grad()
def forward_batch(
self,
input_: dict[str, Any],
output_seqlens: list[int] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> Any | None:
"""
template method of forward_batch
"""
self._ensure_ready()
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]

if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
assert output_seqlens is not None

_data_iterator, mb_list, _ = self._split_micro_batch(input_)
outputs: list[torch.Tensor] = []

def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor:
result = self._post_hook(logits, inputs)
outputs.append(result)
return torch.tensor(1.0, device=logits.device)

self.forward_backward_batch(
_data_iterator,
post_process=post_process,
forward_only=True,
return_outputs=True,
)

def aggregate_fn_wrap(result):
res = aggregate_fn(result)
seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
res = pad_and_stack_tensors_along_first_dim(reordered)
return res

return self._post_forward_batch(outputs, aggregate_fn_wrap)

def _ensure_ready(self):
"""

:return:
"""
pass

def _post_forward_batch(self, result, aggregate_fn):
"""

:return:
"""
return aggregate_fn(result)

def _loss_compute(
self,
output: torch.Tensor,
inputs: dict[str, Any],
forward_only: bool,
total_loss_weight: torch.Tensor,
loss_fn: Callable[..., torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor:
pass

def _post_hook(
self,
output: torch.Tensor,
inputs: dict,
) -> torch.Tensor:
pass

def _post_eval(
self,
losses: list[torch.Tensor],
) -> torch.Tensor:
pass
Loading