Skip to content

Fix: RWKV7 infctx forward兼容 PEFT 传入 inputs_embeds(避免 TypeError)#81

Open
Deng-Xian-Sheng wants to merge 1 commit intoJoluck:mainfrom
Deng-Xian-Sheng:main
Open

Fix: RWKV7 infctx forward兼容 PEFT 传入 inputs_embeds(避免 TypeError)#81
Deng-Xian-Sheng wants to merge 1 commit intoJoluck:mainfrom
Deng-Xian-Sheng:main

Conversation

@Deng-Xian-Sheng
Copy link
Copy Markdown

@Deng-Xian-Sheng Deng-Xian-Sheng commented Jan 26, 2026

背景 / 问题
RWKV_TRAIN_TYPE=infctx 且使用 PEFT (PeftModelForCausalLM) 训练时,启动训练会报错:

  • TypeError: RWKV7.forward_infctx() got an unexpected keyword argument 'inputs_embeds'

根因分析
PEFT/HF 模型 wrapper 的 forward() 会透传 inputs_embeds(即使为 None)以及其它 HF 常见关键字参数到 base model。
RWKV7 在 infctx 路径的 forward_infctx() 签名未包含 inputs_embeds,也没有 **kwargs 兜底,因此触发 unexpected keyword argument。

修复内容

  • RWKV7.forward() 调整为 HF/PEFT 友好的显式签名:支持 input_ids / attention_mask / inputs_embeds / **kwargs
  • forward_infctx() 增加 inputs_embeds=None 并接收 **kwargs,当 inputs_embeds 非空时直接使用 embedding;否则使用 input_ids 做 embedding
  • 可选:当未传入 last_shift_states/last_wkv_states 时初始化空 state,避免某些调用路径崩溃

