Skip to content

Commit 39e0ff5

Browse files
authored
Improve automatic tool call (#1460)
* Improved auto tool call * Add logging
1 parent a5c4eda commit 39e0ff5

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

mistralrs-core/src/engine/search_request.rs

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,31 @@ async fn do_custom_tool(
356356
}
357357

358358
let result = if let Some(cb) = this.tool_callbacks.get(&tool_calls.function.name) {
359-
cb(&tool_calls.function).unwrap_or_else(|e| format!("ERROR: {e}"))
359+
tracing::info!("Called tool `{}`.", tool_calls.function.name);
360+
cb(&tool_calls.function).unwrap_or_else(|e| {
361+
tracing::error!(
362+
"Error when calling tool `{}`: {e}",
363+
tool_calls.function.name
364+
);
365+
format!("ERROR: {e}")
366+
})
360367
} else if let Some(callback_with_tool) = this
361368
.tool_callbacks_with_tools
362369
.get(&tool_calls.function.name)
363370
{
364-
(callback_with_tool.callback)(&tool_calls.function)
365-
.unwrap_or_else(|e| format!("ERROR: {e}"))
371+
tracing::info!("Called tool `{}`.", tool_calls.function.name);
372+
(callback_with_tool.callback)(&tool_calls.function).unwrap_or_else(|e| {
373+
tracing::error!(
374+
"Error when calling tool `{}`: {e}",
375+
tool_calls.function.name
376+
);
377+
format!("ERROR: {e}")
378+
})
366379
} else {
380+
tracing::error!(
381+
"Attempted to call tool `{}`, but it doesn't exist.",
382+
tool_calls.function.name
383+
);
367384
format!("ERROR: no tool callback for {}", tool_calls.function.name)
368385
};
369386

@@ -432,8 +449,6 @@ pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
432449
// `current` is what we actually dispatch each loop.
433450
// The very first time that is the hidden probe.
434451
let mut current = probe;
435-
// Forward results to the user after the first loop.
436-
let mut forward_to_user = false;
437452

438453
loop {
439454
// Each dispatch gets its own one-shot channel so we can peek at
@@ -493,14 +508,6 @@ pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
493508
}
494509
};
495510

496-
// Forward to the caller once the probe is out of the way.
497-
if forward_to_user {
498-
user_sender
499-
.send(Response::Done(done.clone()))
500-
.await
501-
.unwrap();
502-
}
503-
504511
// Did the assistant ask to run a tool?
505512
let tc_opt = match &done.choices[0].message.tool_calls {
506513
Some(calls) if calls.len() == 1 => Some(&calls[0]),
@@ -509,7 +516,11 @@ pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
509516

510517
// No tool call? We are finished.
511518
if tc_opt.is_none() {
512-
break;
519+
user_sender
520+
.send(Response::Done(done.clone()))
521+
.await
522+
.unwrap();
523+
return;
513524
}
514525

515526
// Tool requested → build the next turn.
@@ -530,7 +541,6 @@ pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
530541
visible_req = next_visible.clone();
531542
visible_req.response = user_sender.clone();
532543
current = visible_req.clone();
533-
forward_to_user = true;
534544
}
535545
// ------------------------- STREAMING -------------------------
536546
else {
@@ -636,7 +646,6 @@ pub(super) async fn search_request(this: Arc<Engine>, request: NormalRequest) {
636646
visible_req = next_visible.clone();
637647
visible_req.response = user_sender.clone();
638648
current = visible_req.clone();
639-
forward_to_user = true;
640649
}
641650
}
642651
});

0 commit comments

Comments
 (0)