Skip to content

[Refactor] refactor TrainEngine APIs #601

@garrett4wade

Description

@garrett4wade

Checklist

  • This refactor maintains backward compatibility with all user-facing APIs.
  • For large-scale refactors, I’ve prepared a phased implementation plan.

Motivation

The current TrainEngine exposes only high-level APIs like train_batch and eval_batch.
This implementation has the following drawbacks:

  • Significant code duplication across public APIs and between different backends
  • Insufficient flexibility for fine-grained microbatch control

To address these issues and align with frameworks like verl 0.6.0 and tinker,
we plan to refactor the TrainEngine implementation.

This refactoring will maintain backward compatibility, ensuring no impact on application-level users.

Proposed Refactor Plan

class TrainEngine:
    # Internal APIs
    def _forward_compute_mb(
        self, 
        mb_input: dict[str, Any], 
        loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
        loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
        **kwargs,
    ) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]:
        # Update:
        # 1. For FSDP engine, run a single forward pass and compute the loss
        # 2. For Megatron engine, run forward on this pipeline stage and output hidden states & loss fn over outputs. 
        #     If on the last stage, compute the weighted loss. Check the `forward_step` inline function in MegatronEngine.
        ...

    # Edit: This method is not mandatory temparorily.
    #def _forward_backward_compute_mb(
    #    self, 
    #    mb_input: dict[str, Any], 
    #    loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
    #    loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
    #    mb_output: torch.Tensor | None, 
    #    mb_output_grads: torch.Tensor | None,
    #    overlapped: bool,
    #    **kwargs,
    #) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    #    # Runs forward-backward computation with optional overlap
    #    ...
    
    # Exposed APIs
    def optimizer_zero_grad(self):
        ...
    def optimizer_step(self):
        ...
    def lr_scheduler_step(self):
        ...
    def step_lr_scheduler(self):
        # Backward compatibility
        return self.lr_scheduler_step()

    @dataclass
    class ForwardBackwardOutputs:
        mb_outputs: list[torch.Tensor] | None
        losses: list[torch.Tensor] | None

    # Edit: signature
    def forward_backward_batch(
        data_iterator: Iterable[dict[str, torch.Tensor],
        loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
        loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
        output_post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None,
        return_outputs: bool = False,
        # edit: added a new argument
        forward_only: bool = False,
    ) -> ForwardBackwardOutputs:
        # Reduces code duplication in `train_batch`, `forward`, and `eval_batch`.
        # For FSDP, composes internal APIs `_forward_compute_mb` and `loss.backward` in a for-loop
        # For megatron, passes `_forward_compute_mb` to megatron's `forward_backward_func` and runs it.
        ...

    # High-level training APIs (maintained for backward compatibility).
    def train_batch(
        self,
        input_: dict[str, Any],
        loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
        loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
    ) -> dict[str, float]:
        # 1. split micro batch
        # 2. call forward_backward_batch with loss_fn and loss_weight_fn
        # 3. call optimizer step
        ...
    @torch.no_grad()
    def eval_batch(
        self,
        input_: dict[str, Any],
        loss_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor],
        loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
    ) -> torch.Tensor | None:
        # 1. split micro batch
        # 2. call forward_backward_batch with loss_fn and loss_weight_fn and `forward_only=True`
        # 3. aggregate losses and return
        ...
    # Renamed for consistency; maintains backward compatibility via alias below.
    @torch.no_grad()
    def forward_batch(
        self,
        input_: dict[str, Any],
        output_seqlens: list[int] | None = None,
        post_hook: Callable[[torch.Tensor, dict[str, Any]], Any] | None = None,
        aggregate_fn: Callable[[list[Any]], Any] = torch.cat,
    ) -> Any | None:
        # 1. split micro batch
        # 2. call forward_backward_batch without loss_fn, with output_post_hook, `forward_only=True`, and `return_outputs=True`
        # 3. aggregate outputs according to output_seqlens and aggregate_fn
        ...
    @torch.no_grad()
    def forward(self, ...):
        # Backward compatibility alias
        return self.forward_batch(...)
    # Other APIs (save/load, etc.) remain unchanged.
    ...

Additional Context

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions