Skip to content

Commit 91f05b8

Browse files
committed
feat: add tool call cancellation
1 parent e761c43 commit 91f05b8

File tree

8 files changed

+148
-20
lines changed

8 files changed

+148
-20
lines changed

src-tauri/src/core/mcp/commands.rs

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult};
22
use serde_json::{Map, Value};
33
use tauri::{AppHandle, Emitter, Runtime, State};
44
use tokio::time::timeout;
5+
use tokio::sync::oneshot;
56

67
use super::{
78
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
@@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
179180
/// * `state` - Application state containing MCP server connections
180181
/// * `tool_name` - Name of the tool to call
181182
/// * `arguments` - Optional map of argument names to values
183+
/// * `cancellation_token` - Optional token to allow cancellation from JS side
182184
///
183185
/// # Returns
184186
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
@@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
187189
/// 1. Locks the MCP servers mutex to access server connections
188190
/// 2. Searches through all servers for one containing the named tool
189191
/// 3. When found, calls the tool on that server with the provided arguments
190-
/// 4. Returns error if no server has the requested tool
192+
/// 4. Supports cancellation via cancellation_token
193+
/// 5. Returns error if no server has the requested tool
191194
#[tauri::command]
192195
pub async fn call_tool(
193196
state: State<'_, AppState>,
194197
tool_name: String,
195198
arguments: Option<Map<String, Value>>,
199+
cancellation_token: Option<String>,
196200
) -> Result<CallToolResult, String> {
201+
// Set up cancellation if token is provided
202+
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
203+
204+
if let Some(token) = &cancellation_token {
205+
let mut cancellations = state.tool_call_cancellations.lock().await;
206+
cancellations.insert(token.clone(), cancel_tx);
207+
}
208+
197209
let servers = state.mcp_servers.lock().await;
198210

199211
// Iterate through servers and find the first one that contains the tool
@@ -209,25 +221,77 @@ pub async fn call_tool(
209221

210222
println!("Found tool {} in server", tool_name);
211223

212-
// Call the tool with timeout
224+
// Call the tool with timeout and cancellation support
213225
let tool_call = service.call_tool(CallToolRequestParam {
214226
name: tool_name.clone().into(),
215227
arguments,
216228
});
217229

218-
return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
219-
Ok(result) => result.map_err(|e| e.to_string()),
220-
Err(_) => Err(format!(
221-
"Tool call '{}' timed out after {} seconds",
222-
tool_name,
223-
MCP_TOOL_CALL_TIMEOUT.as_secs()
224-
)),
230+
// Race between timeout, tool call, and cancellation
231+
let result = if cancellation_token.is_some() {
232+
tokio::select! {
233+
result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => {
234+
match result {
235+
Ok(call_result) => call_result.map_err(|e| e.to_string()),
236+
Err(_) => Err(format!(
237+
"Tool call '{}' timed out after {} seconds",
238+
tool_name,
239+
MCP_TOOL_CALL_TIMEOUT.as_secs()
240+
)),
241+
}
242+
}
243+
_ = cancel_rx => {
244+
Err(format!("Tool call '{}' was cancelled", tool_name))
245+
}
246+
}
247+
} else {
248+
match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
249+
Ok(call_result) => call_result.map_err(|e| e.to_string()),
250+
Err(_) => Err(format!(
251+
"Tool call '{}' timed out after {} seconds",
252+
tool_name,
253+
MCP_TOOL_CALL_TIMEOUT.as_secs()
254+
)),
255+
}
225256
};
257+
258+
// Clean up cancellation token
259+
if let Some(token) = &cancellation_token {
260+
let mut cancellations = state.tool_call_cancellations.lock().await;
261+
cancellations.remove(token);
262+
}
263+
264+
return result;
226265
}
227266

228267
Err(format!("Tool {} not found", tool_name))
229268
}
230269

