From 59f44fe267dde729b7916d4cf496fd0a855919f1 Mon Sep 17 00:00:00 2001 From: cdliang11 <1404056823@qq.com> Date: Tue, 18 Nov 2025 16:42:14 +0800 Subject: [PATCH] [osum] support finetuning for osum --- examples/aishell/asr/finetune_osum.sh | 74 +++++ tools/convert_scripts/convert_osum.py | 104 ++++++ west/__init__.py | 4 +- west/models/osum_echat/__init__.py | 6 +- .../osum_echat/configuration_osum_echat.py | 34 +- .../models/osum_echat/extractor_osum_echat.py | 76 ++++- west/models/osum_echat/modeling_osum_echat.py | 305 +++++++++++++----- 7 files changed, 509 insertions(+), 94 deletions(-) create mode 100644 examples/aishell/asr/finetune_osum.sh create mode 100644 tools/convert_scripts/convert_osum.py diff --git a/examples/aishell/asr/finetune_osum.sh b/examples/aishell/asr/finetune_osum.sh new file mode 100644 index 0000000..46392f7 --- /dev/null +++ b/examples/aishell/asr/finetune_osum.sh @@ -0,0 +1,74 @@ +# Copyright 2025 Chengdong Liang(liangchengdongd@qq.com) + +# This script is used to fine-tune [OSUM](https://huggingface.co/ASLP-lab/OSUM). + +[ ! -s west ] && ln -s ../../../west +[ ! -s tools ] && ln -s ../../../tools +export PYTHONPATH=$PYTHONPATH:$PWD +# Change this to all your available gpus, such as "0,1,2,3" +export CUDA_VISIBLE_DEVICES="0" +num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F ',' '{print NF}') + +stage=train # data/convert_model/train/decode +data=data +osum_model_path=/path/to/osum/infer.pt +dir=exp/OSUM-finetune +steps=5000 # training steps +# prompt config: https://github.com/ASLP-lab/OSUM/blob/main/OSUM/conf/prompt_config.yaml +osum_prompt_config=/path/to/osum/prompt_config.json + +. tools/parse_options.sh + +if [ $stage == "data" ] || [ $stage == "all" ]; then + echo "Prepare required data" +fi + +if [ $stage == "convert_model" ] || [ $stage == "all" ]; then + echo "Convert osum model to west style" + python tools/convert_scripts/convert_osum.py \ + --osum_model_path $osum_model_path \ + --llm_model_dir Qwen/Qwen2-7B-Instruct/ \ + --wenet_model_dir whiper-medium \ + --prompt_config_path $osum_prompt_config \ + --output_dir $dir/osum-west-test +fi + +if [ $stage == "train" ] || [ $stage == "all" ]; then + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus west/bin/train.py \ + --model_config_or_dir $dir/osum-west \ + --data_path $data/train.jsonl \ + --output_dir $dir \ + --pack_size 8192 \ + --bf16 True \ + --max_steps $steps \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --save_strategy "steps" \ + --save_steps 100 \ + --save_total_limit 100 \ + --learning_rate 3e-4 \ + --weight_decay 0.01 \ + --adam_beta2 0.95 \ + --warmup_ratio 0.5 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --report_to "tensorboard" \ + --gradient_checkpointing \ + --dataloader_num_workers 2 \ + --dataloader_prefetch_factor 10 \ + --ignore_data_skip True \ + --deepspeed conf/ds_config_zero2.json \ + --accelerator_config conf/accelerator_config.json +fi + +if [ $stage == "decode" ] || [ $stage == "all" ]; then + # mdir=$dir/checkpoint-${steps} + mdir=$dir/osum-west + python west/bin/decode.py \ + --data_path $data/aishell2_test/test.jsonl \ + --model_dir $mdir \ + --result_path $mdir/result.jsonl + python tools/compute_wer.py --char=1 --v=1 \ + $data/test.jsonl $mdir/result.jsonl > $mdir/result.wer +fi diff --git a/tools/convert_scripts/convert_osum.py b/tools/convert_scripts/convert_osum.py new file mode 100644 index 0000000..e929c08 --- /dev/null +++ b/tools/convert_scripts/convert_osum.py @@ -0,0 +1,104 @@ +# Copyright 2025 Chengdong Liang(liangchengdongd@qq.com) + +import argparse +import json +import os + +import torch +from transformers import AutoConfig, AutoModel + +import west # for init touchasu model # noqa + + +def convert_to_west_state_dict(osum_state_dict): + west_state_dict = {} + for name in osum_state_dict.keys(): + if name.startswith("encoder."): + new_name = name.replace("encoder.", "encoder.encoder.") + west_state_dict[new_name] = osum_state_dict[name] + elif name.startswith("speech_transformer."): + new_name = name.replace("speech_transformer.", + "projector.speech_transformer.") + west_state_dict[new_name] = osum_state_dict[name] + elif name.startswith("llama_model."): + new_name = name.replace("llama_model.", "llm.") + west_state_dict[new_name] = osum_state_dict[name] + elif name.startswith("down_sample_2."): + new_name = name.replace("down_sample_2.", "projector.") + west_state_dict[new_name] = osum_state_dict[name] + elif name.startswith("speech_llama_proj."): + new_name = name.replace("speech_llama_proj.", + "projector.speech_llama_proj.") + west_state_dict[new_name] = osum_state_dict[name] + elif name.startswith('speech_token_emded.'): + west_state_dict[name] = osum_state_dict[name] + + return west_state_dict + + +def get_configs(llm_model_dir, wenet_model_dir, prompt_config_path): + configs = { + "llm_model_name_or_path": llm_model_dir, + "lora_config": { + "inference_mode": False, + "lora_alpha": 32, + "lora_dropout": 0.1, + "r": 8, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "down_proj" + ], + "task_type": "CAUSAL_LM" + }, + "model_type": "osum", + "no_init_llm": False, + "projector_type": "transformer", + "transformers_version": "4.52.3", + "wenet_model_name_or_path": wenet_model_dir, + "prompt_conf_path": prompt_config_path, + } + return configs + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--osum_model_path", type=str, required=True) + parser.add_argument("--llm_model_dir", type=str, required=True) + parser.add_argument("--wenet_model_dir", type=str, required=True) + parser.add_argument("--prompt_config_path", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + checkpoint = torch.load(args.osum_model_path, + map_location="cpu", + weights_only=False) + os.makedirs(args.output_dir) + state_dict = convert_to_west_state_dict(checkpoint) + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + configs = get_configs(args.llm_model_dir, + args.wenet_model_dir, + args.prompt_config_path) + print(configs) + json.dump(configs, f, indent=4) + + config = AutoConfig.from_pretrained(f'{args.output_dir}/config.json') + model = AutoModel.from_config(config) + tokenizer = model.init_tokenizer() + print("Loading osum weights...") + model.load_state_dict(state_dict, strict=False) + print("Saving west model...") + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/west/__init__.py b/west/__init__.py index 9409312..11fc5d0 100644 --- a/west/__init__.py +++ b/west/__init__.py @@ -2,7 +2,7 @@ from transformers import AutoConfig, AutoModel -from west.models.osum_echat import OSUMEChat, OSUMEChatConfig +from west.models.osum_echat import OSUM, OSUMConfig, OSUMEChat, OSUMEChatConfig from west.models.touch_asu import TouchASU, TouchASUConfig from west.models.touch_chat import TouchChat, TouchChatConfig from west.models.touch_flow import TouchFlow, TouchFlowConfig @@ -19,3 +19,5 @@ AutoConfig.register("osum_echat", OSUMEChatConfig) AutoModel.register(OSUMEChatConfig, OSUMEChat) +AutoConfig.register("osum", OSUMConfig) +AutoModel.register(OSUMConfig, OSUM) diff --git a/west/models/osum_echat/__init__.py b/west/models/osum_echat/__init__.py index 688b111..e856f22 100644 --- a/west/models/osum_echat/__init__.py +++ b/west/models/osum_echat/__init__.py @@ -1,3 +1,3 @@ -from .configuration_osum_echat import OSUMEChatConfig # noqa -from .extractor_osum_echat import ExtractorOSUMEChat # noqa -from .modeling_osum_echat import OSUMEChat # noqa +from .configuration_osum_echat import OSUMConfig, OSUMEChatConfig # noqa +from .extractor_osum_echat import ExtractorOSUM, ExtractorOSUMEChat # noqa +from .modeling_osum_echat import OSUM, OSUMEChat # noqa diff --git a/west/models/osum_echat/configuration_osum_echat.py b/west/models/osum_echat/configuration_osum_echat.py index 9f2b859..35d93fe 100644 --- a/west/models/osum_echat/configuration_osum_echat.py +++ b/west/models/osum_echat/configuration_osum_echat.py @@ -25,4 +25,36 @@ def __init__( self.speech_token_num = speech_token_num -__all__ = ["OSUMEChatConfig"] +class OSUMConfig(PretrainedConfig): + model_type = "osum" + + def __init__( + self, + llm_model_name_or_path: str = 'Qwen/Qwen2.5-3B-Instruct', + no_init_llm: bool = True, + wenet_model_name_or_path: str = 'whisper-medium', + encoder_ds_rate: int = 2, + encoder_projector_ds_rate: int = 4, + hidden_size: int = 0, # Will override in OSUM Model + lora_config: Optional[Dict[str, Any]] = None, + freeze_encoder: bool = False, + freeze_llm: bool = False, + speech_token_num: int = 4097, + prompt_conf_path: str = 'conf/prompt_config.yaml', + **kwargs, + ): + super().__init__(**kwargs) + self.llm_model_name_or_path = llm_model_name_or_path + self.no_init_llm = no_init_llm + self.wenet_model_name_or_path = wenet_model_name_or_path + self.encoder_ds_rate = encoder_ds_rate + self.encoder_projector_ds_rate = encoder_projector_ds_rate + self.hidden_size = hidden_size + self.lora_config = lora_config + self.freeze_encoder = freeze_encoder + self.freeze_llm = freeze_llm + self.speech_token_num = speech_token_num + self.prompt_conf_path = prompt_conf_path + + +__all__ = ["OSUMEChatConfig", "OSUMConfig"] diff --git a/west/models/osum_echat/extractor_osum_echat.py b/west/models/osum_echat/extractor_osum_echat.py index 85fca81..1068c9a 100644 --- a/west/models/osum_echat/extractor_osum_echat.py +++ b/west/models/osum_echat/extractor_osum_echat.py @@ -1,9 +1,83 @@ # Copyright (c) 2025 Xuelong Geng(xlgeng@mail.nwpu.edu.cn) +import math +import random + +import torch +import wenet +from gxl_ai_utils.utils import utils_file +from transformers.trainer_pt_utils import LabelSmoother + from west.dataset.extractor import Extractor -class ExtractorOSUMEChat(Extractor): +class ExtractorOSUM(Extractor): + model_type = "osum" + fields_batch_static = {'audio_offsets'} + fields_batch_dynamic = {'audio_features', 'input_ids', 'labels'} + fields_pack_offset = {'audio_offsets'} + + def __init__(self, tokenizer, model_config, inference=False): + super().__init__(tokenizer, model_config, inference) + self.compute_feature, self.feature_dim = wenet.load_feature( + self.model_config.wenet_model_name_or_path + ) + self.ds_rate = (self.model_config.encoder_ds_rate * + self.model_config.encoder_projector_ds_rate) + + self.global_prompt_dict = utils_file.load_dict_from_yaml( + model_config.prompt_conf_path) + + def extract(self, item): + """ + {'wav': wav_path, 'txt': text, 'task': ''} + """ + IGNORE_TOKEN_ID = LabelSmoother.ignore_index + t0 = '' + if 'task' in item: + task_name = item['task'] + try: + random_index = random.randint( + 0, len(self.global_prompt_dict[task_name]) - 1) + prompt = self.global_prompt_dict[task_name][random_index] + if prompt != "": + t0 = prompt + except Exception as e: + print(f"Error: {e}") + else: + task_name = '' # default: speech recognition + try: + random_index = random.randint( + 0, len(self.global_prompt_dict[task_name]) - 1) + prompt = self.global_prompt_dict[task_name][random_index] + t0 = prompt + except Exception as e: + print(f"Error: {e}") + + mel = self.compute_feature(item['wav']) + ids_audio = [0] * (math.ceil(mel.size(0) / self.ds_rate) + 2) + + ids0 = self.tokenizer.encode(t0) + ids = ids0 + ids_audio + tgt = [IGNORE_TOKEN_ID] * len(ids) + + if not self.inference: + ids1 = self.tokenizer.encode(item['txt'] + "<|endoftext|>") + ids = ids + ids1 + tgt = tgt + ids1 + + input_ids = torch.tensor(ids, dtype=torch.int) + tgt_ids = torch.tensor(tgt, dtype=torch.long) + return { + 'input_ids': input_ids, + 'labels': tgt_ids, + 'audio_features': mel, + 'audio_offsets': len(ids0), + } + + +class ExtractorOSUMEChat(ExtractorOSUM): + model_type = "osum_echat" def __init__(self, tokenizer, model_config, inference=False): super().__init__(tokenizer, model_config, inference) diff --git a/west/models/osum_echat/modeling_osum_echat.py b/west/models/osum_echat/modeling_osum_echat.py index 3919ef8..0c4e045 100644 --- a/west/models/osum_echat/modeling_osum_echat.py +++ b/west/models/osum_echat/modeling_osum_echat.py @@ -5,13 +5,16 @@ import torch import wenet from gxl_ai_utils.utils import utils_file +from peft import LoraConfig, get_peft_model from torch import nn from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationMixin, PreTrainedModel, StoppingCriteriaList) from wenet.models.transformer.encoder import TransformerEncoder -from .configuration_osum_echat import OSUMEChatConfig +from west.utils.utils import freeze_module + +from .configuration_osum_echat import OSUMConfig, OSUMEChatConfig from .cumstom_stop_criteria import (InterruptStopper, MaxTokenStopper, S2SStopCriteria) @@ -65,24 +68,15 @@ def forward(self, encoder_out, encoder_mask): 1) -class OSUMEChat(PreTrainedModel, GenerationMixin): - model_type = 'osum_echat' - config_class = OSUMEChatConfig +class OSUM(PreTrainedModel, GenerationMixin): + """OSUM: https://arxiv.org/pdf/2501.13306 + """ + model_type = 'osum' + config_class = OSUMConfig supports_gradient_checkpointing = True - def __init__(self, config: OSUMEChatConfig, - *inputs, **kwargs): - """ - TODO(Xuelong Geng): Complete the design of OSUMEChat - """ - super().__init__(config, *inputs, - **kwargs) - self.encoder = wenet.load_model( - config.wenet_model_name_or_path) - del self.encoder.decoder - del self.encoder.ctc - utils_file.logging_info( - f'self.encoder: {self.encoder}') + def __init__(self, config: OSUMConfig): + super().__init__(config) llm_config = AutoConfig.from_pretrained( config.llm_model_name_or_path) utils_file.logging_info( @@ -92,8 +86,7 @@ def __init__(self, config: OSUMEChatConfig, utils_file.logging_info( 'No init llm, only load llm structure' ) - self.llm = AutoModelForCausalLM.from_config( - llm_config, ) + self.llm = AutoModelForCausalLM.from_config(llm_config) self.llm.to(torch.bfloat16) else: self.llm = AutoModelForCausalLM.from_pretrained( @@ -103,32 +96,221 @@ def __init__(self, config: OSUMEChatConfig, attn_implementation="flash_attention_2", # or "flex_attention" ) + config.hidden_size = llm_config.hidden_size + self.encoder = wenet.load_model(config.wenet_model_name_or_path) + # remove decoder and ctc, keep encoder only + del self.encoder.decoder + del self.encoder.ctc + utils_file.logging_info(f'self.encoder: {self.encoder}') + + # self.embed_tokens = self.llm.model.embed_tokens - self.tokenizer = AutoTokenizer.from_pretrained( - config.llm_model_name_or_path, - use_fast=False, - trust_remote_code=True) - self.embed_tokens = self.llm.model.embed_tokens - utils_file.logging_info( - f'self.llm: {self.llm}') self.projector = ProjectorTransformerWithCov1d( - encoder_dim=self.encoder.encoder. - output_size(), + encoder_dim=self.encoder.encoder.output_size(), llm_dim=llm_config.hidden_size, ) + total_params = sum(p.numel() for p in self.projector.parameters()) utils_file.logging_info( - f'self.projector: {self.projector}') + 'self.projector: {}, projector total params: {:.2f}M'.format( + self.projector, total_params / 1024 / 1024 + ) + ) + + if config.lora_config is not None: + lora_config = LoraConfig(**config.lora_config) + self.llm = get_peft_model(self.llm, lora_config) + self.llm.print_trainable_parameters() + if config.freeze_encoder: + self.freeze_encoder() + if config.lora_config is None and config.freeze_llm: + self.freeze_llm() + + self.add_embed_head = True + self.IGNORE_ID = -100 + self.speech_token_num = config.speech_token_num self.speech_token_emded = torch.nn.Embedding( config.speech_token_num + 2, llm_config.hidden_size) + + def tie_weights(self): + return self.llm.tie_weights() + + def get_speech_embeddings(self, audio_features, audio_features_lengths): + encoder_out, encoder_mask = self.encoder._forward_encoder( + audio_features, audio_features_lengths + ) + speech_embeds, speech_masks = self.projector(encoder_out, encoder_mask) + speech_embeds, speech_masks, _ = self._add_bos_eos( + 0 + self.config.speech_token_num, 1 + self.config.speech_token_num, + speech_embeds, speech_masks, None, + ) + speech_embeds_lens = speech_masks.sum(-1) + + return speech_embeds, speech_embeds_lens + + def compute_mix_embedding( + self, + input_ids: torch.LongTensor = None, + audio_offsets: Optional[torch.LongTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_features_lengths: Optional[torch.LongTensor] = None, + batch_idx: Optional[torch.LongTensor] = None, + ): + text_emb = self.llm.get_input_embeddings()(input_ids) + speech_emb, speech_emb_lens = self.get_speech_embeddings( + audio_features, audio_features_lengths + ) + # speech_emb_lens = speech_emb.shape[1] + inputs_embeds = text_emb + for i in range(audio_features.size(0)): + b = batch_idx[i] + s, e = audio_offsets[i], audio_offsets[i] + speech_emb_lens[i] + inputs_embeds[b, s:e, :] = speech_emb[i, :speech_emb_lens[i], :] + return inputs_embeds + + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + audio_offsets: Optional[torch.LongTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_features_lengths: Optional[torch.LongTensor] = None, + batch_idx: Optional[torch.LongTensor] = None, + **kwargs, + ): + inputs_embeds = self.compute_mix_embedding( + input_ids, + audio_offsets, + audio_features, + audio_features_lengths, + batch_idx, + ) + out = self.llm(inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + **kwargs) + return out + + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def generate( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + audio_offsets: Optional[torch.LongTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_features_lengths: Optional[torch.LongTensor] = None, + batch_idx: Optional[torch.LongTensor] = None, + **kwargs, + ): + inputs_embeds = self.compute_mix_embedding( + input_ids, + audio_offsets, + audio_features, + audio_features_lengths, + batch_idx, + ) + model_outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + eos_token_id=151643, + max_new_tokens=400, + do_sample=True, + top_p=0.9, + top_k=5, + repetition_penalty=1.05, + length_penalty=1.0, + temperature=1.0, + num_beams=4, + **kwargs, + ) + return model_outputs + + def enable_input_require_grads(self): + self.llm.enable_input_require_grads() + + def freeze_encoder(self): + freeze_module(self.encoder) + self.encoder.eval() + + def freeze_llm(self): + freeze_module(self.llm) + + def init_tokenizer(self): + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.llm_model_name_or_path, + padding_side="right", + use_fast=False, + trust_remote_code=True) + return self.tokenizer + + def _add_bos_eos(self, + bos, + eos, + inputs_embeds, + attention_mask, + target=None): + B = len(inputs_embeds) + bos_eos_target = torch.full( + [B, 1], self.IGNORE_ID).to( + inputs_embeds.device) # B,1 + bos_eos_mask = torch.full( + [B, 1], + True).to(inputs_embeds.device) # B, 1 + + if bos is not None: + bos_embed = self.speech_token_emded( + torch.full([B, 1], bos).to( + inputs_embeds.device) + ) # B, 1, D + inputs_embeds = torch.cat( + (bos_embed, inputs_embeds), + 1) # B, (1+T), D + attention_mask = torch.cat( + (bos_eos_mask, attention_mask), + 1) # B, (1+T) + if target is not None: + target = torch.cat( + (bos_eos_target, target), + 1) # B, (1+T), D + + if eos is not None: + eos_embed = self.speech_token_emded( + torch.full([B, 1], eos).to( + inputs_embeds.device) + ) # B, 1, D + inputs_embeds = torch.cat( + (inputs_embeds, eos_embed), + 1) # B, (1+T+1), D + attention_mask = torch.cat( + (attention_mask, bos_eos_mask), + 1) # B, (1+T+1) + if target is not None: + target = torch.cat( + (target, bos_eos_target), + 1) # B, (1+T+1), D + + return inputs_embeds, attention_mask, target + + +class OSUMEChat(OSUM): + """OSUMEChat: https://arxiv.org/pdf/2508.09600 + """ + model_type = 'osum_echat' + config_class = OSUMEChatConfig + supports_gradient_checkpointing = True + + def __init__(self, config: OSUMEChatConfig): + super().__init__(config) + self.init_custom_stop_criteria() self.speech_head = torch.nn.Linear( - llm_config.hidden_size, + self.llm.config.hidden_size, config.speech_token_num) - self.add_embed_head = True - self.IGNORE_ID = -100 - self.speech_token_num = config.speech_token_num - self.init_custom_stop_criteria() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) @@ -229,7 +411,7 @@ def generate( prompt_pattern2_embeds = self.embed_tokens( prompt_pattern2) - hyps = [4098] + hyps = [4098] # for generate start token_emb = self.speech_token_emded( torch.tensor(hyps[-1:]).to( device)).unsqueeze(0) @@ -273,11 +455,6 @@ def generate( skip_special_tokens=True) return (output_text, text_res, speech_res) - def init_tokenizer(self): - """ - TODO(Xuelong Geng): Complete the design of OSUMEChat - """ - def init_custom_stop_criteria(self): """ 创建需要的stop criteria @@ -323,51 +500,3 @@ def do_add_speech_embed_head(self): self.llm.speech_head = self.speech_head.to( torch.bfloat16) self.add_embed_head = False - - def _add_bos_eos(self, - bos, - eos, - inputs_embeds, - attention_mask, - target=None): - B = len(inputs_embeds) - bos_eos_target = torch.full( - [B, 1], self.IGNORE_ID).to( - inputs_embeds.device) # B,1 - bos_eos_mask = torch.full( - [B, 1], - True).to(inputs_embeds.device) # B, 1 - - if bos is not None: - bos_embed = self.speech_token_emded( - torch.full([B, 1], bos).to( - inputs_embeds.device) - ) # B, 1, D - inputs_embeds = torch.cat( - (bos_embed, inputs_embeds), - 1) # B, (1+T), D - attention_mask = torch.cat( - (bos_eos_mask, attention_mask), - 1) # B, (1+T) - if target is not None: - target = torch.cat( - (bos_eos_target, target), - 1) # B, (1+T), D - - if eos is not None: - eos_embed = self.speech_token_emded( - torch.full([B, 1], eos).to( - inputs_embeds.device) - ) # B, 1, D - inputs_embeds = torch.cat( - (inputs_embeds, eos_embed), - 1) # B, (1+T+1), D - attention_mask = torch.cat( - (attention_mask, bos_eos_mask), - 1) # B, (1+T+1) - if target is not None: - target = torch.cat( - (target, bos_eos_target), - 1) # B, (1+T+1), D - - return inputs_embeds, attention_mask, target