-
Notifications
You must be signed in to change notification settings - Fork 276
refactor: refactor train engine high level APIs #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rchardx
merged 12 commits into
inclusionAI:main
from
aaaandychen:refactor-trainengine-api
Dec 11, 2025
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2a01f8b
refactor TrainEngine api
aaaandychen 4fcf930
Merge branch 'main' into refactor-trainengine-api
aaaandychen b5eb5b7
fix problems in refactoring train engine api
aaaandychen 4d39ab2
Merge branch 'main' into refactor-trainengine-api
53dcd95
fix problems in refactoring train engine api
f18da60
Merge branch 'inclusionAI:main' into refactor-trainengine-api
aaaandychen efa91f9
Merge remote-tracking branch 'origin/main' into refactor-trainengine-api
aaaandychen b8afbb1
refine the comments
aaaandychen a2979d0
refactor:implement BaseTrainEngine with Template Method and adapt hooks
4a77121
refactor:implement hook method to fetch output list
77a1aba
Merge remote-tracking branch 'origin/main' into refactor-trainengine-api
869858d
Merge branch 'main' into refactor-trainengine-api
aaaandychen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| from collections.abc import Callable | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| from areal.api.engine_api import TrainEngine | ||
| from areal.utils.data import ( | ||
| pack_tensor_dict, | ||
| pad_and_stack_tensors_along_first_dim, | ||
| reorder_list, | ||
| unpack_sequence, | ||
| ) | ||
|
|
||
| """ | ||
| provide template method of high level APIs | ||
| """ | ||
|
|
||
|
|
||
| class BaseTrainEngine(TrainEngine): | ||
| def __init__(self): | ||
| pass | ||
|
|
||
| def train_batch( | ||
| self, | ||
| input_: dict[str, Any], | ||
| loss_fn: Callable[..., torch.Tensor], | ||
| loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], | ||
| ) -> dict[str, float]: | ||
| """ | ||
| template method of train_batch | ||
| """ | ||
| self._ensure_ready() | ||
| self.optimizer_zero_grad() | ||
| _data_iterator, _, total_loss_weight = self._split_micro_batch( | ||
| input_, loss_weight_fn | ||
| ) | ||
|
|
||
| def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor: | ||
| return self._loss_compute( | ||
| output=logits, | ||
| inputs=inputs, | ||
| forward_only=False, | ||
| loss_fn=loss_fn, | ||
| total_loss_weight=total_loss_weight, | ||
| loss_weight_fn=loss_weight_fn, | ||
| ) | ||
|
|
||
| self.forward_backward_batch(_data_iterator, post_process=post_process) | ||
| return self.optimizer_step() | ||
|
|
||
| @torch.no_grad() | ||
| def eval_batch( | ||
| self, | ||
| input_: dict[str, Any], | ||
| loss_fn: Callable[..., torch.Tensor], | ||
| loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], | ||
| ) -> torch.Tensor | None: | ||
| """ | ||
| template method of eval_batch | ||
| """ | ||
| self._ensure_ready() | ||
| _data_iterator, _, total_loss_weight = self._split_micro_batch( | ||
| input_, loss_weight_fn | ||
| ) | ||
|
|
||
| losses: list[torch.Tensor] = [] | ||
|
|
||
| def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor: | ||
| loss = self._loss_compute( | ||
| output=logits, | ||
| inputs=inputs, | ||
| forward_only=True, | ||
| loss_fn=loss_fn, | ||
| total_loss_weight=total_loss_weight, | ||
| loss_weight_fn=loss_weight_fn, | ||
| ) | ||
| losses.append(loss) | ||
| return loss | ||
|
|
||
| self.forward_backward_batch( | ||
| _data_iterator, post_process=post_process, forward_only=True | ||
| ) | ||
| return self._post_eval(losses) | ||
|
|
||
| @torch.no_grad() | ||
| def forward_batch( | ||
| self, | ||
| input_: dict[str, Any], | ||
| output_seqlens: list[int] | None = None, | ||
| aggregate_fn: Callable[[list[Any]], Any] = torch.cat, | ||
| ) -> Any | None: | ||
| """ | ||
| template method of forward_batch | ||
| """ | ||
| self._ensure_ready() | ||
| cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] | ||
|
|
||
| if output_seqlens is None: | ||
| output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() | ||
| assert output_seqlens is not None | ||
|
|
||
| _data_iterator, mb_list, _ = self._split_micro_batch(input_) | ||
| outputs: list[torch.Tensor] = [] | ||
|
|
||
| def post_process(logits: torch.Tensor, inputs: dict) -> torch.Tensor: | ||
| result = self._post_hook(logits, inputs) | ||
| outputs.append(result) | ||
| return torch.tensor(1.0, device=logits.device) | ||
|
|
||
| self.forward_backward_batch( | ||
| _data_iterator, | ||
| post_process=post_process, | ||
| forward_only=True, | ||
| return_outputs=True, | ||
| ) | ||
|
|
||
| def aggregate_fn_wrap(result): | ||
| res = aggregate_fn(result) | ||
| seqlens = [output_seqlens[i] for i in mb_list.forward_indices] | ||
| unpacked = unpack_sequence(res, lens=seqlens, dim=0) | ||
| reordered = reorder_list(unpacked, mb_list.backward_indices) | ||
| res = pad_and_stack_tensors_along_first_dim(reordered) | ||
| return res | ||
|
|
||
| return self._post_forward_batch(outputs, aggregate_fn_wrap) | ||
|
|
||
| def _ensure_ready(self): | ||
| """ | ||
|
|
||
| :return: | ||
| """ | ||
| pass | ||
|
|
||
| def _post_forward_batch(self, result, aggregate_fn): | ||
| """ | ||
|
|
||
| :return: | ||
| """ | ||
| return aggregate_fn(result) | ||
|
|
||
| def _loss_compute( | ||
| self, | ||
| output: torch.Tensor, | ||
| inputs: dict[str, Any], | ||
| forward_only: bool, | ||
| total_loss_weight: torch.Tensor, | ||
| loss_fn: Callable[..., torch.Tensor], | ||
| loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| pass | ||
|
|
||
| def _post_hook( | ||
| self, | ||
| output: torch.Tensor, | ||
| inputs: dict, | ||
| ) -> torch.Tensor: | ||
| pass | ||
|
|
||
| def _post_eval( | ||
| self, | ||
| losses: list[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.