Skip to content

Conversation

@FrankLeeeee
Copy link
Collaborator

Motivation

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @FrankLeeeee, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a robust abstraction layer for the target model backend within the Eagle3 training framework. It allows for flexible integration with various model backends, including SGLang, HuggingFace, and custom implementations, enhancing the system's modularity and extensibility. The core training script has been thoroughly refactored into distinct, manageable functions, improving code organization and maintainability. Additionally, tensor parallelism loading for linear layers has been standardized through state dictionary pre-hooks, streamlining distributed model initialization and ensuring proper weight distribution across devices.

Highlights

  • Target Model Backend Abstraction: Introduced a new Eagle3TargetModel interface with concrete implementations for SGLang, HuggingFace, and custom backends, allowing flexible choice of the target model's underlying inference engine.
  • Refactored Training Script: The main training script train_eagle3_online.py was renamed to train_eagle3.py and extensively refactored into modular helper functions for initialization, model building, data loading, and training steps, significantly improving code organization and maintainability.
  • Standardized Tensor Parallelism Loading: Implemented state dictionary pre-hooks in RowParallelLinear and ColumnParallelLinear layers to automatically handle weight sharding during model loading, simplifying the setup for distributed models and supporting various weight layouts (e.g., merged QKV, gate/up projections).
  • Enhanced SGLang Integration: Customized SGLang's ModelRunner and patched its logits processor to correctly extract auxiliary hidden states and logits, which are crucial for Eagle3 training, ensuring seamless compatibility with the SGLang inference engine.
  • Dependency Updates: Updated key dependencies, including transformers to 4.57.1 and sglang to 0.5.4, to leverage the latest features and improvements from these libraries.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring by abstracting the target model backend. This allows for more flexibility in choosing between different inference engines like HuggingFace Transformers, SGLang, or custom implementations, which is a great improvement for the project. The code is now more modular and easier to maintain.

However, I've found several critical issues that need to be addressed. A hardcoded API key has been added, which is a major security risk. The logic for resuming training from checkpoints is broken, and there are a few bugs in the training loop logic, including swapped function arguments and incorrect unpacking of return values, which will lead to runtime errors. I've also identified a typo in another script that will cause it to fail. Please review the detailed comments for each issue.

Comment on lines 591 to 636
global_step = 0
start_epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for resuming training from a checkpoint appears to be incomplete. While the draft model weights are loaded from the last checkpoint in build_draft_model, the optimizer state, global step, and start epoch are not restored in the main function. The variables global_step and start_epoch are initialized to 0 and never updated from a checkpoint, causing training to always start from the beginning, even when resume is specified. The previous implementation correctly loaded these states from training_state.pt.

# ================================================
# 7.1 Training Step
# ================================================
plosses, acces = run_forward(args, eagle3_model, target_model, data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The arguments in the call to run_forward are swapped. The function signature is run_forward(args, eagle3_model, data, target_model, ...) but it's being called as run_forward(args, eagle3_model, target_model, data). This will pass the target_model object as the data dictionary and vice-versa, leading to a runtime error.

Suggested change
plosses, acces = run_forward(args, eagle3_model, target_model, data)
plosses, acces = run_forward(args, eagle3_model, data, target_model=target_model)

)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()
plosses, _, acces = run_forward(args, eagle3_model, target_model, data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The run_forward function returns a tuple of two lists: (plosses, acces). However, the return value is being unpacked into three variables: plosses, _, acces. This will raise a ValueError at runtime because there are not enough values to unpack. The arguments in the function call are also swapped.

Suggested change
plosses, _, acces = run_forward(args, eagle3_model, target_model, data)
plosses, acces = run_forward(args, eagle3_model, data, target_model=target_model)

Comment on lines 452 to 495
if target_model:
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)
else:
target = target_head(data["hidden_states"].cuda())

