Skip to content
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tokenizers-cpp
1 change: 1 addition & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"redpajama_chat", RedPajamaChat},
{"rwkv_world", RWKVWorld},
{"rwkv", RWKV},
{"rwkv5", RWKVWorld},
{"gorilla", Gorilla},
{"guanaco", Guanaco},
{"dolly", Dolly},
Expand Down
2 changes: 1 addition & 1 deletion cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ class LLMChat {
std::vector<int32_t> encoded = this->tokenizer_->Encode(all_prompt);
tokens.insert(tokens.end(), encoded.begin(), encoded.end());
if (this->sliding_window_ != -1 || // There is no max window size if we use sliding window
this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) {
this->total_seq_len_ + (int)tokens.size() + gen_mean_gen_len < this->max_window_size_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use static_cast<int64_t>(tokens.size())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I think this is a quite serious bug that has troubled me for more than two weeks. The quantized int4 version of rwkv5 seems to give very unintelligent responses, and it was only today that I thought to print out the prompt. Then I discovered that all the code after line 618 was ineffective, and finally pinpointed this issue. Now the quantized int4 version of rwkv5 can also generate text normally. The performance has also improved in other modes for rwkv.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaboratae a bit why static cast int is needed here? do we involve some negative numbers in computing this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error might occur whenever the input prompt is relatively long. Before using static_cast<int64_t> to convert tokens.size(), the expression (this->total_seq_len_ + tokens.size() + gen_mean_gen_len) might have experienced integer overflow at some stage, causing the result to be incorrectly interpreted as a negative number, which in turn erroneously returned true for the comparison operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I feel this is a strange way to think about it, given max_window_size_ == -1, we should specially check it, and that means there is no out of bound and we do not need to re-encode (aka running the code after), would be good for @Hzfengsy to take a loo as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, there seems to be a bug in the handling of the rwkv system prompts. I expect that each interaction with the rwkv model should include the system prompt along with the current text. This is because its series of models(rwkv4,5,6) have higher requirements for prompts. Currently, only the first round of dialogue includes the system's prompt, and the system prompt is forgotten in subsequent dialogues

return tokens;
}
// need shift window and re-encode
Expand Down
10 changes: 6 additions & 4 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
mistral,
param_manager,
rwkv,
rwkv5,
stablelm_3b,
)
from mlc_llm.relax_model.commons import (
Expand Down Expand Up @@ -805,6 +806,7 @@ def build_model_from_args(args: argparse.Namespace):
"gptj": gptj,
"rwkv": rwkv,
"rwkv_world": rwkv,
"rwkv5": rwkv5,
"chatglm": chatglm,
}

Expand Down Expand Up @@ -870,16 +872,16 @@ def build_model_from_args(args: argparse.Namespace):

if args.model_category != "minigpt":
utils.copy_tokenizer(args)
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
if args.model_category == "rwkv" or args.model_category == "rwkv_world" or args.model_category == "rwkv5":
# TODO: refactor config into model definition
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
max_gen_len=model_config.max_sequence_length,
top_p=0.6,
temperature=1.2,
repetition_penalty=0.996,
top_p=0.3,
temperature=1.0,
repetition_penalty=1.0,
rwkv_world=True,
)
elif args.model_category == "chatglm":
Expand Down
3 changes: 3 additions & 0 deletions mlc_llm/dispatch/dispatch_tir_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, model: str):

elif model == "rwkv":
lookup = None

elif model == "rwkv5":
lookup = None

elif model == "rwkv_world":
lookup = None
Expand Down
Loading