270+
/// Cancels a running tool call by its cancellation token
271+
///
272+
/// # Arguments
273+
/// * `state` - Application state containing cancellation tokens
274+
/// * `cancellation_token` - Token identifying the tool call to cancel
275+
///
276+
/// # Returns
277+
/// * `Result<(), String>` - Success if token found and cancelled, error otherwise
278+
#[tauri::command]
279+
pub async fn cancel_tool_call(
280+
state: State<'_, AppState>,
281+
cancellation_token: String,
282+
) -> Result<(), String> {
283+
let mut cancellations = state.tool_call_cancellations.lock().await;
284+
285+
if let Some(cancel_tx) = cancellations.remove(&cancellation_token) {
286+
// Send cancellation signal - ignore if receiver is already dropped
287+
let _ = cancel_tx.send(());
288+
println!("Tool call with token {} cancelled", cancellation_token);
289+
Ok(())
290+
} else {
291+
Err(format!("Cancellation token {} not found", cancellation_token))
292+
}
293+
}
294+
231295
#[tauri::command]
232296
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
233297
let mut path = get_jan_data_folder_path(app);

src-tauri/src/core/state.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rmcp::{
66
service::RunningService,
77
RoleClient, ServiceError,
88
};
9-
use tokio::sync::Mutex;
9+
use tokio::sync::{Mutex, oneshot};
1010
use tokio::task::JoinHandle;
1111

1212
/// Server handle type for managing the proxy server lifecycle
@@ -27,6 +27,7 @@ pub struct AppState {
2727
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
2828
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
2929
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
30+
pub tool_call_cancellations: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
3031
}
3132