影响范围 / 兼容性

  • 不改变默认行为:inputs_embeds is None 时逻辑与原来一致
  • 仅增强 HF/PEFT wrapper 下的兼容性
  • 对非 infctx(normal)路径同样更稳健(支持 inputs_embeds

错误日志

########## work in progress ##########
########## WKV OP           fla               ##########

########## FUSED OP    False          ##########

RWKV_MY_TESTING x070
########## Loading /root/autodl-tmp/rwkv7-g1/rwkv7-g1c-13.3b-20251231-ctx8192.pth... ##########
trainable params: 143,917,056 || all params: 13,414,215,680 || trainable%: 1.0729
/root/miniconda3/lib/python3.12/site-packages/lightning/fabric/connector.py:571: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Data has 9888074 tokens.
Trimmed to 1000 samples for epoch_steps 1000.
Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[rank0]:W0126 17:34:25.697000 10664 site-packages/torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
[rank0]:W0126 17:34:25.697000 10664 site-packages/torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
Before initializing optimizer states
MA 25.54 GB         Max_MA 25.81 GB         CA 25.91 GB         Max_CA 26 GB 
CPU Virtual Memory:  used = 86.91 GB, percent = 8.6%
After initializing optimizer states
MA 25.54 GB         Max_MA 26.08 GB         CA 26.45 GB         Max_CA 26 GB 
CPU Virtual Memory:  used = 86.94 GB, percent = 8.6%
After initializing ZeRO optimizer
MA 25.54 GB         Max_MA 25.54 GB         CA 26.45 GB         Max_CA 26 GB 
CPU Virtual Memory:  used = 86.89 GB, percent = 8.6%
/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:242: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name      | Type                 | Params | Mode  | FLOPs
-------------------------------------------------------------------
0 | model     | PeftModelForCausalLM | 13.4 B | train | 0    
1 | criterion | CrossEntropyLoss     | 0      | train | 0    
-------------------------------------------------------------------
143 M     Trainable params
13.3 B    Non-trainable params
13.4 B    Total params
53,656.863Total estimated model params size (MB)
4462      Modules in train mode
0         Modules in eval mode
0         Total Flops
Epoch 0:   0%|                                         | 0/1000 [00:00<?, ?it/s]
{'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 1, 'contiguous_gradients': True, 'overlap_comm': True, 'allgather_partitions': True, 'reduce_scatter': True, 'allgather_bucket_size': 200000000, 'reduce_bucket_size': 200000000, 'sub_group_size': 1000000000000}, 'activation_checkpointing': {'partition_activations': False, 'cpu_checkpointing': False, 'contiguous_memory_optimization': False, 'synchronize_checkpoint_boundary': False}, 'aio': {'block_size': 1048576, 'queue_depth': 8, 'single_submit': False, 'overlap_events': True, 'thread_count': 1}, 'gradient_accumulation_steps': 1, 'train_micro_batch_size_per_gpu': 1, 'gradient_clipping': 1.0, 'bf16': {'enabled': True}}

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/RWKV-PEFT/train.py", line 267, in <module>
[rank0]:     trainer.fit(model, train_data)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 584, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 630, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1079, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1123, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
[rank0]:     self.advance()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 465, in advance
[rank0]:     self.epoch_loop.run(self._data_fetcher)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 153, in run
[rank0]:     self.advance(data_fetcher)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 352, in advance
[rank0]:     batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank0]:     self._optimizer_step(batch_idx, closure)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank0]:     call._call_lightning_module_hook(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1368, in optimizer_step
[rank0]:     optimizer.step(closure=optimizer_closure)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank0]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/ddp.py", line 274, in optimizer_step
[rank0]:     optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank0]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank0]:     closure_result = closure()
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank0]:     self._result = self.closure(*args, **kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
[rank0]:     step_output = self._step_fn()
[rank0]:                   ^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
[rank0]:     training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 329, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2236, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/RWKV-PEFT/rwkvt/lightning_train/light_rwkv.py", line 217, in training_step
[rank0]:     total_loss,new_shift_states, new_wkv_states,token_amount = torch_checkpoint(
[rank0]:                                                                ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank0]:     ret = function(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/RWKV-PEFT/rwkvt/lightning_train/light_rwkv.py", line 196, in checkpointed_step
[rank0]:     logits, new_shift_states, new_wkv_states = self(idx, last_shift_states, last_wkv_states)
[rank0]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/RWKV-PEFT/rwkvt/lightning_train/light_rwkv.py", line 176, in forward
[rank0]:     return self.model(idx, last_shift_states, last_wkv_states, attention_mask)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py", line 1923, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 311, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/RWKV-PEFT/rwkvt/rwkv7/model.py", line 54, in forward
[rank0]:     return self.forward_infctx(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: TypeError: RWKV7.forward_infctx() got an unexpected keyword argument 'inputs_embeds'
[rank0]:[W126 17:34:52.357970813 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

@Deng-Xian-Sheng
Copy link
Copy Markdown
Author

训练启动脚本:

#!/usr/bin/env bash
set -euo pipefail

ROOT="/root/autodl-tmp"
PEFT_DIR="$ROOT/RWKV-PEFT"

# ====== 1) 基座模型与数据 ======
load_model="$ROOT/rwkv7-g1/rwkv7-g1c-13.3b-20251231-ctx8192.pth"   # <<< 改成你的真实路径
data_file="$ROOT/rwkv_dataset/rwkv_dataset_text_document"                  # 你的 binidx 前缀:rwkv_dataset_text_document.bin / rwkv_dataset_text_document.idx
proj_dir="$ROOT/out/novel_lora_13b_ctx8192_infctx"

mkdir -p "$proj_dir"

# ====== 2) 模型结构(你从 GGUF meta 已确认)======
n_layer=61
n_embd=4096
vocab_size=65536

# ====== 3) 上下文设置(贴近你的需求)======
ctx_len=8192
chunk_ctx=4096                    # infctx:必须 < ctx_len

# ====== 4) “正式训练”步数预算(核心)======
# 这里的 epoch 不是“遍历完 4801 条样本”,而是“跑多少 step”
# 8192 token/step,40k steps ≈ 3.28e8 token(约 3.3 亿 token 级别)
# epoch_steps=5000
# epoch_count=8                     # 总 step = 5000 * 8 = 40000
# epoch_save=1                      # 每个 epoch 都保存一次,便于挑最好的

# —— 总步数仍是 40000,但每 1000 step 保存一次,方便你随时停在“检查点” ——
epoch_steps=1000
epoch_count=40
epoch_save=1

# ====== 5) Batch 与精度 ======
micro_bsz=1                       # 13B + 8k:正式训练也建议从 1 开始
# grad_cp 在你提供的教程/脚本里用于省显存(保留它并不算“花式优化”,更像常规开关)
grad_cp=1

# ====== 6) LoRA 配置(正式训练建议比 r=16 稍强一点)======
# 写小说领域适配:r=32 常见且够用;alpha 通常取 2*r
peft_config='{"r":32,"lora_alpha":64,"lora_dropout":0.05}'

# ====== 7) 学习率与调度(正式训练更稳)======
lr_init=1e-5
lr_final=2e-6
# README 里说默认 cos_decay;你也可以改成 wsd(更像“正式训练”)
lr_schedule="wsd"

export CUDA_VISIBLE_DEVICES=0

cd "$PEFT_DIR"

python train.py \
  --load_model "$load_model" \
  --proj_dir "$proj_dir" \
  --data_file "$data_file" \
  --vocab_size "$vocab_size" \
  --data_type binidx \
  --n_layer "$n_layer" --n_embd "$n_embd" \
  --ctx_len "$ctx_len" --micro_bsz "$micro_bsz" \
  --epoch_steps "$epoch_steps" --epoch_count "$epoch_count" --epoch_save "$epoch_save" \
  --lr_init "$lr_init" --lr_final "$lr_final" \
  --lr_schedule "$lr_schedule" \
  --accelerator gpu --precision bf16 \
  --devices 1 --strategy deepspeed_stage_1 --grad_cp "$grad_cp" \
  --my_testing "x070" \
  --op fla \
  --peft lora --peft_config "$peft_config" \
  --train_type infctx --chunk_ctx "$chunk_ctx"

@Deng-Xian-Sheng
Copy link
Copy Markdown
Author

Deng-Xian-Sheng commented Jan 26, 2026

我在使用--train_type infctx参数,然后错误发生了,然后我修复了它(emmm,目前它能运行了,不过 loss 会在第 3 个 step 直接变成 nan)

@Joluck
Copy link
Copy Markdown
Owner

Joluck commented Jan 27, 2026

感谢,infctx太久没用也基本没人使用所以没有修复,稍后我会检查一下

@Deng-Xian-Sheng
Copy link
Copy Markdown
Author

似乎用infctx来实现无限上下文?

我的数据集文本长度非常长,所以考虑了infctx。

如果不使用infctx,训练时数据集会不会被截断?

我可能对infctx的作用有点误解

@Joluck
Copy link
Copy Markdown
Owner

Joluck commented Jan 27, 2026

似乎用infctx来实现无限上下文?

我的数据集文本长度非常长,所以考虑了infctx。

如果不使用infctx,训练时数据集会不会被截断?

我可能对infctx的作用有点误解

是的不适用infctx会被截断,但是infctx会让训练非常慢,你要训练多长的数据?

@Deng-Xian-Sheng
Copy link
Copy Markdown
Author

似乎用infctx来实现无限上下文?
我的数据集文本长度非常长,所以考虑了infctx。
如果不使用infctx,训练时数据集会不会被截断?
我可能对infctx的作用有点误解

是的不适用infctx会被截断,但是infctx会让训练非常慢,你要训练多长的数据?

几万字的小说。

我合成了一个用于生成短篇、长篇小说的数据集,打算利用rwkv长上下文时不增加内存、不降低速度的特性。

@Dark1Forest
Copy link
Copy Markdown

有进展了吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants