-
Notifications
You must be signed in to change notification settings - Fork 104
added abstraction for target model backend #269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1279fe8 to
a7ffb59
Compare
Summary of ChangesHello @FrankLeeeee, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust abstraction layer for the target model backend within the Eagle3 training framework. It allows for flexible integration with various model backends, including SGLang, HuggingFace, and custom implementations, enhancing the system's modularity and extensibility. The core training script has been thoroughly refactored into distinct, manageable functions, improving code organization and maintainability. Additionally, tensor parallelism loading for linear layers has been standardized through state dictionary pre-hooks, streamlining distributed model initialization and ensuring proper weight distribution across devices. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and valuable refactoring by abstracting the target model backend. This allows for more flexibility in choosing between different inference engines like HuggingFace Transformers, SGLang, or custom implementations, which is a great improvement for the project. The code is now more modular and easier to maintain.
However, I've found several critical issues that need to be addressed. A hardcoded API key has been added, which is a major security risk. The logic for resuming training from checkpoints is broken, and there are a few bugs in the training loop logic, including swapped function arguments and incorrect unpacking of return values, which will lead to runtime errors. I've also identified a typo in another script that will cause it to fail. Please review the detailed comments for each issue.
| global_step = 0 | ||
| start_epoch = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for resuming training from a checkpoint appears to be incomplete. While the draft model weights are loaded from the last checkpoint in build_draft_model, the optimizer state, global step, and start epoch are not restored in the main function. The variables global_step and start_epoch are initialized to 0 and never updated from a checkpoint, causing training to always start from the beginning, even when resume is specified. The previous implementation correctly loaded these states from training_state.pt.
scripts/train_eagle3_online.py
Outdated
| # ================================================ | ||
| # 7.1 Training Step | ||
| # ================================================ | ||
| plosses, acces = run_forward(args, eagle3_model, target_model, data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arguments in the call to run_forward are swapped. The function signature is run_forward(args, eagle3_model, data, target_model, ...) but it's being called as run_forward(args, eagle3_model, target_model, data). This will pass the target_model object as the data dictionary and vice-versa, leading to a runtime error.
| plosses, acces = run_forward(args, eagle3_model, target_model, data) | |
| plosses, acces = run_forward(args, eagle3_model, data, target_model=target_model) |
scripts/train_eagle3.py
Outdated
| ) | ||
| print_on_rank0(f"Saved model configuration to {epoch_output_dir}") | ||
| dist.barrier() | ||
| plosses, _, acces = run_forward(args, eagle3_model, target_model, data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The run_forward function returns a tuple of two lists: (plosses, acces). However, the return value is being unpacked into three variables: plosses, _, acces. This will raise a ValueError at runtime because there are not enough values to unpack. The arguments in the function call are also swapped.
| plosses, _, acces = run_forward(args, eagle3_model, target_model, data) | |
| plosses, acces = run_forward(args, eagle3_model, data, target_model=target_model) |
scripts/train_eagle3_online.py
Outdated
| if target_model: | ||
| eagle3_data = target_model.generate_eagle3_data( | ||
| input_ids=data["input_ids"].cuda(), | ||
| attention_mask=data["attention_mask"].cuda(), | ||
| loss_mask=data["loss_mask"].cuda(), | ||
| ) | ||
| else: | ||
| target = target_head(data["hidden_states"].cuda()) | ||
|
|
||
| plosses, _, acces = eagle3_model( | ||
| input_ids=eagle3_data.input_ids, | ||
| attention_mask=eagle3_data.attention_mask, | ||
| loss_mask=eagle3_data.loss_mask, | ||
| target=eagle3_data.target, | ||
| hidden_states=eagle3_data.hidden_states, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential UnboundLocalError in the run_forward function. If target_model is None, the if target_model: block is skipped, and eagle3_data is never defined. However, it is used later in the call to eagle3_model. The else block only defines target, not the full eagle3_data object. Both branches of the conditional should define eagle3_data to avoid this error.
scripts/train_eagle3_sgl_online.py
Outdated
| next(dataloader) | ||
| for data in dataloader: | ||
| self.steps_consumed_in_current_epoch += 1 | ||
| ·elf.steps_consumed_in_current_epoch += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scripts/train_eagle3.py
Outdated
| print_with_rank("Initialized eval dataloader") | ||
| else: | ||
| eval_dataloader = None | ||
| return train_dataloader, vocab_mapping_path, eval_dataloader, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function build_dataloaders has a return type hint of Tuple[DataLoader, str, Optional[DataLoader]], which is a 3-element tuple. However, the return statement has a trailing comma, which creates a 4-element tuple with None as the last element. This mismatch can be confusing. Please remove the trailing comma to match the type hint.
| return train_dataloader, vocab_mapping_path, eval_dataloader, | |
| return train_dataloader, vocab_mapping_path, eval_dataloader |
specforge/layers/linear.py
Outdated
| up = shard_tensor(up, self.tp_group, 0) | ||
| state_dict['weight'] = torch.cat((gate, up), dim=0) | ||
|
|
||
| assert "bias" not in state_dict, "Bias is not supported for merged QKV layout" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion error message in handle_gate_up_layout is "Bias is not supported for merged QKV layout". This seems to be a copy-paste error from handle_merged_qkv. The message should refer to the 'gate_up' layout for better clarity when debugging.
| assert "bias" not in state_dict, "Bias is not supported for merged QKV layout" | |
| assert "bias" not in state_dict, "Bias is not supported for gate_up layout" |
a7ffb59 to
44764f5
Compare
44764f5 to
7b96f40
Compare
|
@FrankLeeeee The lint test failed. And did you test it end to end? is the acc length good? |
|
Do you update the status of optimizers while loading the checkpoints? I am thinking about adding a Trainer Class to wrap every thing just like I do in |
I'll test this code to train qwen2.5-7b's eagle3 tonight and give some feedback soon. |
|
I am splitting this PR so it is easier for review. |
I train the model use sharedgpt dataset for 3 epochs and hf backend. And I test it using the training set with the topk=1, steps =3 and tokens=4, the result is [2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 63, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 0, token usage: 0.00, accept len: 1.86, cuda graph: True, gen throughput (token/s): 4.27, #queue-req: 0,
[2025-11-04 08:41:02] INFO: 127.0.0.1:48648 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:02] INFO: 127.0.0.1:48650 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 67, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 112, token usage: 0.00, accept len: 2.19, cuda graph: True, gen throughput (token/s): 258.07, #queue-req: 0,
[2025-11-04 08:41:03] INFO: 127.0.0.1:48660 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 98, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 159, token usage: 0.00, accept len: 2.10, cuda graph: True, gen throughput (token/s): 252.92, #queue-req: 0,
[2025-11-04 08:41:03] INFO: 127.0.0.1:48666 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 455, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 512, token usage: 0.00, accept len: 1.61, cuda graph: True, gen throughput (token/s): 170.16, #queue-req: 0,
[2025-11-04 08:41:03] INFO: 127.0.0.1:48682 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 275, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 349, token usage: 0.00, accept len: 2.12, cuda graph: True, gen throughput (token/s): 236.20, #queue-req: 0,
[2025-11-04 08:41:03] INFO: 127.0.0.1:48694 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, I encountered the following problems:
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 742, in <module>
[rank0]: main()
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 673, in main
[rank0]: plosses, acces = run_forward(args, eagle3_model, data, target_model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 481, in run_forward
[rank0]: eagle3_data = target_model.generate_eagle3_data(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 315, in generate_eagle3_data
[rank0]: logits_list, aux_hidden_states_list = self.extend(
[rank0]: ^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 240, in extend
[rank0]: batch.prepare_for_extend()
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/managers/schedule_batch.py", line 1198, in prepare_for_extend
[rank0]: out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 355, in alloc_for_extend
[rank0]: out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 203, in alloc_token_slots
[rank0]: evict_from_tree_cache(tree_cache, num_tokens)
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 247, in evict_from_tree_cache
[rank0]: tree_cache.evict(num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'types.SimpleNamespace' object has no attribute 'evict'. Did you mean: 'device'?
parser.add_argument("--cache-key", type=str, default=None)
parser.add_argument("--cache-dir", type=str, default="./cache")
parser.add_argument("--output-dir", type=str, required=True)
- parser.add_argument("--eval-interval", type=int, default=1)
- parser.add_argument("--save-interval", type=int, default=1)
+ parser.add_argument("--eval-interval", type=int, default=-1)
+ parser.add_argument("--save-interval", type=int, default=-1)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--dist-timeout",
@@ -571,6 +571,12 @@ def main():
draft_model.load_vocab_mapping(vocab_mapping_path)
print_with_rank("Loaded vocab mapping")
+ # calculate steps per ecoch if not provided
+ if args.save_interval == -1:
+ args.save_interval = len(train_dataloader)
+ if args.eval_interval == -1:
+ args.eval_interval = len(train_dataloader)
+
# Calculate total steps if not provided
if args.total_steps is None:
steps_per_epoch = math.ceil(
@@ -661,7 +667,7 @@ def main():
# ================================================
# 7.1 Training Step
# ================================================
- plosses, acces = run_forward(args, eagle3_model, target_model, data)
+ plosses, acces = run_forward(args, eagle3_model, data, target_model)
run_backward_and_update(args, plosses, optimizer, global_step)
# log training metrics
@@ -676,12 +682,12 @@ def main():
if dist.get_rank() == 0:
time_per_step = time.time() - last_time
last_time = time.time()
- avg_loss = sum(pl.item() for pl in plosses) / len(plosses)
+ avg_loss = sum(pl for pl in plosses) / len(plosses)
avg_acc = sum(acces) / len(acces)
progress_bar.set_postfix(
{
- "loss": f"{avg_loss:.2f}",
- "acc": f"{avg_acc:.2f}",
+ "loss": f"{avg_loss:.4f}",
+ "acc": f"{avg_acc:.4f}",
"time": f"{time_per_step:.2f}s",
}
) |
Because I don't have permission to push the code. You can use the following patch to fix the sglang backend bug, I'll try to train a model use it. index a208dd7..bcbc7db 100644
--- a/specforge/modeling/target/eagle3_target_model.py
+++ b/specforge/modeling/target/eagle3_target_model.py
@@ -9,6 +9,7 @@ import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
+from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
@@ -222,17 +223,16 @@ class SGLangEagle3TargetModel(Eagle3TargetModel):
@torch.no_grad
def extend(self, reqs, capture_aux_hidden_states: bool = True):
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
- dummy_tree_cache = SimpleNamespace(
- page_size=self.model_runner.server_args.page_size,
- device=self.model_runner.device,
+ tree_cache = RadixCache(
+ None,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
+ page_size=self.model_runner.server_args.page_size,
)
-
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
- tree_cache=dummy_tree_cache,
+ tree_cache=tree_cache,
model_config=self.model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE, |
5e5c339 to
9584780
Compare
Motivation
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist