Skip to content

Commit 28f3eb2

Browse files
committed
add gpt oss
1 parent 48615dd commit 28f3eb2

File tree

8 files changed

+90
-15
lines changed

8 files changed

+90
-15
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,14 @@ Choose your path:
118118

119119
## Changelog
120120

121-
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. Please install transformers from **main** branch to use.
121+
[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started.
122+
123+
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model.
122124

123125
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
124126

127+
<details><summary>Full Changelog</summary>
128+
125129
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
126130

127131
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
@@ -130,8 +134,6 @@ Choose your path:
130134

131135
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
132136

133-
<details><summary>Full Changelog</summary>
134-
135137
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
136138

137139
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
@@ -268,6 +270,7 @@ Choose your path:
268270
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
269271
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
270272
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
273+
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
271274
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
272275
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
273276
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |

README_zh.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
120120

121121
## 更新日志
122122

123-
[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。请安装 transformers 的 main 分支版本以使用。
123+
[25/08/06] 我们支持了 **[GPT-OSS](https://github.com/openai/gpt-oss)** 模型的微调。查看 [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) 以使用。
124+
125+
[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。
124126

125127
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
126128

129+
<details><summary>展开日志</summary>
130+
127131
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
128132

129133
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
@@ -132,8 +136,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
132136

133137
[25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。
134138

135-
<details><summary>展开日志</summary>
136-
137139
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
138140

139141
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
@@ -270,6 +272,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
270272
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
271273
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
272274
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
275+
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
273276
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
274277
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
275278
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
### model
2+
model_name_or_path: openai/gpt-oss-20b
3+
trust_remote_code: true
4+
5+
### method
6+
stage: sft
7+
do_train: true
8+
finetuning_type: lora
9+
lora_rank: 8
10+
lora_target: all
11+
12+
### dataset
13+
dataset: identity,alpaca_en_demo
14+
template: gpt
15+
cutoff_len: 2048
16+
max_samples: 1000
17+
overwrite_cache: true
18+
preprocessing_num_workers: 16
19+
dataloader_num_workers: 4
20+
21+
### output
22+
output_dir: saves/llama3-8b/lora/sft
23+
logging_steps: 10
24+
save_steps: 500
25+
plot_loss: true
26+
overwrite_output_dir: true
27+
save_only_model: false
28+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
29+
30+
### train
31+
per_device_train_batch_size: 1
32+
gradient_accumulation_steps: 8
33+
learning_rate: 1.0e-4
34+
num_train_epochs: 3.0
35+
lr_scheduler_type: cosine
36+
warmup_ratio: 0.1
37+
bf16: true
38+
ddp_timeout: 180000000
39+
resume_from_checkpoint: null
40+
41+
### eval
42+
# eval_dataset: alpaca_en_demo
43+
# val_size: 0.1
44+
# per_device_eval_batch_size: 1
45+
# eval_strategy: steps
46+
# eval_steps: 500

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# core deps
2-
transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin'
3-
transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin'
2+
transformers>=4.49.0,<=4.55.0,!=4.52.0
43
datasets>=2.16.0,<=3.6.0
54
accelerate>=1.3.0,<=1.7.0
65
peft>=0.14.0,<=0.15.2

src/llamafactory/data/template.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,15 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
10631063
)
10641064

10651065

1066+
register_template(
1067+
name="gpt",
1068+
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
1069+
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
1070+
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
1071+
default_system="You are ChatGPT, a large language model trained by OpenAI.",
1072+
)
1073+
1074+
10661075
register_template(
10671076
name="granite3",
10681077
format_user=StringFormatter(

src/llamafactory/extras/constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,21 @@ def register_model_group(
945945
)
946946

947947

948+
register_model_group(
949+
models={
950+
"GPT-OSS-20B-Thinking": {
951+
DownloadSource.DEFAULT: "openai/gpt-oss-20b",
952+
DownloadSource.MODELSCOPE: "openai/gpt-oss-20b",
953+
},
954+
"GPT-OSS-120B-Thinking": {
955+
DownloadSource.DEFAULT: "openai/gpt-oss-120b",
956+
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
957+
},
958+
},
959+
template="gpt",
960+
)
961+
962+
948963
register_model_group(
949964
models={
950965
"Granite-3.0-1B-A400M-Base": {

src/llamafactory/extras/misc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import gc
1919
import os
2020
import socket
21-
from typing import TYPE_CHECKING, Any, Literal, Union
21+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
2222

2323
import torch
2424
import torch.distributed as dist
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
9494

9595
def check_dependencies() -> None:
9696
r"""Check the version of the required packages."""
97-
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
97+
check_version("transformers>=4.49.0,<=4.55.0")
9898
check_version("datasets>=2.16.0,<=3.6.0")
9999
check_version("accelerate>=1.3.0,<=1.7.0")
100100
check_version("peft>=0.14.0,<=0.15.2")
@@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
211211
return os.path.isdir(path) and len(os.listdir(path)) > 0
212212

213213

214-
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
214+
def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype":
215215
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
216-
if _is_bf16_available and model_dtype == torch.bfloat16:
216+
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
217217
return torch.bfloat16
218218
elif _is_fp16_available:
219219
return torch.float16

src/llamafactory/model/loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ def load_model(
156156
if model_args.mixture_of_depths == "load":
157157
model = load_mod_pretrained_model(**init_kwargs)
158158
else:
159-
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
160-
load_class = AutoModelForVision2Seq
161-
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
159+
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
162160
load_class = AutoModelForImageTextToText
161+
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
162+
load_class = AutoModelForVision2Seq
163163
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
164164
load_class = AutoModelForSeq2SeqLM
165165
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni

0 commit comments

Comments
 (0)