3233
impl RunningServiceEnum {

src-tauri/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pub fn run() {
7474
// MCP commands
7575
core::mcp::commands::get_tools,
7676
core::mcp::commands::call_tool,
77+
core::mcp::commands::cancel_tool_call,
7778
core::mcp::commands::restart_mcp_servers,
7879
core::mcp::commands::get_connected_servers,
7980
core::mcp::commands::save_mcp_configs,
@@ -105,6 +106,7 @@ pub fn run() {
105106
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
106107
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
107108
server_handle: Arc::new(Mutex::new(None)),
109+
tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())),
108110
})
109111
.setup(|app| {
110112
app.handle().plugin(

web-app/src/containers/ChatInput.tsx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
4646
const textareaRef = useRef<HTMLTextAreaElement>(null)
4747
const [isFocused, setIsFocused] = useState(false)
4848
const [rows, setRows] = useState(1)
49-
const { streamingContent, abortControllers, loadingModel, tools } =
50-
useAppState()
49+
const {
50+
streamingContent,
51+
abortControllers,
52+
loadingModel,
53+
tools,
54+
cancelToolCall,
55+
} = useAppState()
5156
const { prompt, setPrompt } = usePrompt()
5257
const { currentThreadId } = useThreads()
5358
const { t } = useTranslation()
@@ -161,8 +166,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
161166
const stopStreaming = useCallback(
162167
(threadId: string) => {
163168
abortControllers[threadId]?.abort()
169+
cancelToolCall?.()
164170
},
165-
[abortControllers]
171+
[abortControllers, cancelToolCall]
166172
)
167173

168174
const fileInputRef = useRef<HTMLInputElement>(null)

web-app/src/hooks/useAppState.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type AppState = {
1313
tokenSpeed?: TokenSpeed
1414
currentToolCall?: ChatCompletionMessageToolCall
1515
showOutOfContextDialog?: boolean
16+
cancelToolCall?: () => void
1617
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
1718
updateStreamingContent: (content: ThreadMessage | undefined) => void
1819
updateCurrentToolCall: (
@@ -24,6 +25,7 @@ type AppState = {
2425
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
2526
resetTokenSpeed: () => void
2627
setOutOfContextDialog: (show: boolean) => void
28+
setCancelToolCall: (cancel: (() => void) | undefined) => void
2729
}
2830

2931
export const useAppState = create<AppState>()((set) => ({
@@ -34,6 +36,7 @@ export const useAppState = create<AppState>()((set) => ({
3436
abortControllers: {},
3537
tokenSpeed: undefined,
3638
currentToolCall: undefined,
39+
cancelToolCall: undefined,
3740
updateStreamingContent: (content: ThreadMessage | undefined) => {
3841
const assistants = useAssistant.getState().assistants
3942
const currentAssistant = useAssistant.getState().currentAssistant
@@ -112,4 +115,9 @@ export const useAppState = create<AppState>()((set) => ({
112115
showOutOfContextDialog: show,
113116
}))
114117
},
118+
setCancelToolCall: (cancel) => {
119+
set(() => ({
120+
cancelToolCall: cancel,
121+
}))
122+
},
115123
}))

web-app/src/lib/completion.ts

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ import { ulid } from 'ulidx'
3131
import { MCPTool } from '@/types/completion'
3232
import { CompletionMessagesBuilder } from './messages'
3333
import { ChatCompletionMessageToolCall } from 'openai/resources'
34-
import { callTool } from '@/services/mcp'
34+
import { callToolWithCancellation } from '@/services/mcp'
3535
import { ExtensionManager } from './extension'
36+
import { useAppState } from '@/hooks/useAppState'
3637

3738
export type ChatCompletionResponse =
3839
| chatCompletion
@@ -381,13 +382,17 @@ export const postMessageProcessing = async (
381382
)
382383
: true)
383384

385+
const { promise, cancel } = callToolWithCancellation({
386+
toolName: toolCall.function.name,
387+
arguments: toolCall.function.arguments.length
388+
? JSON.parse(toolCall.function.arguments)
389+
: {},
390+
})
391+
392+
useAppState.getState().setCancelToolCall(cancel)
393+
384394
let result = approved
385-
? await callTool({
386-
toolName: toolCall.function.name,
387-
arguments: toolCall.function.arguments.length
388-
? JSON.parse(toolCall.function.arguments)
389-
: {},
390-
}).catch((e) => {
395+
? await promise.catch((e) => {
391396
console.error('Tool call failed:', e)
392397
return {
393398
content: [

web-app/src/lib/service.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export const AppRoutes = [
55
'installExtensions',
66
'getTools',
77
'callTool',
8+
'cancelToolCall',
89
'listThreads',
910
'createThread',
1011
'modifyThread',

web-app/src/services/mcp.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,44 @@ export const callTool = (args: {
5656
}): Promise<{ error: string; content: { text: string }[] }> => {
5757
return window.core?.api?.callTool(args)
5858
}
59+
60+
/**
61+
* @description Enhanced function to invoke an MCP tool with cancellation support
62+
* @param args - Tool call arguments
63+
* @param cancellationToken - Optional cancellation token
64+
* @returns Promise with tool result and cancellation function
65+
*/
66+
export const callToolWithCancellation = (args: {
67+
toolName: string
68+
arguments: object
69+
cancellationToken?: string
70+
}): {
71+
promise: Promise<{ error: string; content: { text: string }[] }>
72+
cancel: () => Promise<void>
73+
token: string
74+
} => {
75+
// Generate a unique cancellation token if not provided
76+
const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
77+
78+
// Create the tool call promise with cancellation token
79+
const promise = window.core?.api?.callTool({
80+
...args,
81+
cancellationToken: token
82+
})
83+
84+
// Create cancel function
85+
const cancel = async () => {
86+
await window.core?.api?.cancelToolCall({ cancellationToken: token })
87+
}
88+
89+
return { promise, cancel, token }
90+
}
91+
92+
/**
93+
* @description This function cancels a running tool call
94+
* @param cancellationToken - The token identifying the tool call to cancel
95+
* @returns
96+
*/
97+
export const cancelToolCall = (cancellationToken: string): Promise<void> => {
98+
return window.core?.api?.cancelToolCall({ cancellationToken })
99+
}

0 commit comments

Comments
 (0)