Skip to content

Commit f7a3995

Browse files
committed
fix(realtime): Improve tool call handling and error reporting
- Refactor Model interface to accept []types.ToolUnion and *types.ToolChoiceUnion instead of JSON strings, eliminating unnecessary marshal/unmarshal cycles - Fix Parameters field handling: support both map[string]any and JSON string formats - Add PredictConfig() method to Model interface for accessing model configuration - Add comprehensive debug logging for tool call parsing and function config - Add missing return statement after prediction error (critical bug fix) - Add warning logs for NoAction function argument parsing failures - Improve error visibility throughout generateResponse function 💘 Generated with Crush Assisted-by: Claude Sonnet 4.5 via Crush <[email protected]> Signed-off-by: Richard Palethorpe <[email protected]>
1 parent 9da5abc commit f7a3995

File tree

2 files changed

+76
-34
lines changed

2 files changed

+76
-34
lines changed

core/http/endpoints/openai/realtime.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ const (
4141

4242
// Session represents a single WebSocket connection and its state
4343
type Session struct {
44-
ID string
45-
TranscriptionOnly bool
44+
ID string
45+
TranscriptionOnly bool
4646
// The pipeline or any-to-any model name (full realtime mode)
47-
Model string
47+
Model string
4848
// The voice may be a TTS model name or a parameter passed to a TTS model
4949
Voice string
5050
TurnDetection *types.TurnDetectionUnion // "server_vad", "semantic_vad" or "none"
@@ -58,7 +58,7 @@ type Session struct {
5858
DefaultConversationID string
5959
ModelInterface Model
6060
// The pipeline model config or the config for an any-to-any model
61-
ModelConfig *config.ModelConfig
61+
ModelConfig *config.ModelConfig
6262
}
6363

6464
func (s *Session) FromClient(session *types.SessionUnion) {
@@ -121,8 +121,9 @@ var sessionLock sync.Mutex
121121
type Model interface {
122122
VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error)
123123
Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error)
124-
Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error)
124+
Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error)
125125
TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error)
126+
PredictConfig() *config.ModelConfig
126127
}
127128

128129
var upgrader = websocket.Upgrader{
@@ -765,7 +766,7 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co
765766
}
766767

767768
if !session.TranscriptionOnly {
768-
generateResponse(session.ModelConfig, session, utt, transcript, conv, c, websocket.TextMessage)
769+
generateResponse(session, utt, transcript, conv, c, websocket.TextMessage)
769770
}
770771
}
771772

@@ -790,9 +791,11 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
790791
}
791792

792793
// Function to generate a response based on the conversation
793-
func generateResponse(config *config.ModelConfig, session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) {
794+
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *websocket.Conn, mt int) {
794795
xlog.Debug("Generating realtime response...")
795796

797+
config := session.ModelInterface.PredictConfig()
798+
796799
item := types.MessageItemUnion{
797800
User: &types.MessageItemUser{
798801
ID: generateItemID(),
@@ -881,19 +884,7 @@ func generateResponse(config *config.ModelConfig, session *Session, utt []byte,
881884
},
882885
})
883886

884-
toolsJSON := ""
885-
if len(session.Tools) > 0 {
886-
b, _ := json.Marshal(session.Tools)
887-
toolsJSON = string(b)
888-
}
889-
890-
toolChoiceJSON := ""
891-
if session.ToolChoice != nil {
892-
b, _ := json.Marshal(session.ToolChoice)
893-
toolChoiceJSON = string(b)
894-
}
895-
896-
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, toolsJSON, toolChoiceJSON, nil, nil, nil)
887+
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, session.Tools, session.ToolChoice, nil, nil, nil)
897888
if err != nil {
898889
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
899890
return
@@ -902,8 +893,11 @@ func generateResponse(config *config.ModelConfig, session *Session, utt []byte,
902893
pred, err := predFunc()
903894
if err != nil {
904895
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
896+
return
905897
}
906898

899+
xlog.Debug("Function config for parsing", "function_name_key", config.FunctionsConfig.FunctionNameKey, "function_arguments_key", config.FunctionsConfig.FunctionArgumentsKey)
900+
907901
rawResponse := pred.Response
908902
if config.TemplateConfig.ReplyPrefix != "" {
909903
rawResponse = config.TemplateConfig.ReplyPrefix + rawResponse
@@ -916,6 +910,8 @@ func generateResponse(config *config.ModelConfig, session *Session, utt []byte,
916910
cleanedResponse := functions.CleanupLLMResult(responseWithoutReasoning, config.FunctionsConfig)
917911
toolCalls := functions.ParseFunctionCall(cleanedResponse, config.FunctionsConfig)
918912

913+
xlog.Debug("Function call parsing", "textContent", textContent, "cleanedResponse", cleanedResponse, "toolCallsCount", len(toolCalls))
914+
919915
noActionName := "answer"
920916
if config.FunctionsConfig.NoActionFunctionName != "" {
921917
noActionName = config.FunctionsConfig.NoActionFunctionName
@@ -932,15 +928,23 @@ func generateResponse(config *config.ModelConfig, session *Session, utt []byte,
932928
if m, exists := arguments["message"]; exists {
933929
if message, ok := m.(string); ok {
934930
finalSpeech = message
931+
} else {
932+
xlog.Warn("NoAction function message field is not a string", "type", fmt.Sprintf("%T", m))
935933
}
934+
} else {
935+
xlog.Warn("NoAction function missing 'message' field in arguments")
936936
}
937+
} else {
938+
xlog.Warn("Failed to unmarshal NoAction function arguments", "error", err, "arguments", arg)
937939
}
938940
if finalSpeech == "" {
939941
// Fallback if parsing failed
942+
xlog.Warn("NoAction function did not produce speech, using cleaned response as fallback")
940943
finalSpeech = cleanedResponse
941944
}
942945
} else {
943946
finalToolCalls = toolCalls
947+
xlog.Debug("Setting finalToolCalls", "count", len(finalToolCalls))
944948
if len(toolCalls) > 0 {
945949
finalSpeech = textContent
946950
} else {
@@ -1060,6 +1064,7 @@ func generateResponse(config *config.ModelConfig, session *Session, utt []byte,
10601064
}
10611065

10621066
// Handle Tool Calls
1067+
xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(finalToolCalls))
10631068
for i, tc := range finalToolCalls {
10641069
toolCallID := generateItemID()
10651070
callID := "call_" + generateUniqueID() // OpenAI uses call_xyz

core/http/endpoints/openai/realtime_model.go

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@ func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language st
6565
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
6666
}
6767

68-
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
68+
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
6969
return nil, fmt.Errorf("predict operation not supported in transcript-only mode")
7070
}
7171

7272
func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
7373
return "", nil, fmt.Errorf("TTS not supported in transcript-only mode")
7474
}
7575

76+
func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig {
77+
return nil
78+
}
79+
7680
func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) {
7781
return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig)
7882
}
@@ -81,45 +85,78 @@ func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, t
8185
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
8286
}
8387

