Skip to content

Commit 1d6afbd

Browse files
authored
feat(llama.cpp): Add support to grammar triggers (#4733)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent d79f02e commit 1d6afbd

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

backend/backend.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ message Reply {
163163
double timing_token_generation = 5;
164164
}
165165

166+
message GrammarTrigger {
167+
string word = 1;
168+
bool at_start = 2;
169+
}
170+
166171
message ModelOptions {
167172
string Model = 1;
168173
int32 ContextSize = 2;
@@ -247,6 +252,8 @@ message ModelOptions {
247252

248253
string CacheTypeKey = 63;
249254
string CacheTypeValue = 64;
255+
256+
repeated GrammarTrigger GrammarTriggers = 65;
250257
}
251258

252259
message Result {

backend/cpp/llama/grpc-server.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,9 @@ struct llama_server_context
468468
bool add_bos_token = true;
469469
bool has_eos_token = true;
470470

471+
bool grammar_lazy = false;
472+
std::vector<common_grammar_trigger> grammar_trigger_words;
473+
471474
int32_t n_ctx; // total context for all clients / slots
472475

473476
// system prompt
@@ -706,6 +709,8 @@ struct llama_server_context
706709
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
707710
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
708711
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
712+
slot->sparams.grammar_trigger_words = grammar_trigger_words;
713+
slot->sparams.grammar_lazy = grammar_lazy;
709714

710715
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
711716
// Might be better to reject the request with a 400 ?
@@ -2374,6 +2379,21 @@ static void params_parse(const backend::ModelOptions* request,
23742379
if ( request->ropefreqscale() != 0.0f ) {
23752380
params.rope_freq_scale = request->ropefreqscale();
23762381
}
2382+
2383+
if (request->grammartriggers_size() > 0) {
2384+
LOG_INFO("configuring grammar triggers", {});
2385+
llama.grammar_lazy = true;
2386+
for (int i = 0; i < request->grammartriggers_size(); i++) {
2387+
common_grammar_trigger trigger;
2388+
trigger.word = request->grammartriggers(i).word();
2389+
trigger.at_start = request->grammartriggers(i).at_start();
2390+
llama.grammar_trigger_words.push_back(trigger);
2391+
LOG_INFO("grammar trigger", {
2392+
{ "word", trigger.word },
2393+
{ "at_start", trigger.at_start }
2394+
});
2395+
}
2396+
}
23772397
}
23782398

23792399

core/backend/options.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,19 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
118118
nGPULayers = *c.NGPULayers
119119
}
120120

121+
triggers := make([]*pb.GrammarTrigger, 0)
122+
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
123+
triggers = append(triggers, &pb.GrammarTrigger{
124+
Word: t.Word,
125+
AtStart: t.AtStart,
126+
})
127+
128+
}
129+
121130
return &pb.ModelOptions{
122131
CUDA: c.CUDA || c.Diffusers.CUDA,
123132
SchedulerType: c.Diffusers.SchedulerType,
133+
GrammarTriggers: triggers,
124134
PipelineType: c.Diffusers.PipelineType,
125135
CFGScale: c.CFGScale,
126136
LoraAdapter: c.LoraAdapter,

pkg/functions/parse.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ type GrammarConfig struct {
4747
// SchemaType can be configured to use a specific schema type to force the grammar
4848
// available : json, llama3.1
4949
SchemaType string `yaml:"schema_type"`
50+
51+
GrammarTriggers []GrammarTrigger `yaml:"triggers"`
52+
}
53+
54+
type GrammarTrigger struct {
55+
// Trigger is the string that triggers the grammar
56+
Word string `yaml:"word"`
57+
AtStart bool `yaml:"at_start"`
5058
}
5159

5260
// FunctionsConfig is the configuration for the tool/function call.
@@ -361,6 +369,6 @@ func ParseFunctionCallArgs(functionArguments string, functionConfig FunctionsCon
361369
}
362370

363371
jsonBytes, _ := json.Marshal(args)
364-
372+
365373
return string(jsonBytes)
366374
}

0 commit comments

Comments
 (0)