Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/aishell/asr/finetune_osum.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2025 Chengdong Liang(liangchengdongd@qq.com)

# This script is used to fine-tune [OSUM](https://huggingface.co/ASLP-lab/OSUM).

[ ! -s west ] && ln -s ../../../west
[ ! -s tools ] && ln -s ../../../tools
export PYTHONPATH=$PYTHONPATH:$PWD
# Change this to all your available gpus, such as "0,1,2,3"
export CUDA_VISIBLE_DEVICES="0"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}')

stage=train # data/convert_model/train/decode
data=data
osum_model_path=/path/to/osum/infer.pt
dir=exp/OSUM-finetune
steps=5000 # training steps
# prompt config: https://github.com/ASLP-lab/OSUM/blob/main/OSUM/conf/prompt_config.yaml
osum_prompt_config=/path/to/osum/prompt_config.json

. tools/parse_options.sh

if [ $stage == "data" ] || [ $stage == "all" ]; then
echo "Prepare required data"
fi

if [ $stage == "convert_model" ] || [ $stage == "all" ]; then
echo "Convert osum model to west style"
python tools/convert_scripts/convert_osum.py \
--osum_model_path $osum_model_path \
--llm_model_dir Qwen/Qwen2-7B-Instruct/ \
--wenet_model_dir whiper-medium \
--prompt_config_path $osum_prompt_config \
--output_dir $dir/osum-west-test
fi

if [ $stage == "train" ] || [ $stage == "all" ]; then
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus west/bin/train.py \
--model_config_or_dir $dir/osum-west \
--data_path $data/train.jsonl \
--output_dir $dir \
--pack_size 8192 \
--bf16 True \
--max_steps $steps \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 100 \
--learning_rate 3e-4 \
--weight_decay 0.01 \
--adam_beta2 0.95 \
--warmup_ratio 0.5 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "tensorboard" \
--gradient_checkpointing \
--dataloader_num_workers 2 \
--dataloader_prefetch_factor 10 \
--ignore_data_skip True \
--deepspeed conf/ds_config_zero2.json \
--accelerator_config conf/accelerator_config.json
fi

if [ $stage == "decode" ] || [ $stage == "all" ]; then
# mdir=$dir/checkpoint-${steps}
mdir=$dir/osum-west
python west/bin/decode.py \
--data_path $data/aishell2_test/test.jsonl \
--model_dir $mdir \
--result_path $mdir/result.jsonl
python tools/compute_wer.py --char=1 --v=1 \
$data/test.jsonl $mdir/result.jsonl > $mdir/result.wer
fi
104 changes: 104 additions & 0 deletions tools/convert_scripts/convert_osum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2025 Chengdong Liang(liangchengdongd@qq.com)

import argparse
import json
import os

import torch
from transformers import AutoConfig, AutoModel

import west # for init touchasu model # noqa


def convert_to_west_state_dict(osum_state_dict):
west_state_dict = {}
for name in osum_state_dict.keys():
if name.startswith("encoder."):
new_name = name.replace("encoder.", "encoder.encoder.")
west_state_dict[new_name] = osum_state_dict[name]
elif name.startswith("speech_transformer."):
new_name = name.replace("speech_transformer.",
"projector.speech_transformer.")
west_state_dict[new_name] = osum_state_dict[name]
elif name.startswith("llama_model."):
new_name = name.replace("llama_model.", "llm.")
west_state_dict[new_name] = osum_state_dict[name]
elif name.startswith("down_sample_2."):
new_name = name.replace("down_sample_2.", "projector.")
west_state_dict[new_name] = osum_state_dict[name]
elif name.startswith("speech_llama_proj."):
new_name = name.replace("speech_llama_proj.",
"projector.speech_llama_proj.")
west_state_dict[new_name] = osum_state_dict[name]
elif name.startswith('speech_token_emded.'):
west_state_dict[name] = osum_state_dict[name]

return west_state_dict