plosses, _, acces = eagle3_model(
input_ids=eagle3_data.input_ids,
attention_mask=eagle3_data.attention_mask,
loss_mask=eagle3_data.loss_mask,
target=eagle3_data.target,
hidden_states=eagle3_data.hidden_states,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential UnboundLocalError in the run_forward function. If target_model is None, the if target_model: block is skipped, and eagle3_data is never defined. However, it is used later in the call to eagle3_model. The else block only defines target, not the full eagle3_data object. Both branches of the conditional should define eagle3_data to avoid this error.

next(dataloader)
for data in dataloader:
self.steps_consumed_in_current_epoch += 1
·elf.steps_consumed_in_current_epoch += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo ·elf which should be self. This will cause a NameError when this line is executed.

Suggested change
·elf.steps_consumed_in_current_epoch += 1
self.steps_consumed_in_current_epoch += 1

print_with_rank("Initialized eval dataloader")
else:
eval_dataloader = None
return train_dataloader, vocab_mapping_path, eval_dataloader,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function build_dataloaders has a return type hint of Tuple[DataLoader, str, Optional[DataLoader]], which is a 3-element tuple. However, the return statement has a trailing comma, which creates a 4-element tuple with None as the last element. This mismatch can be confusing. Please remove the trailing comma to match the type hint.

Suggested change
return train_dataloader, vocab_mapping_path, eval_dataloader,
return train_dataloader, vocab_mapping_path, eval_dataloader

up = shard_tensor(up, self.tp_group, 0)
state_dict['weight'] = torch.cat((gate, up), dim=0)

assert "bias" not in state_dict, "Bias is not supported for merged QKV layout"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion error message in handle_gate_up_layout is "Bias is not supported for merged QKV layout". This seems to be a copy-paste error from handle_merged_qkv. The message should refer to the 'gate_up' layout for better clarity when debugging.

Suggested change
assert "bias" not in state_dict, "Bias is not supported for merged QKV layout"
assert "bias" not in state_dict, "Bias is not supported for gate_up layout"

@zyksir
Copy link
Collaborator

zyksir commented Nov 1, 2025

@FrankLeeeee The lint test failed. And did you test it end to end? is the acc length good?

@zyksir
Copy link
Collaborator

zyksir commented Nov 1, 2025

Do you update the status of optimizers while loading the checkpoints? I am thinking about adding a Trainer Class to wrap every thing just like I do in train_eagle3_sgl_online.py

@jiapingW
Copy link
Contributor

jiapingW commented Nov 3, 2025

@FrankLeeeee The lint test failed. And did you test it end to end? is the acc length good?

I'll test this code to train qwen2.5-7b's eagle3 tonight and give some feedback soon.

@FrankLeeeee
Copy link
Collaborator Author

I am splitting this PR so it is easier for review.

@jiapingW
Copy link
Contributor

jiapingW commented Nov 4, 2025

@FrankLeeeee The lint test failed. And did you test it end to end? is the acc length good?

I'll test this code to train qwen2.5-7b's eagle3 tonight and give some feedback soon.

I train the model use sharedgpt dataset for 3 epochs and hf backend. And I test it using the training set with the topk=1, steps =3 and tokens=4, the result is

[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 63, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 0, token usage: 0.00, accept len: 1.86, cuda graph: True, gen throughput (token/s): 4.27, #queue-req: 0, 
[2025-11-04 08:41:02] INFO:     127.0.0.1:48648 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] INFO:     127.0.0.1:48650 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 67, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 112, token usage: 0.00, accept len: 2.19, cuda graph: True, gen throughput (token/s): 258.07, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48660 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 98, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 159, token usage: 0.00, accept len: 2.10, cuda graph: True, gen throughput (token/s): 252.92, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48666 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 455, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 512, token usage: 0.00, accept len: 1.61, cuda graph: True, gen throughput (token/s): 170.16, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48682 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 275, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 349, token usage: 0.00, accept len: 2.12, cuda graph: True, gen throughput (token/s): 236.20, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48694 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 

I encountered the following problems:

  1. An error occurred when using the sglang backend:
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 742, in <module>
[rank0]: main()
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 673, in main
[rank0]: plosses, acces = run_forward(args, eagle3_model, data, target_model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 481, in run_forward
[rank0]: eagle3_data = target_model.generate_eagle3_data(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 315, in generate_eagle3_data
[rank0]: logits_list, aux_hidden_states_list = self.extend(
[rank0]: ^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 240, in extend
[rank0]: batch.prepare_for_extend()
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/managers/schedule_batch.py", line 1198, in prepare_for_extend
[rank0]: out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 355, in alloc_for_extend
[rank0]: out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 203, in alloc_token_slots
[rank0]: evict_from_tree_cache(tree_cache, num_tokens)
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 247, in evict_from_tree_cache
[rank0]: tree_cache.evict(num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'types.SimpleNamespace' object has no attribute 'evict'. Did you mean: 'device'?
  1. The current implementation uses the same DP group for both the target model and the draft model, but this is not the most efficient. The draft model should be able to use all ranks for DP, similar to what is shown in https://github.com/jiapingW/SpecForge/blob/main/scripts/train_eagle3_sgl_online.py#L696.
    You can use the following patch:
     parser.add_argument("--cache-key", type=str, default=None)
     parser.add_argument("--cache-dir", type=str, default="./cache")
     parser.add_argument("--output-dir", type=str, required=True)
-    parser.add_argument("--eval-interval", type=int, default=1)
-    parser.add_argument("--save-interval", type=int, default=1)
+    parser.add_argument("--eval-interval", type=int, default=-1)
+    parser.add_argument("--save-interval", type=int, default=-1)
     parser.add_argument("--seed", type=int, default=0)
     parser.add_argument(
         "--dist-timeout",
@@ -571,6 +571,12 @@ def main():
     draft_model.load_vocab_mapping(vocab_mapping_path)
     print_with_rank("Loaded vocab mapping")
 
+    # calculate steps per ecoch if not provided
+    if args.save_interval == -1:
+        args.save_interval = len(train_dataloader)
+    if args.eval_interval == -1:
+        args.eval_interval = len(train_dataloader)
+
     # Calculate total steps if not provided
     if args.total_steps is None:
         steps_per_epoch = math.ceil(
@@ -661,7 +667,7 @@ def main():
             # ================================================
             # 7.1 Training Step
             # ================================================
-            plosses, acces = run_forward(args, eagle3_model, target_model, data)
+            plosses, acces = run_forward(args, eagle3_model, data, target_model)
             run_backward_and_update(args, plosses, optimizer, global_step)
 
             # log training metrics
@@ -676,12 +682,12 @@ def main():
             if dist.get_rank() == 0:
                 time_per_step = time.time() - last_time
                 last_time = time.time()
-                avg_loss = sum(pl.item() for pl in plosses) / len(plosses)
+                avg_loss = sum(pl for pl in plosses) / len(plosses)
                 avg_acc = sum(acces) / len(acces)
                 progress_bar.set_postfix(
                     {
-                        "loss": f"{avg_loss:.2f}",
-                        "acc": f"{avg_acc:.2f}",
+                        "loss": f"{avg_loss:.4f}",
+                        "acc": f"{avg_acc:.4f}",
                         "time": f"{time_per_step:.2f}s",
                     }
                 )

@jiapingW
Copy link
Contributor

jiapingW commented Nov 4, 2025

@FrankLeeeee The lint test failed. And did you test it end to end? is the acc length good?

I'll test this code to train qwen2.5-7b's eagle3 tonight and give some feedback soon.

I train the model use sharedgpt dataset for 3 epochs and hf backend. And I test it using the training set with the topk=1, steps =3 and tokens=4, the result is

[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 63, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 0, token usage: 0.00, accept len: 1.86, cuda graph: True, gen throughput (token/s): 4.27, #queue-req: 0, 
[2025-11-04 08:41:02] INFO:     127.0.0.1:48648 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] INFO:     127.0.0.1:48650 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:02] Prefill batch. #new-seq: 1, #new-token: 67, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:02] Decode batch. #running-req: 1, #token: 112, token usage: 0.00, accept len: 2.19, cuda graph: True, gen throughput (token/s): 258.07, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48660 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 98, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 159, token usage: 0.00, accept len: 2.10, cuda graph: True, gen throughput (token/s): 252.92, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48666 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 455, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 512, token usage: 0.00, accept len: 1.61, cuda graph: True, gen throughput (token/s): 170.16, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48682 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 275, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-11-04 08:41:03] Decode batch. #running-req: 1, #token: 349, token usage: 0.00, accept len: 2.12, cuda graph: True, gen throughput (token/s): 236.20, #queue-req: 0, 
[2025-11-04 08:41:03] INFO:     127.0.0.1:48694 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-11-04 08:41:03] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0, 

I encountered the following problems:

  1. An error occurred when using the sglang backend:
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 742, in <module>
[rank0]: main()
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 673, in main
[rank0]: plosses, acces = run_forward(args, eagle3_model, data, target_model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/scripts/train_eagle3_online.py", line 481, in run_forward
[rank0]: eagle3_data = target_model.generate_eagle3_data(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 315, in generate_eagle3_data
[rank0]: logits_list, aux_hidden_states_list = self.extend(
[rank0]: ^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/uv_env/pr/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 240, in extend
[rank0]: batch.prepare_for_extend()
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/managers/schedule_batch.py", line 1198, in prepare_for_extend
[rank0]: out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 355, in alloc_for_extend
[rank0]: out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 203, in alloc_token_slots
[rank0]: evict_from_tree_cache(tree_cache, num_tokens)
[rank0]: File "/disk3/wjp/pr_test/sglang/python/sglang/srt/mem_cache/common.py", line 247, in evict_from_tree_cache
[rank0]: tree_cache.evict(num_tokens)
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: AttributeError: 'types.SimpleNamespace' object has no attribute 'evict'. Did you mean: 'device'?
  1. The current implementation uses the same DP group for both the target model and the draft model, but this is not the most efficient. The draft model should be able to use all ranks for DP, similar to what is shown in https://github.com/jiapingW/SpecForge/blob/main/scripts/train_eagle3_sgl_online.py#L696.
    You can use the following patch:
     parser.add_argument("--cache-key", type=str, default=None)
     parser.add_argument("--cache-dir", type=str, default="./cache")
     parser.add_argument("--output-dir", type=str, required=True)
-    parser.add_argument("--eval-interval", type=int, default=1)
-    parser.add_argument("--save-interval", type=int, default=1)
+    parser.add_argument("--eval-interval", type=int, default=-1)
+    parser.add_argument("--save-interval", type=int, default=-1)
     parser.add_argument("--seed", type=int, default=0)
     parser.add_argument(
         "--dist-timeout",
@@ -571,6 +571,12 @@ def main():
     draft_model.load_vocab_mapping(vocab_mapping_path)
     print_with_rank("Loaded vocab mapping")
 
+    # calculate steps per ecoch if not provided
+    if args.save_interval == -1:
+        args.save_interval = len(train_dataloader)
+    if args.eval_interval == -1:
+        args.eval_interval = len(train_dataloader)
+
     # Calculate total steps if not provided
     if args.total_steps is None:
         steps_per_epoch = math.ceil(
@@ -661,7 +667,7 @@ def main():
             # ================================================
             # 7.1 Training Step
             # ================================================
-            plosses, acces = run_forward(args, eagle3_model, target_model, data)
+            plosses, acces = run_forward(args, eagle3_model, data, target_model)
             run_backward_and_update(args, plosses, optimizer, global_step)
 
             # log training metrics
@@ -676,12 +682,12 @@ def main():
             if dist.get_rank() == 0:
                 time_per_step = time.time() - last_time
                 last_time = time.time()
-                avg_loss = sum(pl.item() for pl in plosses) / len(plosses)
+                avg_loss = sum(pl for pl in plosses) / len(plosses)
                 avg_acc = sum(acces) / len(acces)
                 progress_bar.set_postfix(
                     {
-                        "loss": f"{avg_loss:.2f}",
-                        "acc": f"{avg_acc:.2f}",
+                        "loss": f"{avg_loss:.4f}",
+                        "acc": f"{avg_acc:.4f}",
                         "time": f"{time_per_step:.2f}s",
                     }
                 )

Because I don't have permission to push the code. You can use the following patch to fix the sglang backend bug, I'll try to train a model use it.

index a208dd7..bcbc7db 100644
--- a/specforge/modeling/target/eagle3_target_model.py
+++ b/specforge/modeling/target/eagle3_target_model.py
@@ -9,6 +9,7 @@ import torch.nn as nn
 from sglang.srt.configs.model_config import ModelConfig
 from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
 from sglang.srt.managers.scheduler import Scheduler
+from sglang.srt.mem_cache.radix_cache import RadixCache
 from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
 from sglang.srt.sampling.sampling_params import SamplingParams
 from sglang.srt.server_args import ServerArgs
@@ -222,17 +223,16 @@ class SGLangEagle3TargetModel(Eagle3TargetModel):
     @torch.no_grad
     def extend(self, reqs, capture_aux_hidden_states: bool = True):
         # Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
-        dummy_tree_cache = SimpleNamespace(
-            page_size=self.model_runner.server_args.page_size,
-            device=self.model_runner.device,
+        tree_cache = RadixCache(
+            None,
             token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
+            page_size=self.model_runner.server_args.page_size,
         )
-
         batch = ScheduleBatch.init_new(
             reqs=reqs,
             req_to_token_pool=self.model_runner.req_to_token_pool,
             token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
-            tree_cache=dummy_tree_cache,
+            tree_cache=tree_cache,
             model_config=self.model_runner.model_config,
             enable_overlap=False,
             spec_algorithm=SpeculativeAlgorithm.NONE,

@FrankLeeeee FrankLeeeee force-pushed the feature/target-backend branch from 5e5c339 to 9584780 Compare November 4, 2025 07:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants