@@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult};
22use serde_json:: { Map , Value } ;
33use tauri:: { AppHandle , Emitter , Runtime , State } ;
44use tokio:: time:: timeout;
5+ use tokio:: sync:: oneshot;
56
67use super :: {
78 constants:: { DEFAULT_MCP_CONFIG , MCP_TOOL_CALL_TIMEOUT } ,
@@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
179180/// * `state` - Application state containing MCP server connections
180181/// * `tool_name` - Name of the tool to call
181182/// * `arguments` - Optional map of argument names to values
183+ /// * `cancellation_token` - Optional token to allow cancellation from JS side
182184///
183185/// # Returns
184186/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
@@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
187189/// 1. Locks the MCP servers mutex to access server connections
188190/// 2. Searches through all servers for one containing the named tool
189191/// 3. When found, calls the tool on that server with the provided arguments
190- /// 4. Returns error if no server has the requested tool
192+ /// 4. Supports cancellation via cancellation_token
193+ /// 5. Returns error if no server has the requested tool
191194#[ tauri:: command]
192195pub async fn call_tool (
193196 state : State < ' _ , AppState > ,
194197 tool_name : String ,
195198 arguments : Option < Map < String , Value > > ,
199+ cancellation_token : Option < String > ,
196200) -> Result < CallToolResult , String > {
201+ // Set up cancellation if token is provided
202+ let ( cancel_tx, cancel_rx) = oneshot:: channel :: < ( ) > ( ) ;
203+
204+ if let Some ( token) = & cancellation_token {
205+ let mut cancellations = state. tool_call_cancellations . lock ( ) . await ;
206+ cancellations. insert ( token. clone ( ) , cancel_tx) ;
207+ }
208+
197209 let servers = state. mcp_servers . lock ( ) . await ;
198210
199211 // Iterate through servers and find the first one that contains the tool
@@ -209,25 +221,77 @@ pub async fn call_tool(
209221
210222 println ! ( "Found tool {} in server" , tool_name) ;
211223
212- // Call the tool with timeout
224+ // Call the tool with timeout and cancellation support
213225 let tool_call = service. call_tool ( CallToolRequestParam {
214226 name : tool_name. clone ( ) . into ( ) ,
215227 arguments,
216228 } ) ;
217229
218- return match timeout ( MCP_TOOL_CALL_TIMEOUT , tool_call) . await {
219- Ok ( result) => result. map_err ( |e| e. to_string ( ) ) ,
220- Err ( _) => Err ( format ! (
221- "Tool call '{}' timed out after {} seconds" ,
222- tool_name,
223- MCP_TOOL_CALL_TIMEOUT . as_secs( )
224- ) ) ,
230+ // Race between timeout, tool call, and cancellation
231+ let result = if cancellation_token. is_some ( ) {
232+ tokio:: select! {
233+ result = timeout( MCP_TOOL_CALL_TIMEOUT , tool_call) => {
234+ match result {
235+ Ok ( call_result) => call_result. map_err( |e| e. to_string( ) ) ,
236+ Err ( _) => Err ( format!(
237+ "Tool call '{}' timed out after {} seconds" ,
238+ tool_name,
239+ MCP_TOOL_CALL_TIMEOUT . as_secs( )
240+ ) ) ,
241+ }
242+ }
243+ _ = cancel_rx => {
244+ Err ( format!( "Tool call '{}' was cancelled" , tool_name) )
245+ }
246+ }
247+ } else {
248+ match timeout ( MCP_TOOL_CALL_TIMEOUT , tool_call) . await {
249+ Ok ( call_result) => call_result. map_err ( |e| e. to_string ( ) ) ,
250+ Err ( _) => Err ( format ! (
251+ "Tool call '{}' timed out after {} seconds" ,
252+ tool_name,
253+ MCP_TOOL_CALL_TIMEOUT . as_secs( )
254+ ) ) ,
255+ }
225256 } ;
257+
258+ // Clean up cancellation token
259+ if let Some ( token) = & cancellation_token {
260+ let mut cancellations = state. tool_call_cancellations . lock ( ) . await ;
261+ cancellations. remove ( token) ;
262+ }
263+
264+ return result;
226265 }
227266
228267 Err ( format ! ( "Tool {} not found" , tool_name) )
229268}
230269
270+ /// Cancels a running tool call by its cancellation token
271+ ///
272+ /// # Arguments
273+ /// * `state` - Application state containing cancellation tokens
274+ /// * `cancellation_token` - Token identifying the tool call to cancel
275+ ///
276+ /// # Returns
277+ /// * `Result<(), String>` - Success if token found and cancelled, error otherwise
278+ #[ tauri:: command]
279+ pub async fn cancel_tool_call (
280+ state : State < ' _ , AppState > ,
281+ cancellation_token : String ,
282+ ) -> Result < ( ) , String > {
283+ let mut cancellations = state. tool_call_cancellations . lock ( ) . await ;
284+
285+ if let Some ( cancel_tx) = cancellations. remove ( & cancellation_token) {
286+ // Send cancellation signal - ignore if receiver is already dropped
287+ let _ = cancel_tx. send ( ( ) ) ;
288+ println ! ( "Tool call with token {} cancelled" , cancellation_token) ;
289+ Ok ( ( ) )
290+ } else {
291+ Err ( format ! ( "Cancellation token {} not found" , cancellation_token) )
292+ }
293+ }
294+
231295#[ tauri:: command]
232296pub async fn get_mcp_configs ( app : AppHandle ) -> Result < String , String > {
233297 let mut path = get_jan_data_folder_path ( app) ;
0 commit comments