84-
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
88+
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
8589
input := schema.OpenAIRequest{
8690
Messages: messages,
8791
}
8892

8993
var predInput string
9094
if !m.LLMConfig.TemplateConfig.UseTokenizerTemplate {
9195
var funcs []functions.Function
92-
if tools != "" {
93-
var toolUnions []types.ToolUnion
94-
if err := json.Unmarshal([]byte(tools), &toolUnions); err == nil {
95-
for _, t := range toolUnions {
96-
if t.Function != nil {
97-
params, _ := t.Function.Parameters.(map[string]interface{})
98-
funcs = append(funcs, functions.Function{
99-
Name: t.Function.Name,
100-
Description: t.Function.Description,
101-
Parameters: params,
102-
})
96+
if len(tools) > 0 {
97+
for _, t := range tools {
98+
if t.Function != nil {
99+
var params map[string]any
100+
101+
switch p := t.Function.Parameters.(type) {
102+
case map[string]any:
103+
params = p
104+
case string:
105+
if err := json.Unmarshal([]byte(p), &params); err != nil {
106+
xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name)
107+
}
103108
}
109+
110+
funcs = append(funcs, functions.Function{
111+
Name: t.Function.Name,
112+
Description: t.Function.Description,
113+
Parameters: params,
114+
})
104115
}
105116
}
106117
}
107118

108119
predInput = m.evaluator.TemplateMessages(input, input.Messages, m.LLMConfig, funcs, len(funcs) > 0)
109120

121+
// If the config doesn't specify function_name_key but the template contains the word "function"
122+
// in its function calling instructions, default to "function" as the key
123+
// This handles templates that say: "return a json object with function name and arguments"
124+
// but show the schema format as: {'function': {'name': '...', ...}}
125+
if m.LLMConfig.FunctionsConfig.FunctionNameKey == "" {
126+
// Check if this is likely a template that uses "function" as the key
127+
// by looking at common patterns in function templates
128+
xlog.Debug("FunctionNameKey not configured, will use default parsing")
129+
}
130+
110131
xlog.Debug("Prompt (after templating)", "prompt", predInput)
111132
if m.LLMConfig.Grammar != "" {
112133
xlog.Debug("Grammar", "grammar", m.LLMConfig.Grammar)
113134
}
114135
}
115136

116-
return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, m.LLMConfig, m.confLoader, m.appConfig, tokenCallback, tools, toolChoice, logprobs, topLogprobs, logitBias, )
137+
var toolsJSON string
138+
if len(tools) > 0 {
139+
b, _ := json.Marshal(tools)
140+
toolsJSON = string(b)
141+
}
142+
143+
var toolChoiceJSON string
144+
if toolChoice != nil {
145+
b, _ := json.Marshal(toolChoice)
146+
toolChoiceJSON = string(b)
147+
}
148+
149+
return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, m.LLMConfig, m.confLoader, m.appConfig, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, )
117150
}
118151

119152
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
120153
return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
121154
}
122155

156+
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
157+
return m.LLMConfig
158+
}
159+
123160
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
124161
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
125162
if err != nil {

0 commit comments

Comments
 (0)