def get_configs(llm_model_dir, wenet_model_dir, prompt_config_path):
configs = {
"llm_model_name_or_path": llm_model_dir,
"lora_config": {
"inference_mode": False,
"lora_alpha": 32,
"lora_dropout": 0.1,
"r": 8,
"target_modules": [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"down_proj"
],
"task_type": "CAUSAL_LM"
},
"model_type": "osum",
"no_init_llm": False,
"projector_type": "transformer",
"transformers_version": "4.52.3",
"wenet_model_name_or_path": wenet_model_dir,
"prompt_conf_path": prompt_config_path,
}
return configs


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--osum_model_path", type=str, required=True)
parser.add_argument("--llm_model_dir", type=str, required=True)
parser.add_argument("--wenet_model_dir", type=str, required=True)
parser.add_argument("--prompt_config_path", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
return args


def main():
args = get_args()
checkpoint = torch.load(args.osum_model_path,
map_location="cpu",
weights_only=False)
os.makedirs(args.output_dir)
state_dict = convert_to_west_state_dict(checkpoint)

with open(os.path.join(args.output_dir, "config.json"), "w") as f:
configs = get_configs(args.llm_model_dir,
args.wenet_model_dir,
args.prompt_config_path)
print(configs)
json.dump(configs, f, indent=4)

config = AutoConfig.from_pretrained(f'{args.output_dir}/config.json')
model = AutoModel.from_config(config)
tokenizer = model.init_tokenizer()
print("Loading osum weights...")
model.load_state_dict(state_dict, strict=False)
print("Saving west model...")
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion west/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import AutoConfig, AutoModel

from west.models.osum_echat import OSUMEChat, OSUMEChatConfig
from west.models.osum_echat import OSUM, OSUMConfig, OSUMEChat, OSUMEChatConfig
from west.models.touch_asu import TouchASU, TouchASUConfig
from west.models.touch_chat import TouchChat, TouchChatConfig
from west.models.touch_flow import TouchFlow, TouchFlowConfig
Expand All @@ -19,3 +19,5 @@

AutoConfig.register("osum_echat", OSUMEChatConfig)
AutoModel.register(OSUMEChatConfig, OSUMEChat)
AutoConfig.register("osum", OSUMConfig)
AutoModel.register(OSUMConfig, OSUM)
6 changes: 3 additions & 3 deletions west/models/osum_echat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .configuration_osum_echat import OSUMEChatConfig # noqa
from .extractor_osum_echat import ExtractorOSUMEChat # noqa
from .modeling_osum_echat import OSUMEChat # noqa
from .configuration_osum_echat import OSUMConfig, OSUMEChatConfig # noqa
from .extractor_osum_echat import ExtractorOSUM, ExtractorOSUMEChat # noqa
from .modeling_osum_echat import OSUM, OSUMEChat # noqa
34 changes: 33 additions & 1 deletion west/models/osum_echat/configuration_osum_echat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,36 @@ def __init__(
self.speech_token_num = speech_token_num


__all__ = ["OSUMEChatConfig"]
class OSUMConfig(PretrainedConfig):
model_type = "osum"

def __init__(
self,
llm_model_name_or_path: str = 'Qwen/Qwen2.5-3B-Instruct',
no_init_llm: bool = True,
wenet_model_name_or_path: str = 'whisper-medium',
encoder_ds_rate: int = 2,
encoder_projector_ds_rate: int = 4,
hidden_size: int = 0, # Will override in OSUM Model
lora_config: Optional[Dict[str, Any]] = None,
freeze_encoder: bool = False,
freeze_llm: bool = False,
speech_token_num: int = 4097,
prompt_conf_path: str = 'conf/prompt_config.yaml',
**kwargs,
):
super().__init__(**kwargs)
self.llm_model_name_or_path = llm_model_name_or_path
self.no_init_llm = no_init_llm
self.wenet_model_name_or_path = wenet_model_name_or_path
self.encoder_ds_rate = encoder_ds_rate
self.encoder_projector_ds_rate = encoder_projector_ds_rate
self.hidden_size = hidden_size
self.lora_config = lora_config
self.freeze_encoder = freeze_encoder
self.freeze_llm = freeze_llm
self.speech_token_num = speech_token_num
self.prompt_conf_path = prompt_conf_path


__all__ = ["OSUMEChatConfig", "OSUMConfig"]
76 changes: 75 additions & 1 deletion west/models/osum_echat/extractor_osum_echat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,83 @@
# Copyright (c) 2025 Xuelong Geng(xlgeng@mail.nwpu.edu.cn)

import math
import random

import torch
import wenet
from gxl_ai_utils.utils import utils_file
from transformers.trainer_pt_utils import LabelSmoother

from west.dataset.extractor import Extractor


class ExtractorOSUMEChat(Extractor):
class ExtractorOSUM(Extractor):
model_type = "osum"
fields_batch_static = {'audio_offsets'}
fields_batch_dynamic = {'audio_features', 'input_ids', 'labels'}
fields_pack_offset = {'audio_offsets'}

def __init__(self, tokenizer, model_config, inference=False):
super().__init__(tokenizer, model_config, inference)
self.compute_feature, self.feature_dim = wenet.load_feature(
self.model_config.wenet_model_name_or_path
)
self.ds_rate = (self.model_config.encoder_ds_rate *
self.model_config.encoder_projector_ds_rate)

self.global_prompt_dict = utils_file.load_dict_from_yaml(
model_config.prompt_conf_path)

def extract(self, item):
"""
{'wav': wav_path, 'txt': text, 'task': '<TRANSCRIBE>'}
"""
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
t0 = ''
if 'task' in item:
task_name = item['task']
try:
random_index = random.randint(
0, len(self.global_prompt_dict[task_name]) - 1)
prompt = self.global_prompt_dict[task_name][random_index]
if prompt != "<no_prompt>":
t0 = prompt
except Exception as e:
print(f"Error: {e}")
else:
task_name = '<TRANSCRIBE>' # default: speech recognition
try:
random_index = random.randint(
0, len(self.global_prompt_dict[task_name]) - 1)
prompt = self.global_prompt_dict[task_name][random_index]
t0 = prompt
except Exception as e:
print(f"Error: {e}")

mel = self.compute_feature(item['wav'])
ids_audio = [0] * (math.ceil(mel.size(0) / self.ds_rate) + 2)

ids0 = self.tokenizer.encode(t0)
ids = ids0 + ids_audio
tgt = [IGNORE_TOKEN_ID] * len(ids)

if not self.inference:
ids1 = self.tokenizer.encode(item['txt'] + "<|endoftext|>")
ids = ids + ids1
tgt = tgt + ids1

input_ids = torch.tensor(ids, dtype=torch.int)
tgt_ids = torch.tensor(tgt, dtype=torch.long)
return {
'input_ids': input_ids,
'labels': tgt_ids,
'audio_features': mel,
'audio_offsets': len(ids0),
}


class ExtractorOSUMEChat(ExtractorOSUM):
model_type = "osum_echat"

def __init__(self, tokenizer, model_config, inference=False):
super().__init__(tokenizer, model_config, inference)
Expand Down
Loading