diff --git a/configs/draft_lora_trainable_config.json b/configs/draft_lora_trainable_config.json new file mode 100644 index 00000000..d5f85a18 --- /dev/null +++ b/configs/draft_lora_trainable_config.json @@ -0,0 +1,35 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + "bias": "none", + "fan_in_fan_out": false, + "inference_mode": false, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 128, + "lora_dropout": 0.1, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 64, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "gate_proj", + "o_proj", + "q_proj", + "v_proj", + "k_proj" + ], + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} diff --git a/download.py b/download.py new file mode 100644 index 00000000..4fcdbeee --- /dev/null +++ b/download.py @@ -0,0 +1,28 @@ +import os + +from huggingface_hub import snapshot_download + + +def download_model(model_id, local_dir): + print(f"downloading model: {model_id}") + print(f"will save to: {local_dir}") + + try: + snapshot_download( + repo_id=model_id, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + print("download success!") + except Exception as e: + print(f"error: {e}") + + +if __name__ == "__main__": + model_identifier = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" + save_directory = f"./{model_identifier.replace('/', '_')}" + + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + download_model(model_id=model_identifier, local_dir=save_directory) diff --git a/examples/run_llama3_eagle3_lora_online.sh b/examples/run_llama3_eagle3_lora_online.sh new file mode 100755 index 00000000..fa325b4b --- /dev/null +++ b/examples/run_llama3_eagle3_lora_online.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +NUM_GPUS=${1:-8} +TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k} +DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json} +BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_lora_online.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \ + --base-draft-model-path $BASE_DRAFT_MODEL_PATH \ + --train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \ + --use-lora \ + --lora-config $DRAFT_LORA_CONFIG \ + --target-lora-path $TARGET_LORA_PATH \ + --num-epochs 1 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --skip-vocab-mapping \ + --wandb \ + --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ + --wandb-project "specforge-training" \ + --wandb-name "llama3-8b-lora-online-fixed-run-1" diff --git a/examples/run_llama3_eagle3_lora_online_fixed.sh b/examples/run_llama3_eagle3_lora_online_fixed.sh new file mode 100755 index 00000000..92dd52b8 --- /dev/null +++ b/examples/run_llama3_eagle3_lora_online_fixed.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +NUM_GPUS=${1:-8} +TARGET_LORA_PATH=${2:-/sgl-workspace/llama-duo_llama3.1-8b-summarize-gpt4o-128k} +DRAFT_LORA_CONFIG=${3:-$ROOT_DIR/configs/draft_lora_trainable_config.json} +BASE_DRAFT_MODEL_PATH=${4:-/sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_lora_online.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config /sgl-workspace/jamesliu1_sglang-EAGLE3-Llama-3.1-Instruct-8B/config.json \ + --base-draft-model-path $BASE_DRAFT_MODEL_PATH \ + --train-data-path $ROOT_DIR/cache/dataset/synth_summarize.jsonl \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-lora-fixed \ + --use-lora \ + --target-lora-path $TARGET_LORA_PATH \ + --num-epochs 1 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --skip-vocab-mapping \ + --wandb \ + --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ + --wandb-project "specforge-training" \ + --wandb-name "llama3-8b-lora-online-fixed-run-1" diff --git a/examples/run_llama3_eagle3_online.sh b/examples/run_llama3_eagle3_online.sh index 406b339f..d067cb85 100755 --- a/examples/run_llama3_eagle3_online.sh +++ b/examples/run_llama3_eagle3_online.sh @@ -25,3 +25,8 @@ torchrun \ # --mlflow-tracking-uri http://mlflow.grid1.ard.grid.linkedin.com:31812 \ # --eval-data-split 0.01 \ --attention-backend flex_attention + --cache-dir $ROOT_DIR/cache \ + --wandb \ + --wandb-key "f3b46a484034ca1fe99fc5ae4d19402c94da12c1" \ + --wandb-project "specforge-training" \ + --wandb-name "llama3-8b-online-fixed-run-1" diff --git a/push_to_hf.sh b/push_to_hf.sh new file mode 100755 index 00000000..d7d926d4 --- /dev/null +++ b/push_to_hf.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'EOF' +Usage: + HF_TOKEN=xxx ./push_to_hf.sh --repo-id --path [--branch main] [--private true|false] [--repo-type model|dataset|space] [--commit "msg"] [--force true|false] + +Examples: + HF_TOKEN=hf_xxx ./push_to_hf.sh \ + --repo-id yourname/llama3-8b-eagle3-lora-fixed \ + --path /sgl-workspace/SpecForge/outputs/llama3-8b-eagle3-lora-fixed/epoch_0/draft_lora \ + --private true +EOF +} + +# defaults +BRANCH="main" +PRIVATE="false" +REPO_TYPE="model" +COMMIT_MSG="" +FORCE="false" + +# parse args +REPO_ID="" +SRC_PATH="" +while [[ $# -gt 0 ]]; do + case "$1" in + --repo-id) REPO_ID="${2:-}"; shift 2;; + --path) SRC_PATH="${2:-}"; shift 2;; + --branch) BRANCH="${2:-}"; shift 2;; + --private) PRIVATE="${2:-}"; shift 2;; + --repo-type) REPO_TYPE="${2:-}"; shift 2;; + --commit) COMMIT_MSG="${2:-}"; shift 2;; + --force) FORCE="${2:-}"; shift 2;; + -h|--help) usage; exit 0;; + *) echo "Unknown arg: $1"; usage; exit 1;; + esac +done + +# validate +: "${HF_TOKEN:?Set HF_TOKEN in env}" +: "${REPO_ID:?--repo-id is required}" +: "${SRC_PATH:?--path is required}" + +if [[ ! -d "$SRC_PATH" ]]; then + echo "Path not found: $SRC_PATH" >&2 + exit 1 +fi + +if ! command -v git >/dev/null 2>&1; then + echo "git not found. Please install git." >&2 + exit 1 +fi +if ! command -v git-lfs >/dev/null 2>&1; then + echo "git-lfs not found. Please install git-lfs." >&2 + exit 1 +fi +if ! command -v hf >/dev/null 2>&1 && ! command -v huggingface-cli >/dev/null 2>&1; then + echo "huggingface CLI not found; attempting to install huggingface_hub..." + python3 -m pip install --user -U huggingface_hub >/dev/null +fi + +# create repo if missing (ignore error if exists) +if command -v hf >/dev/null 2>&1; then + CREATE_FLAGS=(--repo-type "$REPO_TYPE" --token "$HF_TOKEN") + [[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private) + hf repo create "$REPO_ID" "${CREATE_FLAGS[@]}" -y 2>/dev/null || true +else + CREATE_FLAGS=(--type "$REPO_TYPE" -y --token "$HF_TOKEN") + [[ "$PRIVATE" == "true" ]] && CREATE_FLAGS+=(--private) + huggingface-cli repo create "$REPO_ID" "${CREATE_FLAGS[@]}" 2>/dev/null || true +fi + +# commit msg +if [[ -z "$COMMIT_MSG" ]]; then + COMMIT_MSG="Upload from script on $(date -u +'%Y-%m-%dT%H:%M:%SZ')" +fi + +# work in a temp dir to avoid touching source +WORKDIR="$(mktemp -d)" +trap 'rm -rf "$WORKDIR"' EXIT + +# copy content excluding .git +tar -C "$SRC_PATH" --exclude='.git' -cf - . | tar -C "$WORKDIR" -xf - + +cd "$WORKDIR" +git init -q +# Some git-lfs versions do not support -q; keep output quiet via redirection +git lfs install --skip-repo >/dev/null 2>&1 || true + +# sensible LFS defaults for model assets +git lfs track "*.safetensors" "*.bin" "*.pt" "*.ckpt" "*.h5" "*.gguf" "*.onnx" "*.tflite" "*.tar" "*.zip" 2>/dev/null || true +# track large tokenizer assets to satisfy HF pre-receive hooks (>10 MiB) +git lfs track "tokenizer.json" "tokenizer.model" "spiece.model" "sentencepiece.bpe.model" "*.spm" 2>/dev/null || true +echo ".gitattributes" >> .gitignore || true + +git add -A +git commit -m "$COMMIT_MSG" -q + +REMOTE="https://oauth2:${HF_TOKEN}@huggingface.co/${REPO_ID}" +git branch -M "$BRANCH" +git remote add origin "$REMOTE" + +PUSH_FLAGS=() +[[ "$FORCE" == "true" ]] && PUSH_FLAGS+=("--force") +git push "${PUSH_FLAGS[@]}" -u origin "$BRANCH" + +echo "Pushed $SRC_PATH to https://huggingface.co/${REPO_ID} (branch: $BRANCH)" diff --git a/requirements.txt b/requirements.txt index 0de4a941..0df86e17 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ psutil numpy accelerate pydantic +peft diff --git a/scripts/build_eagle3_dataset.py b/scripts/build_eagle3_dataset.py index 13004681..f07beaac 100644 --- a/scripts/build_eagle3_dataset.py +++ b/scripts/build_eagle3_dataset.py @@ -6,8 +6,10 @@ import hashlib import os from pathlib import Path +from typing import Optional import torch +import torch.nn as nn from datasets import load_dataset from transformers import AutoTokenizer diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 7577e974..78727f14 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -33,7 +33,7 @@ def parse_args(): parser.add_argument( "--dataset", type=str, - choices=["ultrachat", "sharegpt", "opc"], + choices=["ultrachat", "sharegpt", "opc", "synth_summarize", "opc"], help="The demo dataset to quickly run the training for speculative decoding", ) parser.add_argument( @@ -110,6 +110,43 @@ def load_dataset_from_path(data_path: Path): import hashlib +def process_opc_sft_stage1(row) -> Dict: + row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() + return { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["instruction"]}, + {"role": "assistant", "content": row["output"]}, + ], + } + + +def process_synth_summarize_row(row) -> Dict: + """Process a row from the synth_summarize dataset. + + The function expects a row with the following schema: + "messages": [ + { + "role": "user" | "assistant", + "content": str + } + ], + "prompt_id": str + """ + conversations = row["messages"] + formatted_conversations = [] + for message in conversations: + role = message["role"] + content = message["content"] + assert role in ["user", "assistant"] + formatted_conversations.append({"role": role, "content": content}) + row = {"id": row["prompt_id"], "conversations": formatted_conversations} + return row, 0 + + +import hashlib + + def process_opc_sft_stage1(row) -> Dict: row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() return { @@ -139,9 +176,23 @@ def main(): "OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" )["train"] proc_fn = process_opc_sft_stage1 + elif args.dataset == "synth_summarize": + if args.data_path is None: + ds = load_dataset("llama-duo/synth_summarize_dataset_dedup")[ + "train_sft_claude3sonnet" + ] + else: + print("Loading dataset from custom data path: ", args.data_path) + ds = load_dataset_from_path(Path(args.data_path)) + proc_fn = process_synth_summarize_row + elif args.dataset == "opc": + ds = load_dataset( + "OpenCoder-LLM/opc-sft-stage1", "largescale_diverse_instruct" + )["train"] + proc_fn = process_opc_sft_stage1 else: raise ValueError( - "This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script." + "This script only supports ultrachat_200k, sharegpt, opc, and synth_summarize datasets for demo purpose, if you wish to use other datasets, please modify this script." ) if args.output_path is None: diff --git a/scripts/train_eagle3_lora_online.py b/scripts/train_eagle3_lora_online.py new file mode 100644 index 00000000..e2e5f431 --- /dev/null +++ b/scripts/train_eagle3_lora_online.py @@ -0,0 +1,706 @@ +import argparse +import hashlib +import json +import os + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from datasets import load_dataset + +# Import PEFT library +from peft import LoraConfig, PeftConfig, TaskType, get_peft_model +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +import wandb +from specforge import ( + AutoDistributedTargetModel, + AutoDraftModelConfig, + AutoEagle3DraftModel, + OnlineEagle3Model, +) +from specforge.data import ( + build_eagle3_dataset, + generate_vocab_mapping_file, + prepare_dp_dataloaders, +) +from specforge.distributed import destroy_distributed, get_dp_group, init_distributed +from specforge.lr_scheduler import CosineAnnealingWarmupLR +from specforge.utils import get_last_checkpoint, print_with_rank, rank_0_priority + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train Eagle3 with online data") + + # add model-related arguments + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--draft-model-config", type=str, required=True) + + # --- MODIFICATION START --- + # Add new parameter for loading pre-trained draft model + parser.add_argument( + "--base-draft-model-path", + type=str, + default=None, + help="Path to a pre-trained base draft model", + ) + # --- MODIFICATION END --- + + parser.add_argument( + "--embedding-key", + type=str, + default="model.embed_tokens.weight", + help="The key of the embedding weight to load from the target model", + ) + + # LoRA configuration + parser.add_argument("--use-lora", action="store_true", help="Enable LoRA training") + parser.add_argument( + "--lora-config", + type=str, + default=None, + help="Path to LoRA config file for draft model", + ) + parser.add_argument( + "--target-lora-path", + type=str, + default=None, + help="Path to pre-trained target LoRA adapter", + ) + + # add training-related arguments + parser.add_argument("--train-data-path", type=str, required=True) + parser.add_argument("--eval-data-path", type=str, default=None) + parser.add_argument("--num-epochs", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--learning-rate", type=float, default=1e-4) + parser.add_argument("--max-length", type=int, default=2048) + parser.add_argument("--warmup-ratio", type=float, default=0.02) + + # data processing type + parser.add_argument("--chat-template", type=str, default="llama3") + + # distributed training + parser.add_argument("--tp-size", type=int, default=1) + + # other args + parser.add_argument("--cache-key", type=str, default=None) + parser.add_argument("--cache-dir", type=str, default="./cache") + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--eval-interval", type=int, default=1) + parser.add_argument("--save-interval", type=int, default=1) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--dist-timeout", + type=int, + default=20, + help="Timeout for collective communication in minutes", + ) + parser.add_argument( + "--skip-vocab-mapping", + action="store_true", + help="Use pretrained vocab mapping without regeneration", + ) + + # resume + parser.add_argument("--resume", action="store_true") + + # wandb wandb args + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--wandb-project", type=str, default=None) + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument("--wandb-key", type=str, default=None) + + args = parser.parse_args() + return args + + +def init_wandb(args): + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project, name=args.wandb_name) + + +def wandb_log_if_initialized(log_dict): + if dist.get_rank() == 0 and wandb.run is not None: + wandb.log(log_dict) + + +def print_on_rank0(message): + if dist.get_rank() == 0: + print(message) + + +def print_trainable_parameters(model, model_name="Model"): + """Print model trainable parameter statistics""" + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + + print_with_rank( + f"{model_name}: {trainable_params:,} trainable parameters out of {all_param:,} total parameters " + f"({100 * trainable_params / all_param:.2f}% trainable)" + ) + + +def load_lora_config(lora_config_path, is_trainable=False): + """Load LoRA configuration from config file""" + if not lora_config_path or not os.path.exists(lora_config_path): + raise ValueError(f"LoRA config file not found: {lora_config_path}") + + print_with_rank(f"Loading LoRA config from: {lora_config_path}") + + # Create LoraConfig from config file + with open(lora_config_path, "r") as f: + config_dict = json.load(f) + + if is_trainable: + config_dict["inference_mode"] = False + + lora_config = LoraConfig(**config_dict) + print_with_rank( + f"Loaded LoRA config: r={lora_config.r}, alpha={lora_config.lora_alpha}, dropout={lora_config.lora_dropout}" + ) + + return lora_config + + +def main(): + # initialize + args = parse_args() + set_seed(args.seed) + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_with_rank(f"Initialized distributed environment") + + if args.wandb and dist.get_rank() == 0: + init_wandb(args) + + # detecting last ckpt for draft model + draft_model_last_checkpoint = None + if args.resume and os.path.isdir(args.output_dir): + print_on_rank0(args.output_dir) + draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + # build target and draft model + if args.tp_size > 1: + # to avoid CPU RAM OOM, we directly init the model on CUDA + target_model = AutoDistributedTargetModel.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + device="cuda", + ) + else: + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + ).cuda() + + # add target LoRA + if args.use_lora: + if args.target_lora_path and os.path.exists(args.target_lora_path): + print_with_rank( + f"Loading pre-trained target LoRA from: {args.target_lora_path}" + ) + # Load configuration from target LoRA path + target_lora_config = load_lora_config( + os.path.join(args.target_lora_path, "adapter_config.json") + ) + target_model = get_peft_model(target_model, target_lora_config) + target_model.load_adapter(args.target_lora_path, "default") + print_with_rank(f"Loaded pre-trained target LoRA adapter") + else: + print_with_rank( + f"No pre-trained target LoRA specified, using base target model" + ) + + # Freeze all parameters of target model (including LoRA) + for param in target_model.parameters(): + param.requires_grad = False + target_model = target_model.eval() + print_with_rank(f"Target model frozen for inference") + else: + target_model = target_model.eval() + + print_with_rank(f"Initialized target model") + + # Modify draft model loading logic + draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) + + if args.base_draft_model_path and os.path.exists(args.base_draft_model_path): + # Load from pre-trained draft model path first + print_with_rank(f"Loading base draft model from: {args.base_draft_model_path}") + draft_model = ( + AutoEagle3DraftModel.from_pretrained(args.base_draft_model_path) + .cuda() + .to(torch.bfloat16) + ) + draft_model.load_embedding( + args.target_model_path, embedding_key=args.embedding_key + ) + draft_model.freeze_embedding() + print_with_rank(f"Loaded pre-trained base draft model.") + + if args.use_lora and not args.lora_config: + for param in draft_model.parameters(): + param.requires_grad = False + print_with_rank(f"Frozen all base draft model parameters") + + # Freeze or add LoRA based on strategy + if args.use_lora and args.lora_config: + # for param in draft_model.parameters(): + # param.requires_grad = False + # print_with_rank(f"Frozen all base draft model parameters") + + # Add LoRA to draft model + draft_lora_config = load_lora_config(args.lora_config, is_trainable=True) + + # PEFT compatibility fix: ensure prepare_inputs_for_generation is accessible + if not hasattr(draft_model, "prepare_inputs_for_generation"): + print_with_rank( + "Warning: draft_model doesn't have prepare_inputs_for_generation, adding compatibility method" + ) + # Add a simple prepare_inputs_for_generation method from GenerationMixin + from transformers.generation.utils import GenerationMixin + + if hasattr(GenerationMixin, "prepare_inputs_for_generation"): + draft_model.prepare_inputs_for_generation = ( + GenerationMixin.prepare_inputs_for_generation.__get__( + draft_model, draft_model.__class__ + ) + ) + else: + print_with_rank("Draft model has prepare_inputs_for_generation method") + + draft_model = get_peft_model(draft_model, draft_lora_config) + draft_model = draft_model.to(torch.bfloat16) + print_with_rank(f"Added new LoRA to draft model for training") + + # Log detailed status of all model parameters + def log_model_parameters(model, model_name): + """Log detailed parameter information for debugging""" + print_with_rank(f"\n=== {model_name} Parameter Details ===") + trainable_count = 0 + frozen_count = 0 + total_params = 0 + + for name, param in model.named_parameters(): + trainable = param.requires_grad + if trainable: + trainable_count += param.numel() + else: + frozen_count += param.numel() + total_params += param.numel() + + # Print detailed information for each parameter + print_with_rank( + f" {name:60s} | " + f"Trainable: {str(trainable):5s} | " + f"Shape: {str(tuple(param.shape)):20s} | " + f"Dtype: {str(param.dtype):10s} | " + f"Device: {str(param.device):10s} | " + f"Params: {param.numel():,}" + ) + + print_with_rank(f"\n{model_name} Summary:") + print_with_rank(f" Total parameters: {total_params:,}") + print_with_rank( + f" Trainable parameters: {trainable_count:,} ({100*trainable_count/total_params:.2f}%)" + ) + print_with_rank( + f" Frozen parameters: {frozen_count:,} ({100*frozen_count/total_params:.2f}%)" + ) + print_with_rank("=" * 80) + + # Log detailed parameter information for target and draft models + log_model_parameters(target_model, "Target Model") + log_model_parameters(draft_model, "Draft Model") + + # # Restore from LoRA checkpoint (only load draft LoRA) + # if draft_model_last_checkpoint: + # draft_lora_path = os.path.join(draft_model_last_checkpoint, "draft_lora") + + # if os.path.exists(draft_lora_path): + # draft_model.load_adapter(draft_lora_path, "default") + # print_with_rank(f"Loaded draft LoRA from checkpoint: {draft_lora_path}") + + print_with_rank(f"Initialized draft model") + + # build dataloaders + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + + # convert to dataloader + cache_key = hashlib.md5(args.train_data_path.encode()).hexdigest() + train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + vocab_mapping_path = None + with rank_0_priority(): + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + ) + if not args.skip_vocab_mapping: + vocab_mapping_path = generate_vocab_mapping_file( + dataset=train_eagle3_dataset, + target_vocab_size=draft_model_config.vocab_size, + draft_vocab_size=draft_model_config.draft_vocab_size, + cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), + cache_key=cache_key, + ) + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.batch_size, + num_workers=4, + shuffle=True, + process_group=get_dp_group(), + ) + print_with_rank(f"Initialized train dataloader") + + # we load the vocab mapping then + if not args.skip_vocab_mapping: + draft_model.load_vocab_mapping(vocab_mapping_path) + print_with_rank(f"Loaded vocab mapping") + + if args.eval_data_path is not None: + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_eagle3_dataset = build_eagle3_dataset( + eval_dataset, + tokenizer, + args.chat_template, + args.max_length, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=4, + shuffle=False, + process_group=get_dp_group(), + ) + print_with_rank(f"Initialized eval dataloader") + + # build Eagle3 model + # broadcast draft model + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + ) + # eagle3_model = DDP(eagle3_model, find_unused_parameters=True) + # Target model is always ignored (target is frozen regardless of whether LoRA is used) + eagle3_model = FSDP( + eagle3_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + ignored_modules=[target_model], + process_group=get_dp_group(), + ) + print_with_rank(f"Initialized Eagle3 FSDP model") + + # Print parameter statistics + if args.use_lora: + print_trainable_parameters(target_model, "Target Model (Frozen)") + print_trainable_parameters(draft_model, "Draft Model (LoRA Only)") + print_trainable_parameters(eagle3_model, "Eagle3 Model (Overall)") + + # build other components + has_trainable_params = any( + param.requires_grad for param in eagle3_model.parameters() + ) + optimizer = torch.optim.AdamW(eagle3_model.parameters(), lr=args.learning_rate) + + if args.use_lora: + # Count LoRA parameters for logging + lora_param_count = sum( + 1 + for name, param in eagle3_model.named_parameters() + if param.requires_grad + and "draft_model" in name + and ("lora_" in name or "adapter" in name) + ) + print_with_rank( + f"Optimizer will train {lora_param_count} LoRA parameters out of total parameters" + ) + else: + trainable_param_count = sum( + 1 for param in eagle3_model.parameters() if param.requires_grad + ) + print_with_rank( + f"Optimizer configured for {trainable_param_count} trainable parameters" + ) + + total_steps = args.num_epochs * len(train_dataloader) + warmup_steps = int(total_steps * args.warmup_ratio) + scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=total_steps, warmup_steps=warmup_steps + ) + print_with_rank(f"Initialized optimizer and scheduler") + + # resume + start_epoch = 0 + if draft_model_last_checkpoint is not None: + print_on_rank0( + f"Resuming draft model training from checkpoint: {draft_model_last_checkpoint}" + ) + state_path = os.path.join(draft_model_last_checkpoint, "training_state.pt") + + if os.path.exists(state_path): + state = torch.load(state_path, map_location="cpu", weights_only=False) + + optimizer.load_state_dict(state["optimizer_state_dict"]) + print_on_rank0("Successfully loaded optimizer state_dict.") + + scheduler.load_state_dict(state["scheduler_state_dict"]) + print_on_rank0("Successfully loaded scheduler state_dict.") + + start_epoch = state["epoch"] + 1 + print_on_rank0(f"Resuming from epoch {start_epoch}") + else: + print_on_rank0( + f"Warning: Checkpoint directory {draft_model_last_checkpoint} found, but training_state.pt is missing. Starting from scratch." + ) + + dist.barrier() + + # start running + print_on_rank0(f"Starting training from epoch {start_epoch}") + for epoch in range(start_epoch, args.num_epochs): + # Run training / inference-only depending on whether any params are trainable + train_dataloader.sampler.set_epoch(epoch + 1) + if has_trainable_params: + draft_model.train() + eagle3_model.train() + else: + draft_model.eval() + eagle3_model.eval() + epoch_acces = [[] for _ in range(eagle3_model.module.length)] + epoch_plosses = [[] for _ in range(eagle3_model.module.length)] + + for data in tqdm(train_dataloader, desc=f"Training Epoch {epoch}"): + optimizer.zero_grad() + with torch.set_grad_enabled(has_trainable_params): + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) + + # calculate weighted loss + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + + if has_trainable_params: + ploss.backward() + if args.use_lora: + grad_log_dict = {} + lora_param_idx = 0 + for name, param in eagle3_model.named_parameters(): + if ( + param.requires_grad + and "draft_model" in name + and ("lora_" in name or "adapter" in name) + ): + if param.grad is not None: + grad_norm = param.grad.norm().item() + print_on_rank0( + f"LoRA param {lora_param_idx} ({name}) grad norm: {grad_norm}" + ) + grad_log_dict[ + f"train/lora_grad_norm_{lora_param_idx}" + ] = grad_norm + lora_param_idx += 1 + if grad_log_dict: + wandb_log_if_initialized(grad_log_dict) + optimizer.step() + scheduler.step() + + logdict = {"train/lr": optimizer.param_groups[0]["lr"]} + for i in range(len(plosses)): + logdict[f"train/ploss_{i}"] = plosses[i].item() + for i in range(len(acces)): + logdict[f"train/acc_{i}"] = acces[i] + wandb_log_if_initialized(logdict) + + epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))] + epoch_plosses = [ + epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses)) + ] + + for i in range(len(epoch_acces)): + acc_i = torch.tensor(epoch_acces[i]).cuda().mean() + dist.all_reduce(acc_i) + acc_i = acc_i / dist.get_world_size() + acc_i = acc_i.item() + wandb_log_if_initialized({f"train/epochacc_{i}": acc_i}) + print_on_rank0( + f"Train Epoch [{epoch + 1}/{args.num_epochs}], position {i}, Acc: {acc_i:.2f}" + ) + + for i in range(len(epoch_plosses)): + loss_i = torch.tensor(epoch_plosses[i]).cuda().mean() + dist.all_reduce(loss_i) + loss_i = loss_i / dist.get_world_size() + loss_i = loss_i.item() + wandb_log_if_initialized({f"train/epochploss_{i}": loss_i}) + print_on_rank0( + f"Train Epoch [{epoch + 1}/{args.num_epochs}], position {i}, pLoss: {loss_i:.2f}" + ) + + # run evaluation + if args.eval_data_path is not None and epoch % args.eval_interval == 0: + # Run evaluation + draft_model.eval() + eval_acces = [[] for _ in range(eagle3_model.length)] + eval_plosses = [[] for _ in range(eagle3_model.length)] + + for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) + eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))] + eval_plosses = [ + eval_plosses[i] + [plosses[i].item()] for i in range(len(plosses)) + ] + + for i in range(len(eval_acces)): + acc_i = torch.tensor(eval_acces[i]).cuda().mean() + dist.all_reduce(acc_i) + acc_i = acc_i / dist.get_world_size() + acc_i = acc_i.item() + + wandb_log_if_initialized({f"eval/epochacc_{i}": acc_i}) + print_on_rank0( + f"Eval Epoch [{epoch + 1}/{args.num_epochs}], position {i}, Acc: {acc_i:.2f}" + ) + + for i in range(len(eval_plosses)): + loss_i = torch.tensor(eval_plosses[i]).cuda().mean() + dist.all_reduce(loss_i) + loss_i = loss_i / dist.get_world_size() + loss_i = loss_i.item() + + wandb_log_if_initialized({f"eval/epochploss_{i}": loss_i}) + print_on_rank0( + f"Eval Epoch [{epoch + 1}/{args.num_epochs}], position {i}, pLoss: {loss_i:.2f}" + ) + + if epoch % args.save_interval == 0: + # Save the model + epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}") + + if dist.get_rank() == 0: + os.makedirs(epoch_output_dir, exist_ok=True) + dist.barrier() + + # Only gather a full state dict on rank 0 to reduce sync pressure + with FSDP.state_dict_type( + eagle3_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state_dict = eagle3_model.state_dict() + state_to_save = { + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + + if dist.get_rank() == 0: + torch.save( + state_to_save, + os.path.join(epoch_output_dir, "training_state.pt"), + ) + print_on_rank0( + f"Saved full training state to {epoch_output_dir}/training_state.pt" + ) + + if args.use_lora and args.lora_config: + # Manually extract and save draft model LoRA weights (from FSDP state_dict) + draft_lora_output_dir = os.path.join( + epoch_output_dir, "draft_lora" + ) + os.makedirs(draft_lora_output_dir, exist_ok=True) + + # Extract LoRA related weights + lora_state_dict = {} + for key, value in model_state_dict.items(): + if "draft_model." in key and ( + "lora_" in key or "adapter" in key + ): + # Remove "draft_model." prefix because we want to save the LoRA weights inside the draft model + lora_key = key.replace("draft_model.", "") + lora_state_dict[lora_key] = value + + if lora_state_dict: + # Save LoRA weights + import safetensors.torch as st + + st.save_file( + lora_state_dict, + os.path.join( + draft_lora_output_dir, "adapter_model.safetensors" + ), + ) + print_on_rank0( + f"Saved {len(lora_state_dict)} LoRA weights to {draft_lora_output_dir}/adapter_model.safetensors" + ) + + # Save LoRA configuration file + import shutil + + shutil.copy2( + args.lora_config, + os.path.join( + draft_lora_output_dir, "adapter_config.json" + ), + ) + print_on_rank0( + f"Copied LoRA config to {draft_lora_output_dir}/adapter_config.json" + ) + else: + print_on_rank0( + "Warning: No LoRA weights found in state_dict!" + ) + + # Save tokenizer to draft_lora directory + tokenizer.save_pretrained(draft_lora_output_dir) + print_on_rank0(f"Saved tokenizer to {draft_lora_output_dir}") + else: + # Original logic: save draft model state + draft_model_state_dict = { + k.replace("draft_model.", ""): v + for k, v in model_state_dict.items() + if "draft_model." in k + } + draft_model.save_pretrained( + epoch_output_dir, + state_dict=draft_model_state_dict, + ) + print_on_rank0( + f"Saved model configuration to {epoch_output_dir}" + ) + + # Save tokenizer + tokenizer.save_pretrained(epoch_output_dir) + print_on_rank0(f"Saved tokenizer to {epoch_output_dir}") + # Avoid a trailing barrier here to reduce chances of hanging on sync + + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index a7dad45c..f3b4b8f6 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist -import wandb from accelerate.utils import set_seed from datasets import load_dataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -12,6 +11,7 @@ from tqdm import tqdm from transformers import AutoTokenizer +import wandb from specforge import AutoDraftModelConfig, AutoEagle3DraftModel, OfflineEagle3Model from specforge.data import ( build_eagle3_dataset, diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py index 276544e7..2ef09621 100644 --- a/scripts/train_eagle3_online.py +++ b/scripts/train_eagle3_online.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist -import wandb from accelerate.utils import set_seed from datasets import load_dataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -12,6 +11,7 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer +import wandb from specforge import ( AutoDistributedTargetModel, AutoDraftModelConfig, diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index 619fbd66..814b683d 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -30,13 +30,13 @@ import torch.nn as nn from huggingface_hub import snapshot_download from safetensors import safe_open -from transformers import PreTrainedModel +from transformers import GenerationMixin, PreTrainedModel from transformers.cache_utils import Cache from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask -class Eagle3DraftModel(PreTrainedModel, ABC): +class Eagle3DraftModel(PreTrainedModel, GenerationMixin, ABC): """ This is the base class for the Eagle3 draft model implementation. The child class needs to implement the abstract methods to support training with TTT. diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 41e69a49..ddb1a8e4 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -698,6 +698,18 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.register_buffer("t2d", t2d) self.register_buffer("d2t", d2t) + # Initialize the model completely before setting up PEFT compatibility + self.post_init() + + # Ensure this method is explicitly available for PEFT compatibility + def get_input_embeddings(self): + """Required for PEFT compatibility.""" + return self.embed_tokens + + def set_input_embeddings(self, value): + """Required for PEFT compatibility.""" + self.embed_tokens = value + def forward( self, hidden_states: torch.Tensor, @@ -720,6 +732,25 @@ def forward( logger.info(f"using ttt_length {ttt_length}, caching hidden states") cache_hidden = [[], []] + if not hasattr(self, "_lora_logged"): + print( + "self.midlayer.self_attn.q_proj type:", + type(self.midlayer.self_attn.q_proj), + ) + print( + "self.midlayer.self_attn.k_proj type:", + type(self.midlayer.self_attn.k_proj), + ) + print( + "self.midlayer.self_attn.v_proj type:", + type(self.midlayer.self_attn.v_proj), + ) + print( + "self.midlayer.self_attn.o_proj type:", + type(self.midlayer.self_attn.o_proj), + ) + self._lora_logged = True + batch_size, seq_length, _ = hidden_states.size() # make position ids @@ -786,3 +817,50 @@ def backbone( output_attentions=False, use_cache=False, ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + """ + Prepare inputs for generation. This method is required for PEFT compatibility. + """ + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs