This refactoring will maintain backward compatibility, ensuring no impact on application-level users.
class TrainEngine:
# Internal APIs
def _forward_compute_mb(
self,
mb_input: dict[str, Any],
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]:
# Update:
# 1. For FSDP engine, run a single forward pass and compute the loss
# 2. For Megatron engine, run forward on this pipeline stage and output hidden states & loss fn over outputs.
# If on the last stage, compute the weighted loss. Check the `forward_step` inline function in MegatronEngine.
...
# Edit: This method is not mandatory temparorily.
#def _forward_backward_compute_mb(
# self,
# mb_input: dict[str, Any],
# loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
# loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
# mb_output: torch.Tensor | None,
# mb_output_grads: torch.Tensor | None,
# overlapped: bool,
# **kwargs,
#) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
# # Runs forward-backward computation with optional overlap
# ...
# Exposed APIs
def optimizer_zero_grad(self):
...
def optimizer_step(self):
...
def lr_scheduler_step(self):
...
def step_lr_scheduler(self):
# Backward compatibility
return self.lr_scheduler_step()
@dataclass
class ForwardBackwardOutputs:
mb_outputs: list[torch.Tensor] | None
losses: list[torch.Tensor] | None
# Edit: signature
def forward_backward_batch(
data_iterator: Iterable[dict[str, torch.Tensor],
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
output_post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None,
return_outputs: bool = False,
# edit: added a new argument
forward_only: bool = False,
) -> ForwardBackwardOutputs:
# Reduces code duplication in `train_batch`, `forward`, and `eval_batch`.
# For FSDP, composes internal APIs `_forward_compute_mb` and `loss.backward` in a for-loop
# For megatron, passes `_forward_compute_mb` to megatron's `forward_backward_func` and runs it.
...
# High-level training APIs (maintained for backward compatibility).
def train_batch(
self,
input_: dict[str, Any],
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> dict[str, float]:
# 1. split micro batch
# 2. call forward_backward_batch with loss_fn and loss_weight_fn
# 3. call optimizer step
...
@torch.no_grad()
def eval_batch(
self,
input_: dict[str, Any],
loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
# 1. split micro batch
# 2. call forward_backward_batch with loss_fn and loss_weight_fn and `forward_only=True`
# 3. aggregate losses and return
...
# Renamed for consistency; maintains backward compatibility via alias below.
@torch.no_grad()
def forward_batch(
self,
input_: dict[str, Any],
output_seqlens: list[int] | None = None,
post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None,
aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
) -> Any | None:
# 1. split micro batch
# 2. call forward_backward_batch without loss_fn, with output_post_hook, `forward_only=True`, and `return_outputs=True`
# 3. aggregate outputs according to output_seqlens and aggregate_fn
...
@torch.no_grad()
def forward(self, ...):
# Backward compatibility alias
return self.forward_batch(...)
# Other APIs (save/load, etc.) remain unchanged.
...
Checklist
Motivation
The current
TrainEngineexposes only high-level APIs liketrain_batchandeval_batch.This implementation has the following drawbacks:
To address these issues and align with frameworks like verl 0.6.0 and tinker,
we plan to refactor the
TrainEngineimplementation.This refactoring will maintain backward compatibility, ensuring no impact on application-level users.
Proposed Refactor Plan
Additional Context