-
Notifications
You must be signed in to change notification settings - Fork 701
Add KD distributed recipe #1631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 38 commits
2933cd6
bff065a
9dd7b47
a39e99c
6dbcd38
0c4e4f9
380f267
da2b4bb
8beaca0
b54929a
b31c56d
fe5ed97
8b9ea41
3f7fe70
a87aa0c
f5feac4
8c3c42a
6ba0514
04ea649
62faa1d
bf15406
ac9eb0e
87a80b6
1fc3f64
106aa3e
0f4e922
22fddca
526a4dc
0bb49dc
c73857d
59eff44
85d76bb
dba57c4
04e2282
44123b9
a04244d
1ff9934
703e7dc
0031bfb
307791d
fefc24d
15c5be2
46473ee
2e212ec
557396e
4d376e3
53c47ba
f193d02
227e69d
cf5f01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
| # using a teacher and student model | ||
| # | ||
| # This config assumes that you've ran the following commands before launching KD: | ||
| # First download the student and teacher models | ||
| # tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None | ||
| # tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None | ||
| # | ||
| # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
| # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora | ||
| # | ||
| # To launch on a single device, run the following command from root: | ||
| # tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed | ||
| # | ||
| # This config works only for distilling on a single device. | ||
|
|
||
|
|
||
| # Model Arguments | ||
| model: | ||
| _component_: torchtune.models.qwen2.lora_qwen2_0_5b | ||
| lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
| apply_lora_to_mlp: False | ||
| lora_rank: 32 | ||
| lora_alpha: 64 | ||
|
|
||
| teacher_model: | ||
| _component_: torchtune.models.qwen2.qwen2_1_5b | ||
|
|
||
| tokenizer: | ||
| _component_: torchtune.models.qwen2.qwen2_tokenizer | ||
| path: /tmp/Qwen2-0.5B-Instruct/vocab.json | ||
| merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt | ||
| max_seq_len: null | ||
|
|
||
| checkpointer: | ||
| _component_: torchtune.training.FullModelHFCheckpointer | ||
| checkpoint_dir: /tmp/Qwen2-0.5B-Instruct | ||
| checkpoint_files: [ | ||
| model.safetensors | ||
| ] | ||
| recipe_checkpoint: null | ||
| output_dir: /tmp/Qwen2-0.5B-Instruct-kd | ||
| model_type: QWEN2 | ||
|
|
||
| teacher_checkpointer: | ||
| _component_: torchtune.training.FullModelHFCheckpointer | ||
| checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
| checkpoint_files: [ | ||
| hf_model_0001_0.pt | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the default If so, the first example (Llama3.2) is wrong b/c the default files are safetensors, which are only saved if the checkpointer specifies
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used the default lora distributed finetune configs for qwen2 and llama3.1 8b. I'm not sure why qwen2/1.5B_lora outputs |
||
| ] | ||
| recipe_checkpoint: null | ||
| output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
| model_type: QWEN2 | ||
|
|
||
| resume_from_checkpoint: False | ||
|
|
||
| # Dataset and Sampler | ||
| dataset: | ||
| _component_: torchtune.datasets.alpaca_cleaned_dataset | ||
| seed: null | ||
| shuffle: True | ||
| batch_size: 8 | ||
|
|
||
| # Optimizer and Scheduler | ||
| optimizer: | ||
| _component_: torch.optim.AdamW | ||
| weight_decay: 0.01 | ||
| lr: 3e-4 | ||
| lr_scheduler: | ||
| _component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
| num_warmup_steps: 100 | ||
|
|
||
| loss: | ||
| _component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
|
||
| kd_loss: | ||
| _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
| kd_ratio: 0.5 | ||
|
|
||
| # Training | ||
| epochs: 1 | ||
| max_steps_per_epoch: null | ||
| gradient_accumulation_steps: 2 | ||
|
|
||
| # Logging | ||
| output_dir: /tmp/qwen_kd | ||
| metric_logger: | ||
| _component_: torchtune.training.metric_logging.DiskLogger | ||
| log_dir: ${output_dir} | ||
| log_every_n_steps: 1 | ||
| log_peak_memory_stats: False | ||
|
|
||
| # Environment | ||
| device: cuda | ||
| dtype: bf16 | ||
| enable_activation_checkpointing: True | ||
|
||
|
|
||
| # Show case the usage of pytorch profiler | ||
| # Set enabled to False as it's only needed for debugging training | ||
| profiler: | ||
| _component_: torchtune.training.setup_torch_profiler | ||
|
|
||
| enabled: False | ||
|
|
||
| #Output directory of trace artifacts | ||
| output_dir: ${output_dir}/profiling_outputs | ||
|
|
||
| #`torch.profiler.ProfilerActivity` types to trace | ||
| cpu: True | ||
| cuda: True | ||
|
|
||
| #trace options passed to `torch.profiler.profile` | ||
| profile_memory: False | ||
| with_stack: False | ||
| record_shapes: True | ||
| with_flops: False | ||
|
|
||
| # `torch.profiler.schedule` options: | ||
| # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
| wait_steps: 5 | ||
| warmup_steps: 5 | ||
| active_steps: 2 | ||
| num_cycles: 1 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update this