Skip to content

Commit 4f3f1be

Browse files
author
firecoperana
committed
init n_buffer
1 parent 0319431 commit 4f3f1be

File tree

4 files changed

+24
-20
lines changed

4 files changed

+24
-20
lines changed

common/common.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,9 +1532,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
15321532
ban_phrases.push_back(str);
15331533
}
15341534
}
1535-
std::sort(ban_phrases.begin(), ban_phrases.end(), [](const std::string& a, const std::string& b) {
1536-
return a.length() > b.length();
1537-
});
15381535
params.ban_phrases = ban_phrases;
15391536
return true;
15401537
}

common/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ struct gpt_params {
216216

217217
std::vector<std::string> in_files; // all input files
218218
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
219-
std::vector<std::string> ban_phrases; //strings that are banned in generation
220-
int32_t banned_n = 1; // number of tokens that are banned in the phrase
221-
int32_t n_buffer; // number of token buffers for string ban
219+
std::vector<std::string> ban_phrases; // strings that are banned in generation
220+
int32_t banned_n = 1; // number of tokens that are banned in the phrase
221+
size_t n_buffer = 0; // number of token buffers for string ban
222222

223223
std::vector<llama_model_kv_override> kv_overrides;
224224
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

examples/server/server-context.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,21 +1143,28 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
11431143
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
11441144
return a.length() > b.length();
11451145
});
1146-
}
1147-
else if (params_base.ban_phrases.size()>0 && params_base.n_buffer == 0) {
1148-
slot.ban_phrases.clear();
1149-
for (const auto & val : params_base.ban_phrases) {
1150-
if (!val.empty()) {
1151-
std::string s = string_lower(val);
1152-
auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true);
1153-
if (ban_tokens.size() > slot.n_buffer) {
1154-
slot.n_buffer = ban_tokens.size();
1146+
} else if (params_base.ban_phrases.size() > 0) {
1147+
if (params_base.n_buffer == 0) {
1148+
slot.ban_phrases.clear();
1149+
std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) {
1150+
return a.length() > b.length();
1151+
});
1152+
for (auto & val : params_base.ban_phrases) {
1153+
if (!val.empty()) {
1154+
val = string_lower(val);
1155+
auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true);
1156+
if (ban_tokens.size() > slot.n_buffer) {
1157+
slot.n_buffer = ban_tokens.size();
1158+
}
1159+
slot.ban_phrases.push_back(val);
11551160
}
1156-
slot.ban_phrases.push_back(s);
1157-
}
1161+
}
1162+
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
1163+
params_base.n_buffer = slot.n_buffer;
1164+
} else {
1165+
slot.ban_phrases = params_base.ban_phrases;
1166+
slot.n_buffer = params_base.n_buffer;
11581167
}
1159-
params_base.n_buffer = slot.n_buffer + 3;
1160-
slot.n_buffer = slot.n_buffer + 3; // extra buffer in case
11611168
}
11621169
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
11631170
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);

examples/server/server-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct server_slot {
8484
stop_type stop;
8585

8686
// For context rewind/ token buffer
87-
int32_t n_buffer = 0;
87+
size_t n_buffer = 0;
8888
int32_t rewind_count = 0;
8989
bool rewind_status = false;
9090
std::unordered_map<llama_token, float> logit_bias;

0 commit comments

Comments
 (0)